Source code for pyrit.prompt_normalizer.prompt_normalizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import copy
import logging
import traceback
from typing import Any, List, Optional
from uuid import uuid4

from pyrit.exceptions import (
    ComponentRole,
    EmptyResponseException,
    execution_context,
    get_execution_context,
)
from pyrit.identifiers import AttackIdentifier
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import (
    Message,
    construct_response_from_request,
)
from pyrit.prompt_normalizer import NormalizerRequest, PromptConverterConfiguration
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.batch_helper import batch_task_async

logger = logging.getLogger(__name__)


[docs] class PromptNormalizer: """ Handles normalization and processing of prompts before they are sent to targets. """ _memory: MemoryInterface = None
[docs] def __init__(self, start_token: str = "⟪", end_token: str = "⟫") -> None: """ Initialize the PromptNormalizer. start_token and end_token are used to delineate which part of a prompt is converted. """ self._memory = CentralMemory.get_memory_instance() self._start_token = start_token self._end_token = end_token self.id = str(uuid4())
[docs] async def send_prompt_async( self, *, message: Message, target: PromptTarget, conversation_id: Optional[str] = None, request_converter_configurations: list[PromptConverterConfiguration] = [], response_converter_configurations: list[PromptConverterConfiguration] = [], labels: Optional[dict[str, str]] = None, attack_identifier: Optional[AttackIdentifier] = None, ) -> Message: """ Send a single request to a target. Args: message (Message): The message to be sent. target (PromptTarget): The target to which the prompt is sent. conversation_id (str, optional): The ID of the conversation. Defaults to None. request_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the request. Defaults to an empty list. response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. attack_identifier (Optional[AttackIdentifier], optional): Identifier for the attack. Defaults to None. Raises: Exception: If an error occurs during the request processing. ValueError: If the message pieces are not part of the same sequence. Returns: Message: The response received from the target. """ # Validates that the MessagePieces in the Message are part of the same sequence if len(set(piece.sequence for piece in message.message_pieces)) > 1: raise ValueError("All MessagePieces in the Message must have the same sequence.") # Prepare the request by updating conversation ID, labels, and attack identifier request = copy.deepcopy(message) conversation_id = conversation_id if conversation_id else str(uuid4()) for piece in request.message_pieces: piece.conversation_id = conversation_id if labels: piece.labels = labels piece.prompt_target_identifier = target.get_identifier() if attack_identifier: piece.attack_identifier = attack_identifier # Apply request converters await self.convert_values(converter_configurations=request_converter_configurations, message=request) await self._calc_hash(request=request) responses = None try: responses = await target.send_prompt_async(message=request) self._memory.add_message_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution self._memory.add_message_to_memory(request=request) responses = [ construct_response_from_request( request=request.message_pieces[0], response_text_pieces=[""], response_type="text", error="empty", ) ] except Exception as ex: # Ensure request to memory before processing exception self._memory.add_message_to_memory(request=request) error_response = construct_response_from_request( request=request.message_pieces[0], response_text_pieces=[f"{ex}\n{repr(ex)}\n{traceback.format_exc()}"], response_type="error", error="processing", ) await self._calc_hash(request=error_response) self._memory.add_message_to_memory(request=error_response) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex # handling empty responses message list and None responses if not responses or not any(responses): return None # Process all response messages (targets return list[Message]) # Only apply response converters to the last message (final response) # Intermediate messages are tool calls/outputs that don't need conversion for i, resp in enumerate(responses): is_last = i == len(responses) - 1 if is_last: await self.convert_values(converter_configurations=response_converter_configurations, message=resp) await self._calc_hash(request=resp) self._memory.add_message_to_memory(request=resp) # Return the last response for backward compatibility return responses[-1]
[docs] async def send_prompt_batch_to_target_async( self, *, requests: list[NormalizerRequest], target: PromptTarget, labels: Optional[dict[str, str]] = None, attack_identifier: Optional[AttackIdentifier] = None, batch_size: int = 10, ) -> list[Message]: """ Send a batch of prompts to the target asynchronously. Args: requests (list[NormalizerRequest]): A list of NormalizerRequest objects to be sent. target (PromptTarget): The target to which the prompts are sent. labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request. Defaults to None. attack_identifier (Optional[AttackIdentifier], optional): The attack identifier. Defaults to None. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. Returns: list[Message]: A list of Message objects representing the responses received for each prompt. """ batch_items: List[List[Any]] = [ [request.message for request in requests], [request.request_converter_configurations for request in requests], [request.response_converter_configurations for request in requests], [request.conversation_id for request in requests], ] batch_item_keys = [ "message", "request_converter_configurations", "response_converter_configurations", "conversation_id", ] responses = await batch_task_async( prompt_target=target, batch_size=batch_size, items_to_batch=batch_items, task_func=self.send_prompt_async, task_arguments=batch_item_keys, target=target, labels=labels, attack_identifier=attack_identifier, ) # Filter out None responses (e.g., from empty responses) return [response for response in responses if response is not None]
[docs] async def convert_values( self, converter_configurations: list[PromptConverterConfiguration], message: Message, ) -> None: """ Apply converter configurations to message pieces. Args: converter_configurations (list[PromptConverterConfiguration]): List of configurations specifying which converters to apply and to which message pieces. message (Message): The message containing pieces to be converted. Raises: Exception: Any exception from converters propagates with execution context for error tracing. """ for converter_configuration in converter_configurations: for piece_index, piece in enumerate(message.message_pieces): indexes = converter_configuration.indexes_to_apply data_types = converter_configuration.prompt_data_types_to_apply if indexes and piece_index not in indexes: continue if data_types and piece.converted_value_data_type not in data_types: continue piece.converter_identifiers.extend( [converter.get_identifier() for converter in converter_configuration.converters] ) converted_text = piece.converted_value converted_text_data_type = piece.converted_value_data_type for converter in converter_configuration.converters: # Inherit attack context from outer execution context (set by attack strategy) outer_context = get_execution_context() try: with execution_context( component_role=ComponentRole.CONVERTER, attack_strategy_name=outer_context.attack_strategy_name if outer_context else None, attack_identifier=outer_context.attack_identifier if outer_context else None, component_identifier=converter.get_identifier(), objective_target_conversation_id=( outer_context.objective_target_conversation_id if outer_context else None ), ): converter_result = await converter.convert_tokens_async( prompt=converted_text, input_type=converted_text_data_type, start_token=self._start_token, end_token=self._end_token, ) converted_text = converter_result.output_text converted_text_data_type = converter_result.output_type except Exception: # Let the exception propagate - execution context will add converter details raise piece.converted_value = converted_text piece.converted_value_data_type = converted_text_data_type
async def _calc_hash(self, request: Message) -> None: """Add a request to the memory.""" tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.message_pieces] await asyncio.gather(*tasks)
[docs] async def add_prepended_conversation_to_memory( self, conversation_id: str, should_convert: bool = True, converter_configurations: Optional[list[PromptConverterConfiguration]] = None, attack_identifier: Optional[AttackIdentifier] = None, prepended_conversation: Optional[list[Message]] = None, ) -> Optional[list[Message]]: """ Process the prepended conversation by converting it if needed and adding it to memory. Args: conversation_id (str): The conversation ID to use for the message pieces should_convert (bool): Whether to convert the prepended conversation converter_configurations (Optional[list[PromptConverterConfiguration]]): Configurations for converting the request attack_identifier (Optional[AttackIdentifier]): Identifier for the attack prepended_conversation (Optional[list[Message]]): The conversation to prepend Returns: Optional[list[Message]]: The processed prepended conversation """ if not prepended_conversation: return None # Create a deep copy of the prepended conversation to avoid modifying the original prepended_conversation = copy.deepcopy(prepended_conversation) for request in prepended_conversation: if should_convert and converter_configurations: await self.convert_values(message=request, converter_configurations=converter_configurations) for piece in request.message_pieces: piece.conversation_id = conversation_id if attack_identifier: piece.attack_identifier = attack_identifier # if the piece is retrieved from somewhere else, it needs to be unique # and if not, this won't hurt anything piece.id = uuid4() self._memory.add_message_to_memory(request=request) return prepended_conversation