Source code for pyrit.executor.attack.core.attack_strategy

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import dataclasses
import logging
import time
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Generic, List, Optional, Type, TypeVar, cast, overload

from pyrit.common.logger import logger
from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT
from pyrit.executor.core import (
    Strategy,
    StrategyContext,
    StrategyEvent,
    StrategyEventData,
    StrategyEventHandler,
)
from pyrit.memory.central_memory import CentralMemory
from pyrit.models import (
    AttackOutcome,
    AttackResult,
    ConversationReference,
    Message,
)
from pyrit.prompt_target import PromptTarget

AttackStrategyContextT = TypeVar("AttackStrategyContextT", bound="AttackContext")
AttackStrategyResultT = TypeVar("AttackStrategyResultT", bound="AttackResult")


[docs] @dataclass class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): """ Base class for all attack contexts. This class holds both the immutable attack parameters and the mutable execution state. The params field contains caller-provided inputs, while other fields track execution progress. Attacks that generate certain values internally (e.g., RolePlayAttack generates next_message and prepended_conversation) can set the mutable override fields (_next_message_override, _prepended_conversation_override) during _setup_async. """ # Immutable parameters from the caller params: AttackParamsT # Start time of the attack execution start_time: float = 0.0 # Conversations relevant while the attack is running related_conversations: set[ConversationReference] = field(default_factory=set) # Mutable overrides for attacks that generate these values internally _next_message_override: Optional[Message] = None _prepended_conversation_override: Optional[List[Message]] = None _memory_labels_override: Optional[Dict[str, str]] = None # Convenience properties that delegate to params or overrides @property def objective(self) -> str: """Natural-language description of what the attack tries to achieve.""" return self.params.objective @property def memory_labels(self) -> Dict[str, str]: """Additional labels that can be applied to the prompts throughout the attack.""" # Check override first (for attacks that merge labels) if self._memory_labels_override is not None: return self._memory_labels_override return self.params.memory_labels or {} @memory_labels.setter def memory_labels(self, value: Dict[str, str]) -> None: """Set the memory labels (for attacks that merge strategy + context labels).""" self._memory_labels_override = value @property def prepended_conversation(self) -> List[Message]: """Conversation that is automatically prepended to the target model.""" # Check override first (for attacks that generate internally) if self._prepended_conversation_override is not None: return self._prepended_conversation_override # Then check params if hasattr(self.params, "prepended_conversation") and self.params.prepended_conversation: return self.params.prepended_conversation return [] @prepended_conversation.setter def prepended_conversation(self, value: List[Message]) -> None: """Set the prepended conversation (for attacks that generate internally).""" self._prepended_conversation_override = value @property def next_message(self) -> Optional[Message]: """Optional message to send to the objective target.""" # Check override first (for attacks that generate internally) if self._next_message_override is not None: return self._next_message_override # Then check params if hasattr(self.params, "next_message"): return self.params.next_message return None @next_message.setter def next_message(self, value: Optional[Message]) -> None: """Set the next message (for attacks that generate internally).""" self._next_message_override = value
class _DefaultAttackStrategyEventHandler(StrategyEventHandler[AttackStrategyContextT, AttackStrategyResultT]): """ Default event handler for attack strategies. Handles events during the execution of an attack strategy. """ def __init__(self, logger: logging.Logger = logger): """ Initialize the default event handler with a logger. Args: logger (logging.Logger): Logger instance for logging events. """ self._logger = logger self._events = { StrategyEvent.ON_PRE_EXECUTE: self._on_pre_execute, StrategyEvent.ON_POST_EXECUTE: self._on_post_execute, } self._memory = CentralMemory.get_memory_instance() async def on_event(self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]) -> None: """ Handle an event during the attack strategy execution. Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. """ if event_data.event in self._events: handler = self._events[event_data.event] await handler(event_data) else: await self._on(event_data) async def _on(self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]) -> None: """ Handle specific events during the attack strategy execution. Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. """ self._logger.debug(f"Attack is in '{event_data.event.value}' stage for {self.__class__.__name__}") async def _on_pre_execute( self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] ) -> None: """ Handle pre-execution logic before the attack strategy runs. Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. Raises: ValueError: If the attack context is None. """ if not event_data.context: raise ValueError("Attack context is None. Cannot proceed with execution.") # Initialize start time for execution event_data.context.start_time = time.perf_counter() # Log the start of the attack self._logger.info(f"Starting attack: {event_data.context.objective}") async def _on_post_execute( self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] ) -> None: """ Handle post-execution logic after the attack strategy has run. Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. Raises: ValueError: If the attack result is None. """ if not event_data.result: raise ValueError("Attack result is None. Cannot log or record the outcome.") end_time = time.perf_counter() execution_time_ms = int((end_time - event_data.context.start_time) * 1000) event_data.result.execution_time_ms = execution_time_ms self._logger.debug(f"Attack execution completed in {execution_time_ms}ms") self._log_attack_outcome(event_data.result) self._memory.add_attack_results_to_memory(attack_results=[event_data.result]) def _log_attack_outcome(self, result: AttackResult) -> None: """ Log the outcome of the attack. Args: result (AttackResult): The result of the attack containing outcome and reason. """ attack_name = self.__class__.__name__ reason = f"Reason: {result.outcome_reason or 'Not specified'}" if result.outcome == AttackOutcome.SUCCESS: message = f"{attack_name} achieved the objective. {reason}" elif result.outcome == AttackOutcome.UNDETERMINED: message = f"{attack_name} outcome is undetermined. {reason}" else: message = f"{attack_name} did not achieve the objective. {reason}" self._logger.info(message)
[docs] class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC): """ Abstract base class for attack strategies. Defines the interface for executing attacks and handling results. """
[docs] def __init__( self, *, objective_target: PromptTarget, context_type: type[AttackStrategyContextT], params_type: Type[AttackParamsT] = AttackParameters, # type: ignore[assignment] logger: logging.Logger = logger, ): """ Initialize the attack strategy with a specific context type and logger. Args: objective_target (PromptTarget): The target system to attack. context_type (type[AttackStrategyContextT]): The type of context this strategy operates on. params_type (Type[AttackParamsT]): The type of parameters this strategy accepts. Defaults to AttackParameters. Use AttackParameters.excluding() to create a params type that rejects certain fields. logger (logging.Logger): Logger instance for logging events. """ super().__init__( context_type=context_type, event_handler=_DefaultAttackStrategyEventHandler[AttackStrategyContextT, AttackStrategyResultT]( logger=logger ), logger=logger, ) self._objective_target = objective_target self._params_type = params_type
@property def params_type(self) -> Type[AttackParameters]: """ Get the parameters type for this attack strategy. Returns: Type[AttackParameters]: The parameters type this strategy accepts. """ return self._params_type
[docs] def get_objective_target(self) -> PromptTarget: """ Get the objective target for this attack strategy. Returns: PromptTarget: The target system being attacked. """ return self._objective_target
[docs] def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: """ Get the attack scoring configuration used by this strategy. Returns: Optional[AttackScoringConfig]: The scoring configuration, or None if not applicable. Note: Subclasses that use scoring should override this method to return their scoring configuration. The default implementation returns None. """ return None
@overload async def execute_async( self, *, objective: str, next_message: Optional[Message] = None, prepended_conversation: Optional[List[Message]] = None, memory_labels: Optional[dict[str, str]] = None, **kwargs, ) -> AttackStrategyResultT: ... @overload async def execute_async( self, **kwargs, ) -> AttackStrategyResultT: ...
[docs] async def execute_async( self, **kwargs, ) -> AttackStrategyResultT: """ Execute the attack strategy asynchronously with the provided parameters. This method provides a stable contract for all attacks. The signature includes all standard parameters (objective, next_message, prepended_conversation, memory_labels). Attacks that don't accept certain parameters will raise ValueError if those parameters are provided. Args: objective (str): The objective of the attack. next_message (Optional[Message]): Message to send to the target. prepended_conversation (Optional[List[Message]]): Conversation to prepend. memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. **kwargs: Additional context-specific parameters (conversation_id, system_prompt, etc.). Returns: AttackStrategyResultT: The result of the attack execution. Raises: ValueError: If required parameters are missing or if unsupported parameters are provided. """ # Get valid field names for params and context params_fields = {f.name for f in dataclasses.fields(self._params_type)} context_fields = {f.name for f in dataclasses.fields(self._context_type)} - {"params"} # Separate kwargs into params kwargs and context kwargs params_kwargs = {} context_kwargs = {} unknown_fields = set() for k, v in kwargs.items(): if v is None: continue # Skip None values if k in params_fields: params_kwargs[k] = v elif k in context_fields: context_kwargs[k] = v else: unknown_fields.add(k) # Validate no unknown fields if unknown_fields: raise ValueError( f"{self.__class__.__name__} does not accept parameters: {unknown_fields}. " f"Accepted attack parameters: {params_fields}. " f"Accepted context parameters: {context_fields}" ) # Validate objective is provided if "objective" not in params_kwargs: raise ValueError("objective is required") # Construct params instance params = self._params_type(**params_kwargs) # Create context with params and context-specific kwargs # Note: We use cast here because the type checker doesn't know that _context_type # (which is AttackContext or a subclass) always accepts 'params' as a keyword argument. context = cast( AttackStrategyContextT, self._context_type(params=params, **context_kwargs) ) # type: ignore[call-arg] return await self.execute_with_context_async(context=context)