Source code for pyrit.models.seeds.seed_simulated_conversation

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
SeedSimulatedConversation - Configuration for generating simulated conversations dynamically.

This class holds the configuration (prompts, num_turns) needed to generate a simulated
conversation. It is a pure data/config class - the actual generation logic lives in
`pyrit.executor.attack.component.simulated_conversation`.

As a Seed subclass, it can be stored in the database for reproducibility tracking.
"""

from __future__ import annotations

import enum
import hashlib
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union

import pyrit
from pyrit.common.path import EXECUTOR_SIMULATED_TARGET_PATH
from pyrit.models.seeds.seed import Seed
from pyrit.models.seeds.seed_prompt import SeedPrompt

logger = logging.getLogger(__name__)


[docs] class SimulatedTargetSystemPromptPaths(enum.Enum): """Enum for predefined simulated target system prompt paths.""" COMPLIANT = Path(EXECUTOR_SIMULATED_TARGET_PATH, "compliant.yaml").resolve()
[docs] class NextMessageSystemPromptPaths(enum.Enum): """Enum for predefined next message generation system prompt paths.""" DIRECT = Path(EXECUTOR_SIMULATED_TARGET_PATH, "direct_next_message.yaml").resolve()
[docs] class SeedSimulatedConversation(Seed): """ Configuration for generating a simulated conversation dynamically. This class holds the paths and parameters needed to generate prepended conversation content by running an adversarial chat against a simulated (compliant) target. This is a pure configuration class. The actual generation is performed by `generate_simulated_conversation_async` in the executor layer, which accepts this config along with runtime dependencies (adversarial_chat target, scorer). The `value` property returns a JSON serialization of the config for database storage and deduplication. Attributes: num_turns: Number of conversation turns to generate. adversarial_chat_system_prompt_path: Path to the adversarial chat system prompt YAML. simulated_target_system_prompt_path: Path to the simulated target system prompt YAML. Defaults to the compliant prompt if not specified. next_message_system_prompt_path: Optional path to the system prompt for generating an additional user message after the simulated conversation. If provided, a single LLM call generates a final user message that attempts to get the target to fulfill the objective in their next response. """
[docs] def __init__( self, *, adversarial_chat_system_prompt_path: Union[str, Path], simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, next_message_system_prompt_path: Optional[Union[str, Path]] = None, num_turns: int = 3, sequence: int = 0, pyrit_version: Optional[str] = None, **kwargs: Any, ) -> None: """ Initialize a SeedSimulatedConversation. Args: adversarial_chat_system_prompt_path: Path to YAML file containing the adversarial chat system prompt. simulated_target_system_prompt_path: Optional path to YAML file containing the simulated target system prompt. Defaults to the compliant prompt. next_message_system_prompt_path: Optional path to YAML file containing the system prompt for generating a final user message. If provided, after the simulated conversation is generated, a single LLM call generates a user message that attempts to get the target to fulfill the objective. Defaults to None (no next message generation). num_turns: Number of conversation turns to generate. Defaults to 3. sequence: The starting sequence number for generated turns. When combined with static SeedPrompts, this determines where the simulated turns are inserted. Defaults to 0. pyrit_version: PyRIT version for reproducibility tracking. Defaults to current version. **kwargs: Additional arguments passed to the Seed base class. """ # Apply default for simulated target system prompt if not provided if simulated_target_system_prompt_path is None: simulated_target_system_prompt_path = SimulatedTargetSystemPromptPaths.COMPLIANT.value if num_turns <= 0: raise ValueError("num_turns must be a positive integer") if sequence < 0: raise ValueError("sequence must be a non-negative integer") self.adversarial_chat_system_prompt_path = Path(adversarial_chat_system_prompt_path) self.simulated_target_system_prompt_path = Path(simulated_target_system_prompt_path) self.next_message_system_prompt_path = ( Path(next_message_system_prompt_path) if next_message_system_prompt_path else None ) self.num_turns = num_turns self.sequence = sequence self.pyrit_version = pyrit_version or pyrit.__version__ # Compute value and pass to parent # Remove 'value' from kwargs if present since we compute it kwargs.pop("value", None) super().__init__(value=self._compute_value(), **kwargs)
def _compute_value(self) -> str: """Compute the value field as JSON serialization of config.""" config = { "num_turns": self.num_turns, "sequence": self.sequence, "adversarial_chat_system_prompt_path": str(self.adversarial_chat_system_prompt_path), "simulated_target_system_prompt_path": str(self.simulated_target_system_prompt_path), "next_message_system_prompt_path": ( str(self.next_message_system_prompt_path) if self.next_message_system_prompt_path else None ), "pyrit_version": self.pyrit_version, } return json.dumps(config, sort_keys=True, separators=(",", ":"))
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SeedSimulatedConversation": """ Create a SeedSimulatedConversation from a dictionary, typically from YAML. Expected format: num_turns: 3 adversarial_chat_system_prompt_path: path/to/adversarial.yaml simulated_target_system_prompt_path: path/to/simulated.yaml # optional Args: data: Dictionary containing the configuration. Returns: A new SeedSimulatedConversation instance. """ adversarial_path = data.get("adversarial_chat_system_prompt_path") if not adversarial_path: raise ValueError("adversarial_chat_system_prompt_path is required") return cls( num_turns=data.get("num_turns", 3), sequence=data.get("sequence", 0), adversarial_chat_system_prompt_path=adversarial_path, simulated_target_system_prompt_path=data.get("simulated_target_system_prompt_path"), next_message_system_prompt_path=data.get("next_message_system_prompt_path"), )
[docs] @classmethod def from_yaml_with_required_parameters( cls, template_path: Union[str, Path], required_parameters: list[str], error_message: Optional[str] = None, ) -> "SeedSimulatedConversation": """ Load a SeedSimulatedConversation from a YAML file and validate required parameters. Args: template_path: Path to the YAML file containing the config. required_parameters: List of parameter names that must exist. error_message: Custom error message if validation fails. Returns: The loaded and validated SeedSimulatedConversation. Raises: ValueError: If required parameters are missing. """ instance = cls.from_yaml_file(template_path) # Check required parameters for param in required_parameters: if not hasattr(instance, param) or getattr(instance, param) is None: msg = error_message or f"Missing required parameter: {param}" raise ValueError(msg) return instance
[docs] def get_identifier(self) -> Dict[str, Any]: """ Get an identifier dict capturing this configuration for comparison/storage. Returns: Dictionary with configuration details. """ return { "__type__": "SeedSimulatedConversation", "num_turns": self.num_turns, "sequence": self.sequence, "adversarial_chat_system_prompt_path": str(self.adversarial_chat_system_prompt_path), "simulated_target_system_prompt_path": str(self.simulated_target_system_prompt_path), "next_message_system_prompt_path": ( str(self.next_message_system_prompt_path) if self.next_message_system_prompt_path else None ), "pyrit_version": self.pyrit_version, }
[docs] def compute_hash(self) -> str: """ Compute a deterministic hash of this configuration. Returns: A SHA256 hash string representing the configuration. """ identifier = self.get_identifier() config_json = json.dumps(identifier, sort_keys=True, separators=(",", ":")) return hashlib.sha256(config_json.encode("utf-8")).hexdigest()
[docs] @staticmethod def load_simulated_target_system_prompt( *, objective: str, num_turns: int, simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, ) -> Optional[str]: """ Load and render the simulated target system prompt. If no path is provided, returns None (no system prompt). Validates that the template has required `objective` and `num_turns` parameters. Args: objective: The objective to render into the template. num_turns: The number of turns to render into the template. simulated_target_system_prompt_path: Optional path to the prompt YAML file. If None, no system prompt is used. Returns: The rendered system prompt string, or None if no path is provided. Raises: ValueError: If the template doesn't have required parameters. """ if simulated_target_system_prompt_path is None: return None template = SeedPrompt.from_yaml_with_required_parameters( template_path=simulated_target_system_prompt_path, required_parameters=["objective", "num_turns"], error_message="Simulated target system prompt must have objective and num_turns parameters", ) return template.render_template_value( objective=objective, num_turns=num_turns, )
@property def sequence_range(self) -> range: """ Get the range of sequence numbers this simulated conversation will occupy. Each turn generates 2 messages (user + assistant), so num_turns generates num_turns * 2 messages. If next_message_system_prompt_path is set, an additional user message is added at the end. Returns: A range object representing the sequence numbers. """ message_count = self.num_turns * 2 + (1 if self.next_message_system_prompt_path else 0) return range(self.sequence, self.sequence + message_count) def __repr__(self) -> str: has_next_msg = self.next_message_system_prompt_path is not None return ( f"<SeedSimulatedConversation(num_turns={self.num_turns}, sequence={self.sequence}, " f"next_message={has_next_msg}, " f"adversarial_path={self.adversarial_chat_system_prompt_path.name})>" )