Source code for pyrit.models.message

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from typing import Dict, MutableSequence, Optional, Sequence, Union

from pyrit.common.utils import combine_dict
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError
from pyrit.models.message_piece import MessagePiece


[docs] class Message: """ Represents a message in a conversation, for example a prompt or a response to a prompt. This is a single request to a target. It can contain multiple message pieces. Parameters: message_pieces (Sequence[MessagePiece]): The list of message pieces. """
[docs] def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False): if not message_pieces: raise ValueError("Message must have at least one message piece.") self.message_pieces = message_pieces if not skip_validation: self.validate()
[docs] def get_value(self, n: int = 0) -> str: """Return the converted value of the nth message piece.""" if n >= len(self.message_pieces): raise IndexError(f"No message piece at index {n}.") return self.message_pieces[n].converted_value
[docs] def get_values(self) -> list[str]: """Return the converted values of all message pieces.""" return [message_piece.converted_value for message_piece in self.message_pieces]
[docs] def get_piece(self, n: int = 0) -> MessagePiece: """Return the nth message piece.""" if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") if n >= len(self.message_pieces): raise IndexError(f"No message piece at index {n}.") return self.message_pieces[n]
@property def role(self) -> ChatMessageRole: """Return the role of the first request piece (they should all be the same).""" if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") return self.message_pieces[0].role @property def conversation_id(self) -> str: """Return the conversation ID of the first request piece (they should all be the same).""" if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") return self.message_pieces[0].conversation_id @property def sequence(self) -> int: """Return the sequence of the first request piece (they should all be the same).""" if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") return self.message_pieces[0].sequence
[docs] def is_error(self) -> bool: """ Returns True if any of the message pieces have an error response. """ for piece in self.message_pieces: if piece.response_error != "none" or piece.converted_value_data_type == "error": return True return False
[docs] def set_response_not_in_database(self): """ Set that the prompt is not in the database. This is needed when we're scoring prompts or other things that have not been sent by PyRIT """ for piece in self.message_pieces: piece.set_piece_not_in_database()
[docs] def validate(self): """ Validates the request response. """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") conversation_id = self.message_pieces[0].conversation_id sequence = self.message_pieces[0].sequence role = self.message_pieces[0].role for message_piece in self.message_pieces: if message_piece.conversation_id != conversation_id: raise ValueError("Conversation ID mismatch.") if message_piece.sequence != sequence: raise ValueError("Inconsistent sequences within the same message entry.") if message_piece.converted_value is None: raise ValueError("Converted prompt text is None.") if message_piece.role != role: raise ValueError("Inconsistent roles within the same message entry.")
def __str__(self): ret = "" for message_piece in self.message_pieces: ret += str(message_piece) + "\n" return "\n".join([str(message_piece) for message_piece in self.message_pieces])
[docs] @staticmethod def get_all_values(messages: Sequence[Message]) -> list[str]: """Return all converted values across the provided messages.""" values: list[str] = [] for message in messages: values.extend(message.get_values()) return values
[docs] @staticmethod def flatten_to_message_pieces( messages: Sequence[Message], ) -> MutableSequence[MessagePiece]: if not messages: return [] message_pieces: MutableSequence[MessagePiece] = [] for response in messages: message_pieces.extend(response.message_pieces) return message_pieces
[docs] @classmethod def from_prompt(cls, *, prompt: str, role: ChatMessageRole) -> Message: piece = MessagePiece(original_value=prompt, role=role) return cls(message_pieces=[piece])
[docs] @classmethod def from_system_prompt(cls, system_prompt: str) -> Message: return cls.from_prompt(prompt=system_prompt, role="system")
[docs] def group_conversation_message_pieces_by_sequence( message_pieces: Sequence[MessagePiece], ) -> MutableSequence[Message]: """ Groups message pieces from the same conversation into Messages. This is done using the sequence number and conversation ID. Args: message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects representing individual message pieces. Returns: MutableSequence[Message]: A list of Message objects representing grouped message pieces. This is ordered by the sequence number. Raises: ValueError: If the conversation ID of any message piece does not match the conversation ID of the first message piece. Example: >>> message_pieces = [ >>> MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your >>> favorite:"), >>> MessagePiece(conversation_id=1, sequence=2, text="Good question!"), >>> MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?"), >>> MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!"), >>> ] >>> grouped_responses = group_conversation_message_pieces(message_pieces) ... [ ... Message(message_pieces=[ ... MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your ... favorite:"), ... MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?") ... ]), ... Message(message_pieces=[ ... MessagePiece(conversation_id=1, sequence=2, text="Good question!"), ... MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!") ... ]) ... ] """ if not message_pieces: return [] conversation_id = message_pieces[0].conversation_id conversation_by_sequence: dict[int, list[MessagePiece]] = {} for message_piece in message_pieces: if message_piece.conversation_id != conversation_id: raise ValueError( f"All message pieces must be from the same conversation. " f"Expected conversation_id='{conversation_id}', but found '{message_piece.conversation_id}'. " f"If grouping pieces from multiple conversations, group by conversation_id first." ) if message_piece.sequence not in conversation_by_sequence: conversation_by_sequence[message_piece.sequence] = [] conversation_by_sequence[message_piece.sequence].append(message_piece) sorted_sequences = sorted(conversation_by_sequence.keys()) return [Message(conversation_by_sequence[seq]) for seq in sorted_sequences]
def group_message_pieces_into_conversations( message_pieces: Sequence[MessagePiece], ) -> list[list[Message]]: """ Groups message pieces from multiple conversations into separate conversation groups. This function first groups pieces by conversation ID, then groups each conversation's pieces by sequence number. Each conversation is returned as a separate list of Message objects. Args: message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects from potentially different conversations. Returns: list[list[Message]]: A list of conversations, where each conversation is a list of Message objects grouped by sequence. Example: >>> message_pieces = [ >>> MessagePiece(conversation_id="conv1", sequence=1, text="Hello"), >>> MessagePiece(conversation_id="conv2", sequence=1, text="Hi there"), >>> MessagePiece(conversation_id="conv1", sequence=2, text="How are you?"), >>> MessagePiece(conversation_id="conv2", sequence=2, text="I'm good"), >>> ] >>> conversations = group_message_pieces_into_conversations(message_pieces) >>> # Returns a list of 2 conversations: >>> # [ >>> # [Message(seq=1), Message(seq=2)], # conv1 >>> # [Message(seq=1), Message(seq=2)] # conv2 >>> # ] """ if not message_pieces: return [] # Group pieces by conversation ID conversations: dict[str, list[MessagePiece]] = {} for piece in message_pieces: conv_id = piece.conversation_id if conv_id not in conversations: conversations[conv_id] = [] conversations[conv_id].append(piece) # For each conversation, group by sequence result: list[list[Message]] = [] for conv_pieces in conversations.values(): responses = group_conversation_message_pieces_by_sequence(conv_pieces) result.append(list(responses)) return result
[docs] def construct_response_from_request( request: MessagePiece, response_text_pieces: list[str], response_type: PromptDataType = "text", prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, error: PromptResponseError = "none", ) -> Message: """ Constructs a response entry from a request. """ if request.prompt_metadata: prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {}) return Message( message_pieces=[ MessagePiece( role="assistant", original_value=resp_text, conversation_id=request.conversation_id, labels=request.labels, prompt_target_identifier=request.prompt_target_identifier, attack_identifier=request.attack_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata, response_error=error, ) for resp_text in response_text_pieces ] )