Source code for pyrit.executor.attack.core.attack_parameters
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, TypeVar
from pyrit.models import Message, SeedGroup
AttackParamsT = TypeVar("AttackParamsT", bound="AttackParameters")
[docs]
@dataclass(frozen=True)
class AttackParameters:
"""
Immutable parameters for attack execution.
This class defines the standard contract for attack parameters. All attacks
at a given level of the hierarchy share the same parameter signature.
Attacks that don't accept certain parameters should use the `excluding()` factory
to create a derived params type without those fields. Attacks that need additional
parameters should extend this class with new fields.
"""
# Natural-language description of what the attack tries to achieve (required)
objective: str
# Optional message to send to the objective target (overrides objective if provided)
next_message: Optional[Message] = None
# Conversation that is automatically prepended to the target model
prepended_conversation: Optional[List[Message]] = None
# Additional labels that can be applied to the prompts throughout the attack
memory_labels: Optional[Dict[str, str]] = field(default_factory=dict)
[docs]
@classmethod
def from_seed_group(
cls: Type[AttackParamsT],
seed_group: SeedGroup,
**overrides: Any,
) -> AttackParamsT:
"""
Create an AttackParameters instance from a SeedGroup.
Extracts standard fields from the seed group and applies any overrides.
Raises ValueError if overrides contain fields not accepted by this params type.
Args:
seed_group: The seed group to extract parameters from.
**overrides: Field overrides to apply. Must be valid fields for this params type.
Returns:
An instance of this AttackParameters type.
Raises:
ValueError: If seed_group has no objective or if overrides contain invalid fields.
"""
# Get valid field names for this params type
valid_fields = {f.name for f in dataclasses.fields(cls)}
# Validate overrides don't contain invalid fields
invalid_fields = set(overrides.keys()) - valid_fields
if invalid_fields:
raise ValueError(
f"{cls.__name__} does not accept parameters: {invalid_fields}. " f"Accepted parameters: {valid_fields}"
)
# Extract objective (required)
if seed_group.objective is None:
raise ValueError("SeedGroup must have an objective")
# Build params dict, only including fields this class accepts
params: Dict[str, Any] = {}
if "objective" in valid_fields:
params["objective"] = seed_group.objective.value
if "next_message" in valid_fields:
params["next_message"] = seed_group.next_message
if "prepended_conversation" in valid_fields:
params["prepended_conversation"] = seed_group.prepended_conversation
if "memory_labels" in valid_fields:
params["memory_labels"] = {}
# Apply overrides (already validated above)
params.update(overrides)
return cls(**params)
[docs]
@classmethod
def excluding(cls, *field_names: str) -> Type["AttackParameters"]:
"""
Create a new AttackParameters subclass that excludes the specified fields.
This factory method creates a frozen dataclass without the specified fields.
The resulting class inherits the `from_seed_group()` behavior and will raise
if excluded fields are passed as overrides.
Args:
*field_names: Names of fields to exclude from the new params type.
Returns:
A new AttackParameters subclass without the specified fields.
Raises:
ValueError: If any field_name is not a valid field of this class.
Example:
RolePlayAttackParameters = AttackParameters.excluding("next_message", "prepended_conversation")
"""
# Validate all field names exist
current_fields = {f.name for f in dataclasses.fields(cls)}
invalid = set(field_names) - current_fields
if invalid:
raise ValueError(f"Cannot exclude non-existent fields: {invalid}. Valid fields: {current_fields}")
# Build new fields list excluding the specified ones
new_fields: List[tuple] = []
for f in dataclasses.fields(cls):
if f.name not in field_names:
# Preserve field defaults
if f.default is not dataclasses.MISSING:
new_fields.append((f.name, f.type, field(default=f.default)))
elif f.default_factory is not dataclasses.MISSING:
new_fields.append((f.name, f.type, field(default_factory=f.default_factory)))
else:
new_fields.append((f.name, f.type))
# Generate a descriptive class name
excluded_str = "_".join(sorted(field_names))
class_name = f"{cls.__name__}Excluding_{excluded_str}"
# Create the new dataclass
new_cls = dataclasses.make_dataclass(
class_name,
new_fields,
frozen=True,
)
# Copy the from_seed_group method to the new class
# We need to bind it as a classmethod on the new class
new_cls.from_seed_group = classmethod( # type: ignore[attr-defined,method-assign]
lambda c, sg, **ov: cls._from_seed_group_impl(c, sg, **ov)
)
return new_cls # type: ignore[return-value]
@classmethod
def _from_seed_group_impl(
cls: Type[AttackParamsT],
target_cls: Type[AttackParamsT],
seed_group: SeedGroup,
**overrides: Any,
) -> AttackParamsT:
"""
Implement from_seed_group for dynamically created classes.
Args:
target_cls: The actual class to instantiate (may be a dynamically created subclass).
seed_group: The seed group to extract parameters from.
**overrides: Field overrides to apply.
Returns:
An instance of target_cls.
Raises:
ValueError: If seed_group has no objective or if overrides contain invalid fields.
"""
# Get valid field names for the target class
valid_fields = {f.name for f in dataclasses.fields(target_cls)}
# Validate overrides don't contain invalid fields
invalid_fields = set(overrides.keys()) - valid_fields
if invalid_fields:
raise ValueError(
f"{target_cls.__name__} does not accept parameters: {invalid_fields}. "
f"Accepted parameters: {valid_fields}"
)
# Extract objective (required)
if seed_group.objective is None:
raise ValueError("SeedGroup must have an objective")
# Build params dict, only including fields the target class accepts
params: Dict[str, Any] = {}
if "objective" in valid_fields:
params["objective"] = seed_group.objective.value
if "next_message" in valid_fields:
params["next_message"] = seed_group.next_message
if "prepended_conversation" in valid_fields:
params["prepended_conversation"] = seed_group.prepended_conversation
if "memory_labels" in valid_fields:
params["memory_labels"] = {}
# Apply overrides (already validated above)
params.update(overrides)
return target_cls(**params)