Source code for pyrit.prompt_normalizer.prompt_normalizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
import asyncio
import logging
from typing import Any, List, Optional
from uuid import uuid4

from pyrit.common.batch_helper import batch_task_async
from pyrit.exceptions import EmptyResponseException
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import (
    PromptRequestPiece,
    PromptRequestResponse,
    construct_response_from_request,
)
from pyrit.models.filter_criteria import PromptConverterState, PromptFilterCriteria
from pyrit.models.seed_prompt import SeedPromptGroup
from pyrit.prompt_normalizer import NormalizerRequest, PromptConverterConfiguration
from pyrit.prompt_target import PromptTarget

logger = logging.getLogger(__name__)


[docs] class PromptNormalizer(abc.ABC): _memory: MemoryInterface = None
[docs] def __init__(self, start_token: str = "⟪", end_token: str = "⟫") -> None: """ Initializes the PromptNormalizer. start_token and end_token are used to delineate which part of a prompt is converted. """ self._memory = CentralMemory.get_memory_instance() self._start_token = start_token self._end_token = end_token self.id = str(uuid4()) self._skip_criteria: Optional[PromptFilterCriteria] = None
[docs] async def send_prompt_async( self, *, seed_prompt_group: SeedPromptGroup, target: PromptTarget, conversation_id: str = None, request_converter_configurations: list[PromptConverterConfiguration] = [], response_converter_configurations: list[PromptConverterConfiguration] = [], sequence: int = -1, labels: Optional[dict[str, str]] = None, orchestrator_identifier: Optional[dict[str, str]] = None, ) -> PromptRequestResponse: """ Sends a single request to a target. Args: seed_prompt_group (SeedPromptGroup): The seed prompt group to be sent. target (PromptTarget): The target to which the prompt is sent. conversation_id (str, optional): The ID of the conversation. Defaults to None. request_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the request. Defaults to an empty list. response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. sequence (int, optional): The sequence number of the request. Defaults to -1. labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. orchestrator_identifier (Optional[dict[str, str]], optional): Identifier for the orchestrator. Defaults to None. Raises: Exception: If an error occurs during the request processing. Returns: PromptRequestResponse: The response received from the target. """ request = await self._build_prompt_request_response( seed_prompt_group=seed_prompt_group, conversation_id=conversation_id, request_converter_configurations=request_converter_configurations, target=target, sequence=sequence, labels=labels, orchestrator_identifier=orchestrator_identifier, ) await self._calc_hash(request=request) if self._should_skip_based_on_skip_criteria(request): return None response = None try: response = await target.send_prompt_async(prompt_request=request) self._memory.add_request_response_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution self._memory.add_request_response_to_memory(request=request) response = construct_response_from_request( request=request.request_pieces[0], response_text_pieces=[""], response_type="text", error="empty", ) except Exception as ex: # Ensure request to memory before processing exception self._memory.add_request_response_to_memory(request=request) error_response = construct_response_from_request( request=request.request_pieces[0], response_text_pieces=[str(ex)], response_type="error", error="processing", ) await self._calc_hash(request=error_response) self._memory.add_request_response_to_memory(request=error_response) cid = request.request_pieces[0].conversation_id if request and request.request_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex if response is None: return None await self.convert_values(converter_configurations=response_converter_configurations, request_response=response) await self._calc_hash(request=response) self._memory.add_request_response_to_memory(request=response) return response
[docs] async def send_prompt_batch_to_target_async( self, *, requests: list[NormalizerRequest], target: PromptTarget, labels: Optional[dict[str, str]] = None, orchestrator_identifier: Optional[dict[str, str]] = None, batch_size: int = 10, ) -> list[PromptRequestResponse]: """ Sends a batch of prompts to the target asynchronously. Args: requests (list[NormalizerRequest]): A list of NormalizerRequest objects to be sent. target (PromptTarget): The target to which the prompts are sent. labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request. Defaults to None. orchestrator_identifier (Optional[dict[str, str]], optional): A dictionary identifying the orchestrator. Defaults to None. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. Returns: list[PromptRequestResponse]: A list of PromptRequestResponse objects representing the responses received for each prompt. """ batch_items: List[List[Any]] = [ [request.seed_prompt_group for request in requests], [request.request_converter_configurations for request in requests], [request.response_converter_configurations for request in requests], [request.conversation_id for request in requests], ] batch_item_keys = [ "seed_prompt_group", "request_converter_configurations", "response_converter_configurations", "conversation_id", ] responses = await batch_task_async( prompt_target=target, batch_size=batch_size, items_to_batch=batch_items, task_func=self.send_prompt_async, task_arguments=batch_item_keys, target=target, labels=labels, orchestrator_identifier=orchestrator_identifier, ) # send_prompt_async can return None if the prompt is skipped return [response for response in responses if response is not None]
[docs] async def convert_values( self, converter_configurations: list[PromptConverterConfiguration], request_response: PromptRequestResponse, ): for converter_configuration in converter_configurations: for piece_index, piece in enumerate(request_response.request_pieces): indexes = converter_configuration.indexes_to_apply data_types = converter_configuration.prompt_data_types_to_apply if indexes and piece_index not in indexes: continue if data_types and piece.converted_value_data_type not in data_types: continue piece.converter_identifiers.extend( [converter.get_identifier() for converter in converter_configuration.converters] ) converted_text = piece.converted_value converted_text_data_type = piece.converted_value_data_type for converter in converter_configuration.converters: converter_result = await converter.convert_tokens_async( prompt=converted_text, input_type=converted_text_data_type, start_token=self._start_token, end_token=self._end_token, ) converted_text = converter_result.output_text converted_text_data_type = converter_result.output_type piece.converted_value = converted_text piece.converted_value_data_type = converted_text_data_type
[docs] def set_skip_criteria(self, skip_criteria: PromptFilterCriteria, skip_value_type: PromptConverterState) -> None: """ Sets the skip criteria for the orchestrator. If prompts match this in memory and are the same as one being sent, then they won't be sent to a target. Prompts are the same if either the original prompt or the converted prompt, determined by skip_value_type flag. """ self._skip_criteria = skip_criteria prompts_to_skip = self._memory.get_prompt_request_pieces( role="user", orchestrator_id=self._skip_criteria.orchestrator_id, conversation_id=self._skip_criteria.conversation_id, prompt_ids=self._skip_criteria.prompt_ids, labels=self._skip_criteria.labels, sent_after=self._skip_criteria.sent_after, sent_before=self._skip_criteria.sent_before, original_values=self._skip_criteria.original_values, converted_values=self._skip_criteria.converted_values, data_type=self._skip_criteria.data_type, not_data_type=self._skip_criteria.not_data_type, converted_value_sha256=self._skip_criteria.converted_value_sha256, ) self._original_sha256_prompts_to_skip = [ prompt.original_value_sha256 for prompt in prompts_to_skip if prompt.original_value_sha256 ] self._converted_sha256_prompts_to_skip = [ prompt.converted_value_sha256 for prompt in prompts_to_skip if prompt.converted_value_sha256 ] self._skip_value_type = skip_value_type
def _should_skip_based_on_skip_criteria(self, prompt_request: PromptRequestResponse) -> bool: """ Filters out prompts from prompt_request_list that match the skip criteria. Every request_piece of the prompt_request needs to have matching sha256 to skip. """ if not self._skip_criteria: return False for user_prompt in prompt_request.request_pieces: if self._skip_value_type == "converted": if user_prompt.converted_value_sha256 not in self._converted_sha256_prompts_to_skip: return False else: if user_prompt.original_value_sha256 not in self._original_sha256_prompts_to_skip: return False return True async def _calc_hash(self, request: PromptRequestResponse) -> None: """ Adds a request to the memory. """ tasks = [asyncio.create_task(piece.set_sha256_values_async()) for piece in request.request_pieces] await asyncio.gather(*tasks) async def _build_prompt_request_response( self, *, seed_prompt_group: SeedPromptGroup, conversation_id: str, request_converter_configurations: list[PromptConverterConfiguration], target: PromptTarget, sequence: int, labels: dict[str, str], orchestrator_identifier: Optional[dict[str, str]], ) -> PromptRequestResponse: """ Builds a prompt request response based on the given parameters. Applies parameters and converters to the prompt text and puts all the pieces together. Args: seed_prompt_group (SeedPromptGroup): The group of seed prompts to be used. conversation_id (str): The ID of the conversation. request_converter_configurations (list[PromptConverterConfiguration]): List of configurations for request converters. target (PromptTarget): The target for the prompt. sequence (int): The sequence number of the prompt. labels (dict[str, str]): A dictionary of labels associated with the prompt. orchestrator_identifier (Optional[dict[str, str]]): An optional dictionary for orchestrator identifiers. Returns: PromptRequestResponse: The prompt request response object. """ entries = [] # All prompt request pieces within PromptRequestResponse needs to have same conversation ID. conversation_id = conversation_id if conversation_id else str(uuid4()) for seed_prompt in seed_prompt_group.prompts: prompt_request_piece = PromptRequestPiece( role="user", original_value=seed_prompt.value, conversation_id=conversation_id, sequence=sequence, labels=labels, prompt_metadata=seed_prompt.metadata, prompt_target_identifier=target.get_identifier(), orchestrator_identifier=orchestrator_identifier, original_value_data_type=seed_prompt.data_type, ) entries.append(prompt_request_piece) response = PromptRequestResponse(request_pieces=entries) await self.convert_values(converter_configurations=request_converter_configurations, request_response=response) return response