# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import hashlib
import json
from dataclasses import Field, asdict, dataclass, field, fields, is_dataclass
from enum import Enum
from typing import Any, Literal, Type, TypeVar
import pyrit
from pyrit.common.deprecation import print_deprecation_message
from pyrit.identifiers.class_name_utils import class_name_to_snake_case
IdentifierType = Literal["class", "instance"]
class _ExcludeFrom(Enum):
"""
Enum specifying what a field should be excluded from.
Used as values in the _EXCLUDE metadata set for dataclass fields.
Values:
HASH: Exclude the field from hash computation (field is still stored).
STORAGE: Exclude the field from storage (implies HASH - field is also excluded from hash).
The `expands_to` property returns the full set of exclusions that apply.
For example, STORAGE.expands_to returns {STORAGE, HASH} since excluding
from storage implicitly means excluding from hash as well.
"""
HASH = "hash"
STORAGE = "storage"
@property
def expands_to(self) -> set["_ExcludeFrom"]:
"""
Get the full set of exclusions that this value implies.
This implements a catalog pattern where certain exclusions automatically
include others. For example, STORAGE expands to {STORAGE, HASH} because
a field excluded from storage should never be included in the hash.
Returns:
set[_ExcludeFrom]: The complete set of exclusions including implied ones.
"""
return _EXPANSION_CATALOG[self]
# Lookup table for exclusion expansion - defined after enum so values exist
_EXPANSION_CATALOG: dict[_ExcludeFrom, set[_ExcludeFrom]] = {
_ExcludeFrom.HASH: {_ExcludeFrom.HASH},
_ExcludeFrom.STORAGE: {_ExcludeFrom.STORAGE, _ExcludeFrom.HASH},
}
def _expand_exclusions(exclude_set: set[_ExcludeFrom]) -> set[_ExcludeFrom]:
"""
Expand a set of exclusions to include all implied exclusions.
Args:
exclude_set: A set of _ExcludeFrom values.
Returns:
set[_ExcludeFrom]: The expanded set including all implied exclusions.
"""
expanded: set[_ExcludeFrom] = set()
for exclusion in exclude_set:
expanded.update(exclusion.expands_to)
return expanded
# Metadata keys for field configuration
# _EXCLUDE is a metadata key whose value is a set of _ExcludeFrom enum values.
# Examples:
# field(metadata={_EXCLUDE: {_ExcludeFrom.HASH}}) # Stored but not hashed
# field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Not stored and not hashed (STORAGE implies HASH)
_EXCLUDE = "exclude"
_MAX_STORAGE_LENGTH = "max_storage_length"
def _is_excluded_from_hash(f: Field[Any]) -> bool:
"""
Check if a field should be excluded from hash computation.
A field is excluded from hash if, after expansion, the exclusion set contains _ExcludeFrom.HASH.
This uses the catalog expansion pattern where STORAGE automatically implies HASH.
Args:
f: A dataclass field object.
Returns:
True if the field should be excluded from hash computation.
"""
exclude_set = f.metadata.get(_EXCLUDE, set())
expanded = _expand_exclusions(exclude_set)
return _ExcludeFrom.HASH in expanded
def _is_excluded_from_storage(f: Field[Any]) -> bool:
"""
Check if a field should be excluded from storage.
A field is excluded from storage if, after expansion, the exclusion set contains _ExcludeFrom.STORAGE.
Args:
f: A dataclass field object.
Returns:
True if the field should be excluded from storage.
"""
exclude_set = f.metadata.get(_EXCLUDE, set())
expanded = _expand_exclusions(exclude_set)
return _ExcludeFrom.STORAGE in expanded
T = TypeVar("T", bound="Identifier")
[docs]
@dataclass(frozen=True)
class Identifier:
"""
Base dataclass for identifying PyRIT components.
This frozen dataclass provides a stable identifier for registry items,
targets, scorers, attacks, converters, and other components. The hash is computed at creation
time from the core fields and remains constant.
This class serves as:
1. Base for registry metadata (replacing RegistryItemMetadata)
2. Future replacement for get_identifier() dict patterns
All component-specific identifier types should extend this with additional fields.
"""
class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer")
class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer")
# Fields excluded from storage (STORAGE auto-expands to include HASH)
class_description: str = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
identifier_type: IdentifierType = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
# Auto-computed fields
snake_class_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
hash: str | None = field(default=None, compare=False, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}})
# {full_snake_case}::{hash[:8]}
unique_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}})
# Version field - stored but not hashed (allows version tracking without affecting identity)
pyrit_version: str = field(
default_factory=lambda: pyrit.__version__, kw_only=True, metadata={_EXCLUDE: {_ExcludeFrom.HASH}}
)
def __post_init__(self) -> None:
"""Compute derived fields: snake_class_name, hash, and unique_name."""
# Use object.__setattr__ since this is a frozen dataclass
# 1. Compute snake_class_name
object.__setattr__(self, "snake_class_name", class_name_to_snake_case(self.class_name))
# 2. Compute hash only if not already provided (e.g., from from_dict)
computed_hash = self.hash if self.hash is not None else self._compute_hash()
object.__setattr__(self, "hash", computed_hash)
# 3. Compute unique_name: full snake_case :: hash prefix
full_snake = class_name_to_snake_case(self.class_name)
object.__setattr__(self, "unique_name", f"{full_snake}::{computed_hash[:8]}")
def _compute_hash(self) -> str:
"""
Compute a stable SHA256 hash from identifier fields not excluded from hashing.
Fields are excluded from hash computation if they have:
metadata={_EXCLUDE: {_ExcludeFrom.HASH}} or metadata={_EXCLUDE: {_ExcludeFrom.HASH, _ExcludeFrom.STORAGE}}
Returns:
A hex string of the SHA256 hash.
"""
hashable_dict: dict[str, Any] = {
f.name: getattr(self, f.name) for f in fields(self) if not _is_excluded_from_hash(f)
}
config_json = json.dumps(hashable_dict, sort_keys=True, separators=(",", ":"), default=_dataclass_encoder)
return hashlib.sha256(config_json.encode("utf-8")).hexdigest()
[docs]
def to_dict(self) -> dict[str, Any]:
"""
Return only fields suitable for DB storage.
Fields with max_storage_length metadata are truncated to show the first
N characters followed by the field's hash, formatted as:
"<first N chars>... [sha256:<hash[:16]>]"
Nested Identifier objects are recursively serialized to dicts.
Returns:
dict[str, Any]: A dictionary containing the storable fields.
"""
result: dict[str, Any] = {}
for f in fields(self):
if _is_excluded_from_storage(f):
continue
value = getattr(self, f.name)
max_len = f.metadata.get(_MAX_STORAGE_LENGTH)
if max_len is not None and isinstance(value, str) and len(value) > max_len:
truncated = value[:max_len]
field_hash = hashlib.sha256(value.encode()).hexdigest()[:16]
value = f"{truncated}... [sha256:{field_hash}]"
# Recursively serialize nested Identifier objects
elif isinstance(value, Identifier):
value = value.to_dict()
elif isinstance(value, list) and value and isinstance(value[0], Identifier):
value = [item.to_dict() for item in value]
# Exclude None and empty values
if value is None or value == "" or value == [] or value == {}:
continue
result[f.name] = value
return result
[docs]
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""
Create an Identifier from a dictionary (e.g., retrieved from database).
Note:
For fields with max_storage_length, stored values may be truncated
strings like "<first N chars>... [sha256:<hash>]". If a 'hash' key is
present in the input dict, it will be preserved rather than recomputed,
ensuring identity matching works correctly.
Args:
data: The dictionary representation.
Returns:
A new Identifier instance.
"""
# Create a mutable copy
data = dict(data)
# Handle legacy key mappings for class_name
if "class_name" not in data:
if "__type__" in data:
print_deprecation_message(
old_item="'__type__' key in Identifier dict",
new_item="'class_name' key",
removed_in="0.13.0",
)
data["class_name"] = data.pop("__type__")
elif "type" in data:
print_deprecation_message(
old_item="'type' key in Identifier dict",
new_item="'class_name' key",
removed_in="0.13.0",
)
data["class_name"] = data.pop("type")
else:
# Default for truly legacy data without any class identifier
data["class_name"] = "Unknown"
# Handle legacy key mapping for class_module
if "class_module" not in data:
if "__module__" in data:
print_deprecation_message(
old_item="'__module__' key in Identifier dict",
new_item="'class_module' key",
removed_in="0.13.0",
)
data["class_module"] = data.pop("__module__")
else:
# Default for truly legacy data without module info
data["class_module"] = "unknown"
# Provide defaults for fields excluded from storage (not in stored dicts)
if "class_description" not in data:
data["class_description"] = ""
if "identifier_type" not in data:
data["identifier_type"] = "instance"
# Get the set of valid field names for this class
valid_fields = {f.name for f in fields(cls) if f.init}
filtered_data = {k: v for k, v in data.items() if k in valid_fields}
return cls(**filtered_data)
[docs]
@classmethod
def normalize(cls: Type[T], value: T | dict[str, Any]) -> T:
"""
Normalize a value to an Identifier instance.
This method handles conversion from legacy dict format to Identifier,
emitting a deprecation warning when a dict is passed. Existing Identifier
instances are returned as-is.
Args:
value: An Identifier instance or a dict (legacy format).
Returns:
The normalized Identifier instance.
Raises:
TypeError: If value is not an Identifier or dict.
"""
if isinstance(value, cls):
return value
if isinstance(value, dict):
print_deprecation_message(
old_item=f"dict for {cls.__name__}",
new_item=cls.__name__,
removed_in="0.14.0",
)
return cls.from_dict(value)
raise TypeError(f"Expected {cls.__name__} or dict, got {type(value).__name__}")
def _dataclass_encoder(obj: Any) -> Any:
"""
JSON encoder that handles dataclasses by converting them to dicts.
Args:
obj: The object to encode.
Returns:
Any: The dictionary representation of the dataclass.
Raises:
TypeError: If the object is not a dataclass instance.
"""
if is_dataclass(obj) and not isinstance(obj, type):
return asdict(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")