Source code for pyrit.message_normalizer.message_normalizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
from typing import Any, Generic, List, Literal, Protocol, TypeVar

from pyrit.models import Message

# Type alias for system message handling strategies
SystemMessageBehavior = Literal["keep", "squash", "ignore"]
"""
How to handle system messages in models with varying support:
- "keep": Keep system messages as-is (default for most models)
- "squash": Merge system message into first user message
- "ignore": Drop system messages entirely
"""


class DictConvertible(Protocol):
    """Protocol for objects that can be converted to a dictionary."""

    def to_dict(self) -> dict[str, Any]:
        """Convert the object to a dictionary representation."""
        ...


T = TypeVar("T", bound=DictConvertible)


[docs] class MessageListNormalizer(abc.ABC, Generic[T]): """ Abstract base class for normalizers that return a list of items. Subclasses specify the type T (e.g., Message, ChatMessage) that the list contains. T must implement the DictConvertible protocol (have a to_dict() method). """
[docs] @abc.abstractmethod async def normalize_async(self, messages: List[Message]) -> List[T]: """ Normalize the list of messages into a list of items. Args: messages: The list of Message objects to normalize. Returns: A list of normalized items of type T. """
[docs] async def normalize_to_dicts_async(self, messages: List[Message]) -> List[dict[str, Any]]: """ Normalize the list of messages into a list of dictionaries. This method uses normalize_async and calls to_dict() on each item. Args: messages: The list of Message objects to normalize. Returns: A list of dictionaries representing the normalized messages. """ normalized = await self.normalize_async(messages) return [item.to_dict() for item in normalized]
[docs] class MessageStringNormalizer(abc.ABC): """ Abstract base class for normalizers that return a string representation. Use this for formatting messages into text for non-chat targets or context strings. """
[docs] @abc.abstractmethod async def normalize_string_async(self, messages: List[Message]) -> str: """ Normalize the list of messages into a string representation. Args: messages: The list of Message objects to normalize. Returns: A string representation of the messages. """
async def apply_system_message_behavior(messages: List[Message], behavior: SystemMessageBehavior) -> List[Message]: """ Apply a system message behavior to a list of messages. This is a helper function used by normalizers to preprocess messages based on how the target handles system messages. Args: messages: The list of Message objects to process. behavior: How to handle system messages: - "keep": Return messages unchanged - "squash": Merge system into first user message - "ignore": Remove system messages Returns: The processed list of Message objects. Raises: ValueError: If an unknown behavior is provided. """ if behavior == "keep": return messages elif behavior == "squash": # Import here to avoid circular imports from pyrit.message_normalizer.generic_system_squash import ( GenericSystemSquashNormalizer, ) return await GenericSystemSquashNormalizer().normalize_async(messages) elif behavior == "ignore": return [msg for msg in messages if msg.role != "system"] else: # This should never happen due to Literal type, but handle it gracefully raise ValueError(f"Unknown system message behavior: {behavior}")