# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import copy
import logging
import traceback
from typing import Any, List, Optional
from uuid import uuid4
from pyrit.exceptions import EmptyResponseException, PyritException
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import (
Message,
MessagePiece,
SeedGroup,
construct_response_from_request,
)
from pyrit.prompt_normalizer import NormalizerRequest, PromptConverterConfiguration
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.batch_helper import batch_task_async
logger = logging.getLogger(__name__)
[docs]
class PromptNormalizer:
"""
Handles normalization and processing of prompts before they are sent to targets.
"""
_memory: MemoryInterface = None
[docs]
def __init__(self, start_token: str = "⟪", end_token: str = "⟫") -> None:
"""
Initialize the PromptNormalizer.
start_token and end_token are used to delineate which part of a prompt is converted.
"""
self._memory = CentralMemory.get_memory_instance()
self._start_token = start_token
self._end_token = end_token
self.id = str(uuid4())
[docs]
async def send_prompt_async(
self,
*,
seed_group: SeedGroup,
target: PromptTarget,
conversation_id: Optional[str] = None,
request_converter_configurations: list[PromptConverterConfiguration] = [],
response_converter_configurations: list[PromptConverterConfiguration] = [],
labels: Optional[dict[str, str]] = None,
attack_identifier: Optional[dict[str, str]] = None,
) -> Message:
"""
Send a single request to a target.
Args:
seed_group (SeedGroup): The seed group to be sent.
target (PromptTarget): The target to which the prompt is sent.
conversation_id (str, optional): The ID of the conversation. Defaults to None.
request_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for
converting the request. Defaults to an empty list.
response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for
converting the response. Defaults to an empty list.
labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None.
attack_identifier (Optional[dict[str, str]], optional): Identifier for the attack. Defaults to
None.
Raises:
Exception: If an error occurs during the request processing.
ValueError: If the prompts in the SeedGroup are not part of the same sequence.
Returns:
Message: The response received from the target.
"""
# Validates that the SeedPrompts in the SeedGroup are part of the same sequence
if len(set(prompt.sequence for prompt in seed_group.prompts)) > 1:
raise ValueError("All SeedPrompts in the SeedGroup must have the same sequence.")
request = await self._build_message(
seed_group=seed_group,
conversation_id=conversation_id,
request_converter_configurations=request_converter_configurations,
target=target,
labels=labels,
attack_identifier=attack_identifier,
)
await self._calc_hash(request=request)
responses = None
try:
responses = await target.send_prompt_async(message=request)
self._memory.add_message_to_memory(request=request)
except EmptyResponseException:
# Empty responses are retried, but we don't want them to stop execution
self._memory.add_message_to_memory(request=request)
responses = [
construct_response_from_request(
request=request.message_pieces[0],
response_text_pieces=[""],
response_type="text",
error="empty",
)
]
except Exception as ex:
# Ensure request to memory before processing exception
self._memory.add_message_to_memory(request=request)
error_response = construct_response_from_request(
request=request.message_pieces[0],
response_text_pieces=[f"{ex}\n{repr(ex)}\n{traceback.format_exc()}"],
response_type="error",
error="processing",
)
await self._calc_hash(request=error_response)
self._memory.add_message_to_memory(request=error_response)
cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None
raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex
# handling empty responses message list and None responses
if not responses or not any(responses):
return None
# Process all response messages (targets return list[Message])
# Only apply response converters to the last message (final response)
# Intermediate messages are tool calls/outputs that don't need conversion
for i, resp in enumerate(responses):
is_last = i == len(responses) - 1
if is_last:
await self.convert_values(converter_configurations=response_converter_configurations, message=resp)
await self._calc_hash(request=resp)
self._memory.add_message_to_memory(request=resp)
# Return the last response for backward compatibility
return responses[-1]
[docs]
async def send_prompt_batch_to_target_async(
self,
*,
requests: list[NormalizerRequest],
target: PromptTarget,
labels: Optional[dict[str, str]] = None,
attack_identifier: Optional[dict[str, str]] = None,
batch_size: int = 10,
) -> list[Message]:
"""
Send a batch of prompts to the target asynchronously.
Args:
requests (list[NormalizerRequest]): A list of NormalizerRequest objects to be sent.
target (PromptTarget): The target to which the prompts are sent.
labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request.
Defaults to None.
attack_identifier (Optional[dict[str, str]], optional): A dictionary identifying the attack.
Defaults to None.
batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10.
Returns:
list[Message]: A list of Message objects representing the responses
received for each prompt.
"""
batch_items: List[List[Any]] = [
[request.seed_group for request in requests],
[request.request_converter_configurations for request in requests],
[request.response_converter_configurations for request in requests],
[request.conversation_id for request in requests],
]
batch_item_keys = [
"seed_group",
"request_converter_configurations",
"response_converter_configurations",
"conversation_id",
]
responses = await batch_task_async(
prompt_target=target,
batch_size=batch_size,
items_to_batch=batch_items,
task_func=self.send_prompt_async,
task_arguments=batch_item_keys,
target=target,
labels=labels,
attack_identifier=attack_identifier,
)
# Filter out None responses (e.g., from empty responses)
return [response for response in responses if response is not None]
[docs]
async def convert_values(
self,
converter_configurations: list[PromptConverterConfiguration],
message: Message,
) -> None:
"""
Apply converter configurations to message pieces.
Args:
converter_configurations (list[PromptConverterConfiguration]): List of configurations specifying
which converters to apply and to which message pieces.
message (Message): The message containing pieces to be converted.
Raises:
PyritException: If a converter raises a PyRIT exception (re-raised with enhanced context).
RuntimeError: If a converter raises a non-PyRIT exception (wrapped with converter context).
"""
for converter_configuration in converter_configurations:
for piece_index, piece in enumerate(message.message_pieces):
indexes = converter_configuration.indexes_to_apply
data_types = converter_configuration.prompt_data_types_to_apply
if indexes and piece_index not in indexes:
continue
if data_types and piece.converted_value_data_type not in data_types:
continue
piece.converter_identifiers.extend(
[converter.get_identifier() for converter in converter_configuration.converters]
)
converted_text = piece.converted_value
converted_text_data_type = piece.converted_value_data_type
for converter in converter_configuration.converters:
try:
converter_result = await converter.convert_tokens_async(
prompt=converted_text,
input_type=converted_text_data_type,
start_token=self._start_token,
end_token=self._end_token,
)
converted_text = converter_result.output_text
converted_text_data_type = converter_result.output_type
except PyritException as e:
# Re-raise PyRIT exceptions with enhanced context while preserving type for retry decorators
e.message = f"Error in converter {converter.__class__.__name__}: {e.message}"
e.args = (f"Status Code: {e.status_code}, Message: {e.message}",)
raise
except Exception as e:
# Wrap non-PyRIT exceptions for better error tracing
raise RuntimeError(f"Error in converter {converter.__class__.__name__}: {str(e)}") from e
piece.converted_value = converted_text
piece.converted_value_data_type = converted_text_data_type
async def _calc_hash(self, request: Message) -> None:
"""Add a request to the memory."""
tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.message_pieces]
await asyncio.gather(*tasks)
async def _build_message(
self,
*,
seed_group: SeedGroup,
conversation_id: str,
request_converter_configurations: list[PromptConverterConfiguration],
target: PromptTarget,
labels: dict[str, str],
attack_identifier: Optional[dict[str, str]] = None,
) -> Message:
"""
Build a message based on the given parameters.
Applies parameters and converters to the prompt text and puts all the pieces together.
Args:
seed_group (SeedGroup): The group of seed prompts to be used.
conversation_id (str): The ID of the conversation.
request_converter_configurations (list[PromptConverterConfiguration]): List of configurations for
request converters.
target (PromptTarget): The target for the prompt.
labels (dict[str, str]): A dictionary of labels associated with the prompt.
attack_identifier (Optional[dict[str, str]]): An optional dictionary for attack identifiers.
Returns:
Message: The message object.
"""
entries = []
# All message pieces within Message needs to have same conversation ID.
conversation_id = conversation_id if conversation_id else str(uuid4())
for seed_prompt in seed_group.prompts:
message_piece = MessagePiece(
role=seed_prompt.role,
original_value=seed_prompt.value,
conversation_id=conversation_id,
sequence=seed_prompt.sequence,
labels=labels,
prompt_metadata=seed_prompt.metadata,
prompt_target_identifier=target.get_identifier(),
attack_identifier=attack_identifier,
original_value_data_type=seed_prompt.data_type,
)
entries.append(message_piece)
response = Message(message_pieces=entries)
await self.convert_values(converter_configurations=request_converter_configurations, message=response)
return response
[docs]
async def add_prepended_conversation_to_memory(
self,
conversation_id: str,
should_convert: bool = True,
converter_configurations: Optional[list[PromptConverterConfiguration]] = None,
attack_identifier: Optional[dict[str, str]] = None,
prepended_conversation: Optional[list[Message]] = None,
) -> Optional[list[Message]]:
"""
Process the prepended conversation by converting it if needed and adding it to memory.
Args:
conversation_id (str): The conversation ID to use for the message pieces
should_convert (bool): Whether to convert the prepended conversation
converter_configurations (Optional[list[PromptConverterConfiguration]]): Configurations for converting the
request
attack_identifier (Optional[dict[str, str]]): Identifier for the attack
prepended_conversation (Optional[list[Message]]): The conversation to prepend
Returns:
Optional[list[Message]]: The processed prepended conversation
"""
if not prepended_conversation:
return None
# Create a deep copy of the prepended conversation to avoid modifying the original
prepended_conversation = copy.deepcopy(prepended_conversation)
for request in prepended_conversation:
if should_convert and converter_configurations:
await self.convert_values(message=request, converter_configurations=converter_configurations)
for piece in request.message_pieces:
piece.conversation_id = conversation_id
if attack_identifier:
piece.attack_identifier = attack_identifier
# if the piece is retrieved from somewhere else, it needs to be unique
# and if not, this won't hurt anything
piece.id = uuid4()
self._memory.add_message_to_memory(request=request)
return prepended_conversation