Source code for pyrit.executor.attack.component.conversation_manager

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from pyrit.common.utils import combine_dict
from pyrit.executor.attack.component.prepended_conversation_config import (
    PrependedConversationConfig,
)
from pyrit.memory import CentralMemory
from pyrit.message_normalizer import ConversationContextNormalizer
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
from pyrit.prompt_normalizer.prompt_converter_configuration import (
    PromptConverterConfiguration,
)
from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget

if TYPE_CHECKING:
    from pyrit.executor.attack.core import AttackContext

logger = logging.getLogger(__name__)


def mark_messages_as_simulated(messages: Sequence[Message]) -> List[Message]:
    """
    Mark assistant messages as simulated_assistant for traceability.

    This function converts all assistant roles to simulated_assistant in the
    provided messages. This is useful when loading conversations from YAML files
    or other sources where the responses are not from actual targets.

    Args:
        messages (Sequence[Message]): The messages to mark as simulated.

    Returns:
        List[Message]: The same messages with assistant roles converted to simulated_assistant.
            Modifies the messages in place and also returns them for convenience.
    """
    result = list(messages)
    for message in result:
        for piece in message.message_pieces:
            if piece._role == "assistant":
                piece._role = "simulated_assistant"
    return result


def get_adversarial_chat_messages(
    prepended_conversation: List[Message],
    *,
    adversarial_chat_conversation_id: str,
    attack_identifier: Dict[str, str],
    adversarial_chat_target_identifier: Dict[str, str],
    labels: Optional[Dict[str, str]] = None,
) -> List[Message]:
    """
    Transform prepended conversation messages for adversarial chat with swapped roles.

    This function creates new Message objects with swapped roles for use in adversarial
    chat conversations. From the adversarial chat's perspective:
    - "user" messages become "assistant" (prompts it generated)
    - "assistant" messages become "user" (responses it received)
    - System messages are skipped (adversarial chat has its own system prompt)

    All messages receive new UUIDs to distinguish them from the originals.

    Args:
        prepended_conversation: The original conversation messages to transform.
        adversarial_chat_conversation_id: Conversation ID for the adversarial chat.
        attack_identifier: Attack identifier to associate with messages.
        adversarial_chat_target_identifier: Target identifier for the adversarial chat.
        labels: Optional labels to associate with the messages.

    Returns:
        List of transformed messages with swapped roles and new IDs.
    """
    if not prepended_conversation:
        return []

    role_swap: Dict[ChatMessageRole, ChatMessageRole] = {
        "user": "assistant",
        "assistant": "user",
        "simulated_assistant": "user",
    }

    result: List[Message] = []

    for message in prepended_conversation:
        for piece in message.message_pieces:
            # Skip system messages - adversarial chat has its own system prompt
            if piece.api_role == "system":
                continue

            # Create a new piece with swapped role for adversarial chat
            swapped_role = role_swap.get(piece.api_role, piece.api_role)

            adversarial_piece = MessagePiece(
                id=uuid.uuid4(),
                role=swapped_role,
                original_value=piece.original_value,
                converted_value=piece.converted_value,
                original_value_data_type=piece.original_value_data_type,
                converted_value_data_type=piece.converted_value_data_type,
                conversation_id=adversarial_chat_conversation_id,
                attack_identifier=attack_identifier,
                prompt_target_identifier=adversarial_chat_target_identifier,
                labels=labels,
            )

            result.append(adversarial_piece.to_message())

    logger.debug(f"Created {len(result)} adversarial chat messages with swapped roles")
    return result


async def build_conversation_context_string_async(messages: List[Message]) -> str:
    """
    Build a formatted context string from a list of messages.

    This is a convenience function that uses ConversationContextNormalizer
    to format messages into a "Turn N: User/Assistant" format suitable for
    use in system prompts.

    Args:
        messages: The conversation messages to format.

    Returns:
        A formatted string representing the conversation context.
        Returns empty string if no messages provided.
    """
    if not messages:
        return ""
    normalizer = ConversationContextNormalizer()
    return await normalizer.normalize_string_async(messages)


def get_prepended_turn_count(prepended_conversation: Optional[List[Message]]) -> int:
    """
    Count the number of turns (assistant responses) in a prepended conversation.

    This is used to offset iteration counts so that executed_turns reflects
    the total conversation depth including prepended messages.

    Args:
        prepended_conversation: The prepended conversation messages, or None.

    Returns:
        int: The number of assistant messages in the prepended conversation.
            Returns 0 if prepended_conversation is None or empty.
    """
    if not prepended_conversation:
        return 0
    return sum(1 for msg in prepended_conversation if msg.api_role == "assistant")


[docs] @dataclass class ConversationState: """Container for conversation state data returned from context initialization.""" turn_count: int = 0 # Scores from the last assistant message (for attack-specific interpretation) # Used by Crescendo to detect refusals and objective achievement last_assistant_message_scores: List[Score] = field(default_factory=list)
[docs] class ConversationManager: """ Manages conversations for attacks, handling message history, system prompts, and conversation state. This class provides methods to: - Initialize attack context with prepended conversations - Retrieve conversation history - Set system prompts for chat targets """
[docs] def __init__( self, *, attack_identifier: Dict[str, str], prompt_normalizer: Optional[PromptNormalizer] = None, ): """ Initialize the conversation manager. Args: attack_identifier: The identifier of the attack this manager belongs to. prompt_normalizer: Optional prompt normalizer for converting prompts. If not provided, a default PromptNormalizer instance will be created. """ self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() self._attack_identifier = attack_identifier
[docs] def get_conversation(self, conversation_id: str) -> List[Message]: """ Retrieve a conversation by its ID. Args: conversation_id: The ID of the conversation to retrieve. Returns: A list of messages in the conversation, ordered by creation time. Returns empty list if no messages exist. """ conversation = self._memory.get_conversation(conversation_id=conversation_id) return list(conversation)
[docs] def get_last_message( self, *, conversation_id: str, role: Optional[ChatMessageRole] = None ) -> Optional[MessagePiece]: """ Retrieve the most recent message from a conversation. Args: conversation_id: The ID of the conversation to retrieve from. role: If provided, return only the last message matching this role. Returns: The last message piece, or None if no messages exist. """ conversation = self.get_conversation(conversation_id) if not conversation: return None if role: for m in reversed(conversation): piece = m.get_piece() if piece.api_role == role: return piece return None return conversation[-1].get_piece()
[docs] def set_system_prompt( self, *, target: PromptChatTarget, conversation_id: str, system_prompt: str, labels: Optional[Dict[str, str]] = None, ) -> None: """ Set or update the system prompt for a conversation. Args: target: The chat target to set the system prompt on. conversation_id: Unique identifier for the conversation. system_prompt: The system prompt text. labels: Optional labels to associate with the system prompt. """ target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, attack_identifier=self._attack_identifier, labels=labels, )
[docs] async def initialize_context_async( self, *, context: "AttackContext[Any]", target: PromptTarget, conversation_id: str, request_converters: Optional[List[PromptConverterConfiguration]] = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: Optional[int] = None, memory_labels: Optional[Dict[str, str]] = None, ) -> ConversationState: """ Initialize attack context with prepended conversation and merged labels. This is the primary method for setting up an attack context. It: 1. Merges memory_labels from attack strategy with context labels 2. Processes prepended_conversation based on target type and config 3. Updates context.executed_turns for multi-turn attacks 4. Sets context.next_message if there's an unanswered user message For PromptChatTarget: - Adds prepended messages to memory with simulated_assistant role - All messages get new UUIDs For non-chat PromptTarget: - If `config.non_chat_target_behavior="normalize_first_turn"`: normalizes conversation to string and prepends to context.next_message - If `config.non_chat_target_behavior="raise"`: raises ValueError Args: context: The attack context to initialize. target: The objective target for the conversation. conversation_id: Unique identifier for the conversation. request_converters: Converters to apply to messages. prepended_conversation_config: Configuration for handling prepended conversation. max_turns: Maximum turns allowed (for validation and state tracking). memory_labels: Labels from the attack strategy to merge with context labels. Returns: ConversationState with turn_count and last_assistant_message_scores. Raises: ValueError: If conversation_id is empty, or if prepended_conversation requires a PromptChatTarget but target is not one. """ if not conversation_id: raise ValueError("conversation_id cannot be empty") # Merge memory labels: attack strategy labels + context labels context.memory_labels = combine_dict(existing_dict=memory_labels, new_dict=context.memory_labels) state = ConversationState() prepended_conversation = context.prepended_conversation if not prepended_conversation: logger.debug(f"No prepended conversation for context initialization: {conversation_id}") return state # Handle target type compatibility is_chat_target = isinstance(target, PromptChatTarget) if not is_chat_target: return await self._handle_non_chat_target_async( context=context, prepended_conversation=prepended_conversation, config=prepended_conversation_config, ) # Process prepended conversation for objective target return await self._process_prepended_for_chat_target_async( context=context, prepended_conversation=prepended_conversation, conversation_id=conversation_id, request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, )
async def _handle_non_chat_target_async( self, *, context: "AttackContext[Any]", prepended_conversation: List[Message], config: Optional["PrependedConversationConfig"], ) -> ConversationState: """ Handle prepended conversation for non-chat targets. Args: context: The attack context. prepended_conversation: Messages to prepend. config: Configuration for non-chat target behavior. Returns: Empty ConversationState (non-chat targets don't track turns). Raises: ValueError: If config requires raising for non-chat targets. """ if config is None: config = PrependedConversationConfig() if config.non_chat_target_behavior == "raise": raise ValueError( "prepended_conversation requires the objective target to be a PromptChatTarget. " "Non-chat objective targets do not support conversation history. " "Use PrependedConversationConfig with non_chat_target_behavior='normalize_first_turn' " "to normalize the conversation into the first message instead." ) # Normalize conversation to string normalizer = config.get_message_normalizer() normalized_context = await normalizer.normalize_string_async(prepended_conversation) # Prepend to next_message if it exists, otherwise create new message if context.next_message is not None: # Find an existing text piece to prepend to text_piece = None for piece in context.next_message.message_pieces: if piece.original_value_data_type == "text": text_piece = piece break if text_piece: # Prepend context to the existing text piece text_piece.original_value = f"{normalized_context}\n\n{text_piece.original_value}" text_piece.converted_value = f"{normalized_context}\n\n{text_piece.converted_value}" else: # No text piece found (multimodal message), add a new text piece at the beginning context_piece = MessagePiece( id=uuid.uuid4(), role="user", original_value=normalized_context, converted_value=normalized_context, original_value_data_type="text", converted_value_data_type="text", ) # Create a new message with the context piece prepended context.next_message = Message( message_pieces=[context_piece] + list(context.next_message.message_pieces) ) else: # Create new message with just the context context.next_message = Message.from_prompt(prompt=normalized_context, role="user") logger.debug(f"Normalized prepended conversation for non-chat target: {len(normalized_context)} characters") return ConversationState()
[docs] async def add_prepended_conversation_to_memory_async( self, *, prepended_conversation: List[Message], conversation_id: str, request_converters: Optional[List[PromptConverterConfiguration]] = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: Optional[int] = None, ) -> int: """ Add prepended conversation messages to memory for a chat target. This is a lower-level method that handles adding messages to memory without modifying any attack context state. It can be called directly by attacks that manage their own state (like TAP nodes) or internally by initialize_context_async for standard attacks. Messages are added with: - Duplicated message objects (preserves originals) - simulated_assistant role for assistant messages (for traceability) - Converters applied based on config Args: prepended_conversation: Messages to add to memory. conversation_id: Conversation ID to assign to all messages. request_converters: Optional converters to apply to messages. prepended_conversation_config: Optional configuration for converter roles. max_turns: If provided, validates that turn count doesn't exceed this limit. Returns: The number of turns (assistant messages) added. Raises: ValueError: If max_turns is exceeded by the prepended conversation. """ # Filter valid messages valid_messages = [msg for msg in prepended_conversation if msg and msg.message_pieces] if not valid_messages: return 0 # Get roles that should have converters applied apply_to_roles = ( prepended_conversation_config.apply_converters_to_roles if prepended_conversation_config else None ) turn_count = 0 for i, message in enumerate(valid_messages): message_copy = message.duplicate_message() message_copy.set_simulated_role() for piece in message_copy.message_pieces: piece.conversation_id = conversation_id piece.attack_identifier = self._attack_identifier # Count turns at message level (only assistant/simulated_assistant messages) # A multi-part response still counts as one turn if message_copy.api_role == "assistant": turn_count += 1 if max_turns is not None and turn_count > max_turns: raise ValueError( f"Prepended conversation has {turn_count} turns, " f"exceeding max_turns={max_turns}. Reduce prepended turns or increase max_turns." ) # Apply converters if configured if request_converters: await self._apply_converters_async( message=message_copy, request_converters=request_converters, apply_to_roles=apply_to_roles, ) # Add to memory self._memory.add_message_to_memory(request=message_copy) logger.debug(f"Added prepended message {i + 1}/{len(valid_messages)} to memory") return turn_count
async def _process_prepended_for_chat_target_async( self, *, context: "AttackContext[Any]", prepended_conversation: List[Message], conversation_id: str, request_converters: Optional[List[PromptConverterConfiguration]], prepended_conversation_config: Optional["PrependedConversationConfig"], max_turns: Optional[int], ) -> ConversationState: """ Process prepended conversation for a chat target. Adds messages to memory with: - New UUIDs for all pieces - simulated_assistant role for assistant messages - Converters applied based on config Args: context: The attack context. prepended_conversation: Messages to add to memory. conversation_id: Conversation ID for the messages. request_converters: Converters to apply. prepended_conversation_config: Configuration for converter roles. max_turns: Maximum turns for validation. Returns: ConversationState with turn_count and scores. """ state = ConversationState() is_multi_turn = max_turns is not None # Filter valid messages valid_messages = [msg for msg in prepended_conversation if msg and msg.message_pieces] if not valid_messages: return state # Use the lower-level method to add messages to memory state.turn_count = await self.add_prepended_conversation_to_memory_async( prepended_conversation=prepended_conversation, conversation_id=conversation_id, request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, ) # Update context for multi-turn attacks if is_multi_turn: # Update executed_turns if hasattr(context, "executed_turns"): context.executed_turns = state.turn_count # Extract scores for last assistant message if it exists # Multi-part messages (e.g., text + image) may have scores on multiple pieces last_message = valid_messages[-1] if last_message.api_role == "assistant": prompt_ids = [str(piece.original_prompt_id) for piece in last_message.message_pieces] state.last_assistant_message_scores = list(self._memory.get_prompt_scores(prompt_ids=prompt_ids)) return state async def _apply_converters_async( self, *, message: Message, request_converters: List[PromptConverterConfiguration], apply_to_roles: Optional[List[ChatMessageRole]], ) -> None: """ Apply converters to message pieces. Args: message: The message containing pieces to convert. request_converters: Converter configurations to apply. apply_to_roles: If provided, only apply to pieces with these roles. If None, apply to all roles. """ for piece in message.message_pieces: # Filter by role if specified if apply_to_roles is not None and piece.api_role not in apply_to_roles: continue temp_message = Message(message_pieces=[piece]) await self._prompt_normalizer.convert_values( message=temp_message, converter_configurations=request_converters, )