Source code for pyrit.executor.attack.core.attack_executor

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Simplified AttackExecutor that uses AttackParameters directly.

This is the new, cleaner design that leverages the params_type architecture.
"""

import asyncio
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar, cast

from pyrit.common.logger import logger
from pyrit.executor.attack.core import (
    AttackStrategy,
    AttackStrategyContextT,
    AttackStrategyResultT,
)
from pyrit.executor.attack.core.attack_parameters import AttackParameters
from pyrit.executor.attack.multi_turn.multi_turn_attack_strategy import (
    MultiTurnAttackContext,
)
from pyrit.executor.attack.single_turn.single_turn_attack_strategy import (
    SingleTurnAttackContext,
)
from pyrit.models import Message, SeedGroup

AttackResultT = TypeVar("AttackResultT")


[docs] @dataclass class AttackExecutorResult(Generic[AttackResultT]): """ Result container for attack execution, supporting both full and partial completion. This class holds results from parallel attack execution. It is iterable and behaves like a list in the common case where all objectives complete successfully. When some objectives don't complete (throw exceptions), access incomplete_objectives to retrieve the failures, or use raise_if_incomplete() to raise the first exception. Note: "completed" means the execution finished, not that the attack objective was achieved. """ completed_results: List[AttackResultT] incomplete_objectives: List[tuple[str, BaseException]] def __iter__(self): """ Iterate over completed results. Returns: Iterator over completed attack results. """ return iter(self.completed_results) def __len__(self) -> int: """Return number of completed results.""" return len(self.completed_results) def __getitem__(self, index: int) -> AttackResultT: """ Access completed results by index. Returns: The attack result at the specified index. """ return self.completed_results[index] @property def has_incomplete(self) -> bool: """Check if any objectives didn't complete execution.""" return len(self.incomplete_objectives) > 0 @property def all_completed(self) -> bool: """Check if all objectives completed execution.""" return len(self.incomplete_objectives) == 0 @property def exceptions(self) -> List[BaseException]: """Get all exceptions from incomplete objectives.""" return [exception for _, exception in self.incomplete_objectives]
[docs] def raise_if_incomplete(self) -> None: """Raise the first exception if any objectives are incomplete.""" if self.incomplete_objectives: raise self.incomplete_objectives[0][1]
[docs] def get_results(self) -> List[AttackResultT]: """ Get completed results, raising if any incomplete. Returns: List of completed attack results. """ self.raise_if_incomplete() return self.completed_results
[docs] class AttackExecutor: """ Manages the execution of attack strategies with support for parallel execution. The AttackExecutor provides controlled execution of attack strategies with concurrency limiting. It uses the attack's params_type to create parameters from seed groups. """
[docs] def __init__(self, *, max_concurrency: int = 1): """ Initialize the attack executor with configurable concurrency control. Args: max_concurrency: Maximum number of concurrent attack executions (default: 1). Raises: ValueError: If max_concurrency is not a positive integer. """ if max_concurrency <= 0: raise ValueError(f"max_concurrency must be a positive integer, got {max_concurrency}") self._max_concurrency = max_concurrency
[docs] async def execute_attack_from_seed_groups_async( self, *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], seed_groups: Sequence[SeedGroup], field_overrides: Optional[Sequence[Dict[str, Any]]] = None, return_partial_on_failure: bool = False, **broadcast_fields, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel, extracting parameters from SeedGroups. Uses the attack's params_type.from_seed_group() to extract parameters, automatically handling which fields the attack accepts. Args: attack: The attack strategy to execute. seed_groups: SeedGroups containing objectives and optional prompts. field_overrides: Optional per-seed-group field overrides. If provided, must match the length of seed_groups. Each dict is passed to from_seed_group() as overrides. return_partial_on_failure: If True, returns partial results when some objectives fail. If False (default), raises the first exception. **broadcast_fields: Fields applied to all seed groups (e.g., memory_labels). Per-seed-group field_overrides take precedence. Returns: AttackExecutorResult with completed results and any incomplete objectives. Raises: ValueError: If seed_groups is empty or field_overrides length doesn't match. BaseException: If return_partial_on_failure=False and any objective fails. """ if not seed_groups: raise ValueError("At least one seed_group must be provided") if field_overrides and len(field_overrides) != len(seed_groups): raise ValueError( f"field_overrides length ({len(field_overrides)}) must match " f"seed_groups length ({len(seed_groups)})" ) params_type = attack.params_type # Build params list using from_seed_group params_list: List[AttackParameters] = [] for i, sg in enumerate(seed_groups): # Start with broadcast fields, then layer on per-seed-group overrides combined_overrides = dict(broadcast_fields) if field_overrides: combined_overrides.update(field_overrides[i]) params = params_type.from_seed_group(sg, **combined_overrides) params_list.append(params) return await self._execute_with_params_list_async( attack=attack, params_list=params_list, return_partial_on_failure=return_partial_on_failure, )
[docs] async def execute_attack_async( self, *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], objectives: Sequence[str], field_overrides: Optional[Sequence[Dict[str, Any]]] = None, return_partial_on_failure: bool = False, **broadcast_fields, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel for each objective. Creates AttackParameters directly from objectives and field values. Args: attack: The attack strategy to execute. objectives: List of attack objectives. field_overrides: Optional per-objective field overrides. If provided, must match the length of objectives. return_partial_on_failure: If True, returns partial results when some objectives fail. If False (default), raises the first exception. **broadcast_fields: Fields applied to all objectives (e.g., memory_labels). Per-objective field_overrides take precedence. Returns: AttackExecutorResult with completed results and any incomplete objectives. Raises: ValueError: If objectives is empty or field_overrides length doesn't match. BaseException: If return_partial_on_failure=False and any objective fails. """ if not objectives: raise ValueError("At least one objective must be provided") if field_overrides and len(field_overrides) != len(objectives): raise ValueError( f"field_overrides length ({len(field_overrides)}) must match " f"objectives length ({len(objectives)})" ) params_type = attack.params_type # Build params list params_list: List[AttackParameters] = [] for i, objective in enumerate(objectives): # Start with broadcast fields fields = dict(broadcast_fields) # Apply per-objective overrides if field_overrides: fields.update(field_overrides[i]) # Add objective fields["objective"] = objective params = params_type(**fields) params_list.append(params) return await self._execute_with_params_list_async( attack=attack, params_list=params_list, return_partial_on_failure=return_partial_on_failure, )
async def _execute_with_params_list_async( self, *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], params_list: Sequence[AttackParameters], return_partial_on_failure: bool = False, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel with a list of pre-built parameters. This is the core execution method. It creates contexts from params and runs attacks with concurrency control. Args: attack: The attack strategy to execute. params_list: List of AttackParameters, one per execution. return_partial_on_failure: If True, returns partial results on failure. Returns: AttackExecutorResult with completed results and any incomplete objectives. """ semaphore = asyncio.Semaphore(self._max_concurrency) async def run_one(params: AttackParameters) -> AttackStrategyResultT: async with semaphore: # Create context with params context = cast( AttackStrategyContextT, attack._context_type(params=params), # type: ignore[call-arg] ) return await attack.execute_with_context_async(context=context) tasks = [run_one(p) for p in params_list] results_or_exceptions = await asyncio.gather(*tasks, return_exceptions=True) return self._process_execution_results( objectives=[p.objective for p in params_list], results_or_exceptions=list(results_or_exceptions), return_partial_on_failure=return_partial_on_failure, ) def _process_execution_results( self, *, objectives: Sequence[str], results_or_exceptions: List[Any], return_partial_on_failure: bool, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Process results from parallel execution into an AttackExecutorResult. Args: objectives: The objectives that were executed. results_or_exceptions: Results or exceptions from asyncio.gather. return_partial_on_failure: Whether to return partial results on failure. Returns: AttackExecutorResult with completed and incomplete results. Raises: BaseException: If return_partial_on_failure=False and any failed. """ completed: List[AttackStrategyResultT] = [] incomplete: List[tuple[str, BaseException]] = [] for objective, result in zip(objectives, results_or_exceptions): if isinstance(result, BaseException): incomplete.append((objective, result)) else: completed.append(result) executor_result: AttackExecutorResult[AttackStrategyResultT] = AttackExecutorResult( completed_results=completed, incomplete_objectives=incomplete, ) if not return_partial_on_failure: executor_result.raise_if_incomplete() return executor_result # ========================================================================= # Deprecated methods - these will be removed in a future version # ========================================================================= _SingleTurnContextT = TypeVar("_SingleTurnContextT", bound=SingleTurnAttackContext) _MultiTurnContextT = TypeVar("_MultiTurnContextT", bound=MultiTurnAttackContext)
[docs] async def execute_multi_objective_attack_async( self, *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], objectives: List[str], prepended_conversation: Optional[List[Message]] = None, memory_labels: Optional[Dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute the same attack strategy with multiple objectives against the same target in parallel. .. deprecated:: Use :meth:`execute_attack_async` instead. This method will be removed in a future version. Args: attack: The attack strategy to use for all objectives. objectives: List of attack objectives to test. prepended_conversation: Conversation to prepend to the target model. memory_labels: Additional labels that can be applied to the prompts. return_partial_on_failure: If True, returns partial results on failure. **attack_params: Additional parameters specific to the attack strategy. Returns: AttackExecutorResult with completed results and any incomplete objectives. """ logger.warning( "execute_multi_objective_attack_async is deprecated and will disappear in 0.13.0. " "Use execute_attack_async instead." ) # Build field_overrides if prepended_conversation is provided (broadcast to all) field_overrides: Optional[List[Dict[str, Any]]] = None if prepended_conversation: field_overrides = [{"prepended_conversation": prepended_conversation} for _ in objectives] return await self.execute_attack_async( attack=attack, objectives=objectives, field_overrides=field_overrides, return_partial_on_failure=return_partial_on_failure, memory_labels=memory_labels, **attack_params, )
[docs] async def execute_single_turn_attacks_async( self, *, attack: AttackStrategy["_SingleTurnContextT", AttackStrategyResultT], objectives: List[str], messages: Optional[List[Message]] = None, prepended_conversations: Optional[List[List[Message]]] = None, memory_labels: Optional[Dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute a batch of single-turn attacks with multiple objectives. .. deprecated:: Use :meth:`execute_attack_async` instead. This method will be removed in a future version. Args: attack: The single-turn attack strategy to use. objectives: List of attack objectives to test. messages: List of messages to use for this execution (per-objective). prepended_conversations: Conversations to prepend to each objective (per-objective). memory_labels: Additional labels that can be applied to the prompts. return_partial_on_failure: If True, returns partial results on failure. **attack_params: Additional parameters specific to the attack strategy. Returns: AttackExecutorResult with completed results and any incomplete objectives. Raises: TypeError: If the attack does not use SingleTurnAttackContext. """ logger.warning( "execute_single_turn_attacks_async is deprecated and will disappear in 0.13.0. " "Use execute_attack_async instead." ) # Validate that the attack uses SingleTurnAttackContext if hasattr(attack, "_context_type") and not issubclass(attack._context_type, SingleTurnAttackContext): raise TypeError( f"Attack strategy {attack.__class__.__name__} must use SingleTurnAttackContext or a subclass of it." ) # Build field_overrides from per-objective parameters field_overrides: Optional[List[Dict[str, Any]]] = None if messages or prepended_conversations: field_overrides = [] for i in range(len(objectives)): override: Dict[str, Any] = {} if messages and i < len(messages): override["next_message"] = messages[i] if prepended_conversations and i < len(prepended_conversations): override["prepended_conversation"] = prepended_conversations[i] field_overrides.append(override) return await self.execute_attack_async( attack=attack, objectives=objectives, field_overrides=field_overrides, return_partial_on_failure=return_partial_on_failure, memory_labels=memory_labels, **attack_params, )
[docs] async def execute_multi_turn_attacks_async( self, *, attack: AttackStrategy["_MultiTurnContextT", AttackStrategyResultT], objectives: List[str], messages: Optional[List[Message]] = None, prepended_conversations: Optional[List[List[Message]]] = None, memory_labels: Optional[Dict[str, str]] = None, return_partial_on_failure: bool = False, **attack_params, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute a batch of multi-turn attacks with multiple objectives. .. deprecated:: Use :meth:`execute_attack_async` instead. This method will be removed in a future version. Args: attack: The multi-turn attack strategy to use. objectives: List of attack objectives to test. messages: List of messages to use for this execution (per-objective). prepended_conversations: Conversations to prepend to each objective (per-objective). memory_labels: Additional labels that can be applied to the prompts. return_partial_on_failure: If True, returns partial results on failure. **attack_params: Additional parameters specific to the attack strategy. Returns: AttackExecutorResult with completed results and any incomplete objectives. Raises: TypeError: If the attack does not use MultiTurnAttackContext. """ logger.warning( "execute_multi_turn_attacks_async is deprecated and will disappear in 0.13.0. " "Use execute_attack_async instead." ) # Validate that the attack uses MultiTurnAttackContext if hasattr(attack, "_context_type") and not issubclass(attack._context_type, MultiTurnAttackContext): raise TypeError( f"Attack strategy {attack.__class__.__name__} must use MultiTurnAttackContext or a subclass of it." ) # Build field_overrides from per-objective parameters field_overrides: Optional[List[Dict[str, Any]]] = None if messages or prepended_conversations: field_overrides = [] for i in range(len(objectives)): override: Dict[str, Any] = {} if messages and i < len(messages): override["next_message"] = messages[i] if prepended_conversations and i < len(prepended_conversations): override["prepended_conversation"] = prepended_conversations[i] field_overrides.append(override) return await self.execute_attack_async( attack=attack, objectives=objectives, field_overrides=field_overrides, return_partial_on_failure=return_partial_on_failure, memory_labels=memory_labels, **attack_params, )