Source code for pyrit.executor.attack.multi_turn.multi_turn_attack_strategy

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import logging
import uuid
from abc import ABC
from dataclasses import dataclass, field
from typing import List, Optional, TypeVar, overload

from pyrit.common.logger import logger
from pyrit.common.utils import get_kwarg_param
from pyrit.executor.attack.core import (
    AttackContext,
    AttackStrategy,
    AttackStrategyResultT,
)
from pyrit.models import (
    PromptRequestResponse,
    Score,
)

MultiTurnAttackStrategyContextT = TypeVar("MultiTurnAttackStrategyContextT", bound="MultiTurnAttackContext")


[docs] @dataclass class ConversationSession: """Session for conversations""" # Unique identifier of the main conversation between the attacker and model conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Separate identifier used when the attack leverages an adversarial chat adversarial_chat_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4()))
[docs] @dataclass class MultiTurnAttackContext(AttackContext): """Context for multi-turn attacks""" # Object holding all conversation-level identifiers for this attack session: ConversationSession = field(default_factory=lambda: ConversationSession()) # Counter of turns that have actually been executed so far executed_turns: int = 0 # Model response produced in the latest turn last_response: Optional[PromptRequestResponse] = None # Score assigned to the latest response by a scorer component last_score: Optional[Score] = None # Optional custom prompt that overrides the default one for the next turn custom_prompt: Optional[str] = None
class MultiTurnAttackStrategy(AttackStrategy[MultiTurnAttackStrategyContextT, AttackStrategyResultT], ABC): """ Strategy for executing single-turn attacks. This strategy is designed to handle attacks that consist of a single turn of interaction with the target model. """ def __init__(self, *, context_type: type[MultiTurnAttackStrategyContextT], logger: logging.Logger = logger): """ The base class for multi-turn attack strategies. Args: context_type (type[MultiTurnAttackContext]): The type of context this strategy will use logger (logging.Logger): Logger instance for logging events and messages """ super().__init__(context_type=context_type, logger=logger) @overload async def execute_async( self, *, objective: str, prepended_conversation: Optional[List[PromptRequestResponse]] = None, custom_prompt: Optional[str] = None, memory_labels: Optional[dict[str, str]] = None, **kwargs, ) -> AttackStrategyResultT: """ Execute the multi-turn attack strategy asynchronously with the provided parameters. Args: objective (str): The objective of the attack. prepended_conversation (Optional[List[PromptRequestResponse]]): Conversation to prepend. custom_prompt (Optional[str]): Custom prompt for the attack. memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. **kwargs: Additional parameters for the attack. Returns: AttackStrategyResultT: The result of the attack execution. """ ... @overload async def execute_async( self, **kwargs, ) -> AttackStrategyResultT: ... async def execute_async( self, **kwargs, ) -> AttackStrategyResultT: """ Execute the attack strategy asynchronously with the provided parameters. """ # Validate parameters before creating context custom_prompt = get_kwarg_param(kwargs=kwargs, param_name="custom_prompt", expected_type=str, required=False) return await super().execute_async(**kwargs, custom_prompt=custom_prompt)