Source code for pyrit.models.seeds.seed_group

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
SeedGroup - Container for grouping seeds together.

Provides functionality for grouping prompts, objectives, and simulated conversation
configurations together with consistent group IDs and roles.
"""

from __future__ import annotations

import logging
import uuid
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Union

from pyrit.common.yaml_loadable import YamlLoadable
from pyrit.models.message import Message
from pyrit.models.message_piece import MessagePiece
from pyrit.models.seeds.seed import Seed
from pyrit.models.seeds.seed_objective import SeedObjective
from pyrit.models.seeds.seed_prompt import SeedPrompt
from pyrit.models.seeds.seed_simulated_conversation import SeedSimulatedConversation

logger = logging.getLogger(__name__)


[docs] class SeedGroup(YamlLoadable): """ A container for grouping prompts that need to be sent together. This class handles: - Grouping of SeedPrompt, SeedObjective, and SeedSimulatedConversation - Consistent group IDs and roles across seeds - Prepended conversation and next message extraction - Validation of sequence overlaps between SeedPrompts and SeedSimulatedConversation All prompts in the group share the same `prompt_group_id`. """ seeds: List[Seed]
[docs] def __init__( self, *, seeds: Sequence[Union[Seed, Dict[str, Any]]], ): """ Initialize a SeedGroup. Args: seeds: Sequence of seeds. Can include: - SeedObjective (or dict with seed_type="objective") - SeedSimulatedConversation (or dict with seed_type="simulated_conversation") - SeedPrompt for prompts (or dict with seed_type="prompt" or no seed_type) Note: is_objective and is_simulated_conversation are deprecated since 0.13.0. Raises: ValueError: If seeds is empty. ValueError: If multiple objectives are provided. ValueError: If SeedPrompt sequences overlap with SeedSimulatedConversation range. """ if not seeds: raise ValueError("SeedGroup cannot be empty.") self.seeds = [] for seed in seeds: if isinstance(seed, Seed): self.seeds.append(seed) elif isinstance(seed, dict): # Support new seed_type field with backward compatibility for deprecated fields seed_type = seed.pop("seed_type", None) is_objective = seed.pop("is_objective", False) if is_objective: warnings.warn( "is_objective is deprecated since 0.13.0. Use seed_type='objective' instead.", DeprecationWarning, stacklevel=2, ) # Only error if seed_type is explicitly set to a conflicting value if seed_type is not None and seed_type != "objective": raise ValueError("Conflicting seed_type and is_objective values.") if seed_type == "simulated_conversation": # SeedSimulatedConversation doesn't use data_type (always text) seed.pop("data_type", None) self.seeds.append(SeedSimulatedConversation.from_dict(seed)) elif seed_type == "objective" or (seed_type is None and is_objective): # SeedObjective doesn't use data_type (always text) seed.pop("data_type", None) self.seeds.append(SeedObjective(**seed)) else: self.seeds.append(SeedPrompt(**seed)) else: raise ValueError(f"Invalid seed type: {type(seed)}") # Validate and normalize the seeds self.validate() # Extract simulated conversation config self._simulated_conversation_config = self._get_simulated_conversation() # Reconstruct seeds in canonical order: objective, simulated_conversation, sorted prompts objective = self._get_objective() simulated_conv = self._simulated_conversation_config sorted_prompts = sorted(self.prompts, key=lambda p: p.sequence if p.sequence is not None else 0) self.seeds = [] if objective: self.seeds.append(objective) if simulated_conv: self.seeds.append(simulated_conv) self.seeds.extend(sorted_prompts)
# ========================================================================= # Validation # =========================================================================
[docs] def validate(self) -> None: """ Validate the seed group state. This method can be called after external modifications to seeds to ensure the group remains in a valid state. It is automatically called during initialization. Raises: ValueError: If validation fails. """ if not self.seeds: raise ValueError("SeedGroup cannot be empty.") self._enforce_consistent_group_id() self._enforce_consistent_role() self._enforce_max_one_objective() self._enforce_max_one_simulated_conversation() self._enforce_no_sequence_overlap_with_simulated()
def _enforce_max_one_objective(self) -> None: """Ensure at most one objective is present.""" if len([s for s in self.seeds if isinstance(s, SeedObjective)]) > 1: raise ValueError("SeedGroup can only have one objective.") def _enforce_max_one_simulated_conversation(self) -> None: """Ensure at most one simulated conversation is present.""" if len([s for s in self.seeds if isinstance(s, SeedSimulatedConversation)]) > 1: raise ValueError("SeedGroup can only have one simulated conversation.") def _enforce_consistent_group_id(self) -> None: """ Ensure all seeds share the same group ID. If any seeds have a group ID, all must match. If none have one, assigns a new UUID. Raises: ValueError: If multiple different group IDs exist. """ existing_group_ids = {seed.prompt_group_id for seed in self.seeds if seed.prompt_group_id is not None} if len(existing_group_ids) > 1: raise ValueError("Inconsistent group IDs found across seeds.") elif len(existing_group_ids) == 1: group_id = existing_group_ids.pop() for seed in self.seeds: seed.prompt_group_id = group_id else: new_group_id = uuid.uuid4() for seed in self.seeds: seed.prompt_group_id = new_group_id def _enforce_consistent_role(self) -> None: """ Ensure all prompts in a sequence have consistent roles. Raises: ValueError: If roles are inconsistent within a sequence. ValueError: If no roles are set in a multi-sequence group. """ grouped_prompts = defaultdict(list) for prompt in self.prompts: grouped_prompts[prompt.sequence].append(prompt) num_sequences = len(grouped_prompts) for sequence, prompts in grouped_prompts.items(): roles = {prompt.role for prompt in prompts if prompt.role is not None} if not roles and num_sequences > 1: raise ValueError( f"No roles set for sequence {sequence} in a multi-sequence group. " "Please ensure at least one prompt within a sequence has an assigned role." ) if len(roles) > 1: raise ValueError(f"Inconsistent roles found for sequence {sequence}: {roles}") role = roles.pop() if roles else "user" for prompt in prompts: prompt.role = role def _enforce_no_sequence_overlap_with_simulated(self) -> None: """ Ensure SeedPrompt sequences don't overlap with SeedSimulatedConversation range. When a SeedSimulatedConversation is present, it will generate turns that occupy sequence numbers [config.sequence, config.sequence + config.num_turns * 2 - 1]. SeedPrompts must not have sequences that fall within this range. Raises: ValueError: If any SeedPrompt sequence overlaps with the simulated range. """ simulated_config = self._get_simulated_conversation() if simulated_config is None: return simulated_range = simulated_config.sequence_range for prompt in self.prompts: if prompt.sequence in simulated_range: raise ValueError( f"SeedPrompt sequence {prompt.sequence} overlaps with SeedSimulatedConversation " f"range {list(simulated_range)}. Adjust the SeedPrompt sequence or the " f"SeedSimulatedConversation sequence/num_turns to avoid overlap." ) # ========================================================================= # Seed Accessors # ========================================================================= def _get_objective(self) -> Optional[SeedObjective]: """Get the objective seed if present.""" for seed in self.seeds: if isinstance(seed, SeedObjective): return seed return None def _get_simulated_conversation(self) -> Optional[SeedSimulatedConversation]: """Get the simulated conversation seed if present.""" for seed in self.seeds: if isinstance(seed, SeedSimulatedConversation): return seed return None @property def prompts(self) -> Sequence[SeedPrompt]: """Get all SeedPrompt instances from this group.""" return [seed for seed in self.seeds if isinstance(seed, SeedPrompt)] @property def objective(self) -> Optional[SeedObjective]: """Get the objective for this group.""" return self._get_objective() @property def harm_categories(self) -> List[str]: """ Returns a deduplicated list of all harm categories from all seeds. Returns: List of harm categories with duplicates removed. """ categories: List[str] = [] for seed in self.seeds: if seed.harm_categories: categories.extend(seed.harm_categories) return list(set(categories)) # ========================================================================= # Simulated Conversation # ========================================================================= @property def simulated_conversation_config(self) -> Optional[SeedSimulatedConversation]: """Get the simulated conversation configuration if set.""" return self._simulated_conversation_config @property def has_simulated_conversation(self) -> bool: """Check if this group uses simulated conversation generation.""" return self._simulated_conversation_config is not None # ========================================================================= # Message Extraction # ========================================================================= @property def prepended_conversation(self) -> Optional[List[Message]]: """ Returns Messages that should be prepended as conversation history. Returns all messages except the last user sequence. Returns: Messages for conversation history, or None if empty. """ if not self.prompts: return None last_role = self._get_last_sequence_role() unique_sequences = sorted({prompt.sequence for prompt in self.prompts}) if last_role == "user": if len(unique_sequences) <= 1: return None last_sequence = unique_sequences[-1] prepended_prompts = [p for p in self.prompts if p.sequence != last_sequence] if not prepended_prompts: return None return self._prompts_to_messages(prepended_prompts) else: return self._prompts_to_messages(list(self.prompts)) @property def next_message(self) -> Optional[Message]: """ Returns a Message containing only the last turn's prompts if it's a user message. Returns: Message for the current/last turn if user role, or None otherwise. """ if not self.prompts: return None last_role = self._get_last_sequence_role() if last_role != "user": return None unique_sequences = sorted({prompt.sequence for prompt in self.prompts}) last_sequence = unique_sequences[-1] current_turn_prompts = [p for p in self.prompts if p.sequence == last_sequence] if not current_turn_prompts: return None messages = self._prompts_to_messages(current_turn_prompts) return messages[0] if messages else None @property def user_messages(self) -> List[Message]: """ Returns all prompts as user Messages, one per sequence. Returns: All user messages in sequence order, or empty list if no prompts. """ if not self.prompts: return [] return self._prompts_to_messages(list(self.prompts)) def _get_last_sequence_role(self) -> Optional[str]: """ Get the role of the last sequence. Returns: The role of the last sequence, or None if no prompts exist. """ if not self.prompts: return None unique_sequences = sorted({prompt.sequence for prompt in self.prompts}) last_sequence = unique_sequences[-1] last_sequence_prompts = [p for p in self.prompts if p.sequence == last_sequence] return last_sequence_prompts[0].role if last_sequence_prompts else None def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> List[Message]: """ Convert a sequence of SeedPrompts to Messages. Groups prompts by sequence number and creates one Message per sequence. Args: prompts: The prompts to convert. Returns: Messages created from the prompts. """ sequence_groups = defaultdict(list) for prompt in prompts: sequence_groups[prompt.sequence].append(prompt) messages = [] for sequence in sorted(sequence_groups.keys()): sequence_prompts = sequence_groups[sequence] message_pieces = [] for prompt in sequence_prompts: role = prompt.role or "user" if role == "assistant": role = "simulated_assistant" piece = MessagePiece( role=role, original_value=prompt.value, original_value_data_type=prompt.data_type or "text", prompt_target_identifier=None, conversation_id=str(prompt.prompt_group_id), sequence=sequence, prompt_metadata=prompt.metadata, ) message_pieces.append(piece) messages.append(Message(message_pieces=message_pieces)) return messages # ========================================================================= # Utility Methods # =========================================================================
[docs] def render_template_value(self, **kwargs: Any) -> None: """ Renders seed values as templates with provided parameters. Args: kwargs: Key-value pairs to replace in seed values. """ for seed in self.seeds: seed.value = seed.render_template_value(**kwargs)
[docs] def is_single_turn(self) -> bool: """Check if this is a single-turn group (single request without objective).""" return self.is_single_request() and not self.objective
[docs] def is_single_request(self) -> bool: """Check if all prompts are in a single sequence.""" unique_sequences = {prompt.sequence for prompt in self.prompts} return len(unique_sequences) == 1
[docs] def is_single_part_single_text_request(self) -> bool: """Check if this is a single text prompt.""" return len(self.prompts) == 1 and self.prompts[0].data_type == "text"
def __repr__(self) -> str: sim_info = " (simulated)" if self.has_simulated_conversation else "" return f"<SeedGroup(seeds={len(self.seeds)}{sim_info})>"