Source code for pyrit.prompt_converter.translation_converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import pathlib
import uuid
from textwrap import dedent
from typing import Optional

from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from pyrit.common.path import DATASETS_PATH
from pyrit.models import (
    PromptDataType,
    PromptRequestPiece,
    PromptRequestResponse,
    SeedPrompt,
)
from pyrit.prompt_converter import ConverterResult, PromptConverter
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)


[docs] class TranslationConverter(PromptConverter):
[docs] def __init__( self, *, converter_target: PromptChatTarget, language: str, prompt_template: Optional[SeedPrompt] = None, max_retries: int = 3, max_wait_time_in_seconds: int = 60, ): """ Initializes a TranslationConverter object. Args: converter_target (PromptChatTarget): The target chat support for the conversion which will translate language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. Raises: ValueError: If the language is not provided. """ self.converter_target = converter_target # Retry strategy for the conversion self._max_retries = max_retries self._max_wait_time_in_seconds = max_wait_time_in_seconds # set to default strategy if not provided prompt_template = ( prompt_template if prompt_template else SeedPrompt.from_yaml_file( pathlib.Path(DATASETS_PATH) / "prompt_converters" / "translation_converter.yaml" ) ) if not language: raise ValueError("Language must be provided for translation conversion") self.language = language.lower() self.system_prompt = prompt_template.render_template_value(languages=language)
[docs] async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: """ Generates variations of the input prompt using the converter target. Parameters: prompt (str): prompt to convert Return: (ConverterResult): result generated by the converter target """ conversation_id = str(uuid.uuid4()) self.converter_target.set_system_prompt(system_prompt=self.system_prompt, conversation_id=conversation_id) if not self.input_supported(input_type): raise ValueError("Input type not supported") formatted_prompt = dedent( f"Translate the following to {self.language} between the begin and end tags:" "=== begin ===\n" f"{prompt}\n" "=== end ===\n" ) logger.debug(f"Formatted Prompt: {formatted_prompt}") request = PromptRequestResponse( [ PromptRequestPiece( role="user", original_value=prompt, converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], ) ] ) translation = await self._send_translation_prompt_async(request) return ConverterResult(output_text=translation, output_type="text")
async def _send_translation_prompt_async(self, request) -> str: async for attempt in AsyncRetrying( stop=stop_after_attempt(self._max_retries), wait=wait_exponential(multiplier=1, min=1, max=self._max_wait_time_in_seconds), retry=retry_if_exception_type(Exception), # covers all exceptions ): with attempt: logger.debug(f"Attempt {attempt.retry_state.attempt_number} for translation") response = await self.converter_target.send_prompt_async(prompt_request=request) response_msg = response.get_value() return response_msg.strip() # when we exhaust all retries without success, raise an exception raise Exception(f"Failed to translate after {self._max_retries} attempts")
[docs] def input_supported(self, input_type: PromptDataType) -> bool: return input_type == "text"
[docs] def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "text"