Source code for pyrit.executor.attack.component.conversation_manager

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import uuid
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

from pyrit.memory import CentralMemory
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
from pyrit.prompt_normalizer.prompt_converter_configuration import (
    PromptConverterConfiguration,
)
from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.prompt_target import PromptTarget

logger = logging.getLogger(__name__)


[docs] @dataclass class ConversationState: """Container for conversation state data shared between attack components.""" turn_count: int = 0 last_user_message: str = "" last_assistant_message_scores: List[Score] = field(default_factory=list)
[docs] class ConversationManager: """ Manages conversations for attacks, handling message history, system prompts, and conversation state. This class provides methods to retrieve conversations, add system prompts, and update conversation state with prepended messages. """
[docs] def __init__( self, *, attack_identifier: dict[str, str], prompt_normalizer: Optional[PromptNormalizer] = None, ): """ Initialize the conversation manager. Args: attack_identifier (dict[str, str]): The identifier of the attack this manager belongs to. prompt_normalizer (Optional[PromptNormalizer]): Optional prompt normalizer to use for converting prompts. If not provided, a default PromptNormalizer instance will be created. """ self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() self._attack_identifier = attack_identifier
[docs] def get_conversation(self, conversation_id: str) -> List[Message]: """ Retrieve a conversation by its ID. Args: conversation_id (str): The ID of the conversation to retrieve. Returns: List[Message]: A list of messages in the conversation, ordered by their creation time. If no messages exist, an empty list is returned. """ conversation = self._memory.get_conversation(conversation_id=conversation_id) return list(conversation)
[docs] def get_last_message( self, *, conversation_id: str, role: Optional[ChatMessageRole] = None ) -> Optional[MessagePiece]: """ Retrieve the most recent message from a conversation. Args: conversation_id (str): The ID of the conversation to retrieve the last message from. role (Optional[ChatMessageRole]): If provided, only return the last message that matches this role. Returns: Optional[MessagePiece]: The last message piece from the conversation, or `None` if no messages exist. """ conversation = self.get_conversation(conversation_id) if not conversation: return None if role: for m in reversed(conversation): piece = m.get_piece() if piece.role == role: return piece return None return conversation[-1].get_piece()
[docs] def set_system_prompt( self, *, target: PromptChatTarget, conversation_id: str, system_prompt: str, labels: Optional[Dict[str, str]] = None, ) -> None: """ Set or update the system-level prompt associated with a conversation. This helper is intended for conversational (`PromptChatTarget`) goals, where a dedicated system prompt influences the behavior of the LLM for all subsequent user / assistant messages in the same `conversation_id`. Args: target (PromptChatTarget): The target to set the system prompt on. conversation_id (str): Unique identifier for the conversation to set the system prompt on. system_prompt (str): The system prompt to set for the conversation. labels (Optional[Dict[str, str]]): Optional labels to associate with the system prompt. These can be used for categorization or filtering purposes. """ target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, attack_identifier=self._attack_identifier, labels=labels, )
[docs] async def update_conversation_state_async( self, *, conversation_id: str, target: Optional[Union[PromptTarget, PromptChatTarget]] = None, prepended_conversation: List[Message], request_converters: Optional[List[PromptConverterConfiguration]] = None, response_converters: Optional[List[PromptConverterConfiguration]] = None, max_turns: Optional[int] = None, ) -> ConversationState: """ Prepare a chat conversation by attaching history, enforcing target-specific rules, optionally normalizing prompts, and returning a serializable `ConversationState`. This helper is designed to support two distinct usage patterns: Single-turn bootstrap - When `max_turns` is **not** supplied the function simply injects the provided `prepended_conversation` into memory, performs any requested prompt conversions, and exits. Multi-turn continuation - When `max_turns` **is** supplied the function acts as a state machine: it verifies that the running history does not exceed the allowed turn budget, excludes the most recent user-utterance (so that an attack can re-inject it as the "live" request), and extracts per-session counters such as the current turn index. Args: conversation_id (str): Unique identifier for the conversation to update or create. target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if applicable). prepended_conversation (List[Message]): List of messages to prepend to the conversation history. request_converters (Optional[List[PromptConverterConfiguration]]): List of configurations for converting user (request) messages. response_converters (Optional[List[PromptConverterConfiguration]]): List of configurations for converting assistant (response) messages. max_turns (Optional[int]): Maximum number of turns allowed in the conversation. If not provided, the function assumes a single-turn context. Returns: ConversationState: A snapshot of the conversation state after processing the prepended messages, including turn count and last user message. Raises: ValueError: If `conversation_id` is empty or if the last message in a multi-turn context is a user message (which should not be prepended). """ if not conversation_id: raise ValueError("conversation_id cannot be empty") # Initialize conversation state state = ConversationState() logger.debug(f"Preparing conversation with ID: {conversation_id}") # Do not proceed if no history is provided if not prepended_conversation: logger.debug(f"No history provided for conversation initialization: {conversation_id}") return state # Filter out None values and empty requests valid_requests = [req for req in prepended_conversation if req is not None and req.message_pieces] if not valid_requests: logger.debug(f"No valid requests in prepended conversation for: {conversation_id}") return state # Determine if we should exclude the last message (if it's a user message in multi-turn context) last_message = valid_requests[-1].message_pieces[0] is_multi_turn = max_turns is not None should_exclude_last = is_multi_turn and last_message.role == "user" # Process all messages except potentially the last one for i, request in enumerate(valid_requests): # Skip the last message if it's a user message in multi-turn context if should_exclude_last and i == len(valid_requests) - 1: logger.debug("Skipping last user message (will be added by attack)") continue # Apply converters if needed if request_converters or response_converters: logger.debug(f"Converting request {i + 1}/{len(valid_requests)} in conversation {conversation_id}") # Apply role-specific converters await self._apply_role_specific_converters_async( request=request, request_converters=request_converters, response_converters=response_converters, ) # Process the message piece logger.debug(f"Processing message {i + 1}/{len(valid_requests)} in conversation {conversation_id}") await self._process_prepended_message_async( request=request, conversation_id=conversation_id, conversation_state=state, target=target, max_turns=max_turns, ) # Extract state from the conversation history (only for multi-turn conversations) if is_multi_turn: await self._populate_conversation_state_async( last_message=last_message, prepended_conversation=valid_requests, conversation_state=state, ) return state
async def _apply_role_specific_converters_async( self, *, request: Message, request_converters: Optional[List[PromptConverterConfiguration]] = None, response_converters: Optional[List[PromptConverterConfiguration]] = None, ) -> None: """ Apply role-specific converters to messages. - Request converters are applied to 'user' role messages - Response converters are applied to 'assistant' role messages - No converters are applied to 'system' role messages Args: request (Message): The request containing pieces to convert. request_converters (Optional[List[PromptConverterConfiguration]]): Converter configurations to apply to 'user' role messages. response_converters (Optional[List[PromptConverterConfiguration]]): Converter configurations to apply to 'assistant' role messages. """ # Determine which converters to apply based on message roles for piece in request.message_pieces: applicable_converters: Optional[List[PromptConverterConfiguration]] = None if piece.role == "user" and request_converters: applicable_converters = request_converters elif piece.role == "assistant" and response_converters: applicable_converters = response_converters # System messages get no converters (applicable_converters remains None) # Apply the determined converters if applicable_converters: # Create a temporary request with just this piece for conversion temp_request = Message(message_pieces=[piece]) await self._prompt_normalizer.convert_values( message=temp_request, converter_configurations=applicable_converters, ) async def _process_prepended_message_async( self, *, request: Message, conversation_id: str, conversation_state: ConversationState, target: Optional[Union[PromptTarget, PromptChatTarget]] = None, max_turns: Optional[int] = None, ) -> None: """ Process a prepended message and update the conversation state. This method handles the conversion of message pieces, sets conversation IDs, and attack identifiers, and processes each piece based on its role. Args: request (Message): The request containing pieces to process. conversation_id (str): The ID of the conversation to update. conversation_state (ConversationState): The current state of the conversation. target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if applicable). max_turns (Optional[int]): Maximum allowed turns for the conversation. Raises: ValueError: If the request is invalid or if a system prompt is provided but target doesn't support it. """ # Validate the request before processing if not request or not request.message_pieces: return # Set the conversation ID and attack ID for each piece in the request save_to_memory = True for piece in request.message_pieces: piece.conversation_id = conversation_id piece.attack_identifier = self._attack_identifier piece.id = uuid.uuid4() # Process the piece based on its role self._process_piece( piece=piece, conversation_state=conversation_state, max_turns=max_turns, target=target, ) if ConversationManager._should_exclude_piece_from_memory(piece=piece, max_turns=max_turns): # it is excluded, so we don't want to save it to memory save_to_memory = False # Add the request to memory if it was not a system piece if save_to_memory: self._memory.add_message_to_memory(request=request) def _process_piece( self, *, piece: MessagePiece, conversation_state: ConversationState, max_turns: Optional[int] = None, target: Optional[Union[PromptTarget, PromptChatTarget]] = None, ) -> None: """ Process a message piece based on its role and update conversation state. Args: piece (MessagePiece): The piece to process. conversation_state (ConversationState): The current state of the conversation. max_turns (Optional[int]): Maximum allowed turns (for validation). target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on. Raises: ValueError: If max_turns would be exceeded by this piece. ValueError: If a system prompt is provided but target doesn't support it. """ # Check if multiturn is_multi_turn = max_turns is not None # Handle system prompts (both single-turn and multi-turn) if piece.role == "system": if target is None: raise ValueError("Target must be provided to handle system prompts") if not isinstance(target, PromptChatTarget): raise ValueError("Target must be a PromptChatTarget to set system prompts") # Set system prompt and exclude from memory self.set_system_prompt( target=target, conversation_id=piece.conversation_id, system_prompt=piece.converted_value, labels=piece.labels, ) # Handle assistant messages (count turns for multi-turn only) elif piece.role == "assistant" and is_multi_turn: # Update turn count conversation_state.turn_count += 1 # Validate against max_turns if max_turns and conversation_state.turn_count > max_turns: raise ValueError( f"The number of turns in the prepended conversation ({conversation_state.turn_count-1}) is equal to" + f" or exceeds the maximum number of turns ({max_turns}), which means the" + " conversation will not be able to continue. Please reduce the number of turns in" + " the prepended conversation or increase the maximum number of turns and try again." ) @staticmethod def _should_exclude_piece_from_memory(*, piece: MessagePiece, max_turns: Optional[int] = None) -> bool: # System pieces should always be excluded from memory because set_system_prompt function # is called on the target, which internally adds them to memory return piece.role == "system" async def _populate_conversation_state_async( self, *, prepended_conversation: List[Message], last_message: MessagePiece, conversation_state: ConversationState, ) -> None: """ Extract conversation context from the last messages in prepended_conversation. This extracts: - Last user message for continuing conversations. - Scores for the last assistant message for evaluation. Args: prepended_conversation (List[Message]): Complete conversation history. last_message (MessagePiece): The last message in the history. conversation_state (ConversationState): State object to populate. Raises: ValueError: If an assistant message doesn't have a preceding user message. """ if not prepended_conversation: return # Nothing to extract from empty history # Extract the last user message and assistant message scores from the last message if last_message.role == "user": conversation_state.last_user_message = last_message.converted_value logger.debug(f"Extracted last user message: {conversation_state.last_user_message[:50]}...") elif last_message.role == "assistant": # Get scores for the last assistant message based off of the original id conversation_state.last_assistant_message_scores = list( self._memory.get_prompt_scores(prompt_ids=[str(last_message.original_prompt_id)]) ) # Do not set last user message if there are no scores for the last assistant message if not conversation_state.last_assistant_message_scores: logger.debug("No scores found for last assistant message") return # Check assumption that there will be a user message preceding the assistant message if len(prepended_conversation) > 1 and prepended_conversation[-2].get_piece().role == "user": conversation_state.last_user_message = prepended_conversation[-2].get_value() logger.debug(f"Extracted preceding user message: {conversation_state.last_user_message[:50]}...") else: raise ValueError( "There must be a user message preceding the assistant message in prepended conversations." )