Source code for pyrit.chat_message_normalizer.chat_message_normalizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
from pyrit.models import ChatMessage
from typing import TypeVar, Generic

T = TypeVar("T", str, list[ChatMessage])


[docs] class ChatMessageNormalizer(abc.ABC, Generic[T]):
[docs] @abc.abstractmethod def normalize(self, messages: list[ChatMessage]) -> T: """ Normalizes the list of chat messages into a compatible format for the model or target """ pass
[docs] @staticmethod def squash_system_message(messages: list[ChatMessage], squash_function) -> list[ChatMessage]: """ Combines the system message into the first user request. Args: messages: The list of chat messages. squash_function: The function to combine the system message with the user message. Returns: The list of chat messages with squashed system messages. """ if not messages: raise ValueError("ChatMessage list cannot be empty") if messages[0].role == "system": if len(messages) == 1: return [ChatMessage(role="user", content=messages[0].content)] first_user_message = squash_function(messages[0], messages[1]) return [first_user_message] + messages[2:] return messages