Source code for pyrit.prompt_converter.translation_converter
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import pathlib
import uuid
from textwrap import dedent
from typing import Optional
from tenacity import (
AsyncRetrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pyrit.common.path import DATASETS_PATH
from pyrit.models import (
PromptDataType,
PromptRequestPiece,
PromptRequestResponse,
SeedPrompt,
)
from pyrit.prompt_converter import ConverterResult, PromptConverter
from pyrit.prompt_target import PromptChatTarget
logger = logging.getLogger(__name__)
[docs]
class TranslationConverter(PromptConverter):
[docs]
def __init__(
self,
*,
converter_target: PromptChatTarget,
language: str,
prompt_template: Optional[SeedPrompt] = None,
max_retries: int = 3,
max_wait_time_in_seconds: int = 60,
):
"""
Initializes a TranslationConverter object.
Args:
converter_target (PromptChatTarget): The target chat support for the conversion which will translate
language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc.
prompt_template (SeedPrompt, Optional): The prompt template for the conversion.
Raises:
ValueError: If the language is not provided.
"""
self.converter_target = converter_target
# Retry strategy for the conversion
self._max_retries = max_retries
self._max_wait_time_in_seconds = max_wait_time_in_seconds
# set to default strategy if not provided
prompt_template = (
prompt_template
if prompt_template
else SeedPrompt.from_yaml_file(
pathlib.Path(DATASETS_PATH) / "prompt_converters" / "translation_converter.yaml"
)
)
if not language:
raise ValueError("Language must be provided for translation conversion")
self.language = language.lower()
self.system_prompt = prompt_template.render_template_value(languages=language)
[docs]
async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult:
"""
Generates variations of the input prompt using the converter target.
Parameters:
prompt (str): prompt to convert
Return:
(ConverterResult): result generated by the converter target
"""
conversation_id = str(uuid.uuid4())
self.converter_target.set_system_prompt(system_prompt=self.system_prompt, conversation_id=conversation_id)
if not self.input_supported(input_type):
raise ValueError("Input type not supported")
formatted_prompt = dedent(
f"Translate the following to {self.language} between the begin and end tags:"
"=== begin ===\n"
f"{prompt}\n"
"=== end ===\n"
)
logger.debug(f"Formatted Prompt: {formatted_prompt}")
request = PromptRequestResponse(
[
PromptRequestPiece(
role="user",
original_value=prompt,
converted_value=formatted_prompt,
conversation_id=conversation_id,
sequence=1,
prompt_target_identifier=self.converter_target.get_identifier(),
original_value_data_type=input_type,
converted_value_data_type=input_type,
converter_identifiers=[self.get_identifier()],
)
]
)
translation = await self._send_translation_prompt_async(request)
return ConverterResult(output_text=translation, output_type="text")
async def _send_translation_prompt_async(self, request) -> str:
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self._max_retries),
wait=wait_exponential(multiplier=1, min=1, max=self._max_wait_time_in_seconds),
retry=retry_if_exception_type(Exception), # covers all exceptions
):
with attempt:
logger.debug(f"Attempt {attempt.retry_state.attempt_number} for translation")
response = await self.converter_target.send_prompt_async(prompt_request=request)
response_msg = response.get_value()
return response_msg.strip()
# when we exhaust all retries without success, raise an exception
raise Exception(f"Failed to translate after {self._max_retries} attempts")
[docs]
def output_supported(self, output_type: PromptDataType) -> bool:
return output_type == "text"