# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Dict, MutableSequence, Optional, Sequence, Union
from pyrit.common.utils import combine_dict
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError
from pyrit.models.message_piece import MessagePiece
[docs]
class Message:
"""
Represents a message in a conversation, for example a prompt or a response to a prompt.
This is a single request to a target. It can contain multiple message pieces.
Parameters:
message_pieces (Sequence[MessagePiece]): The list of message pieces.
"""
[docs]
def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False):
if not message_pieces:
raise ValueError("Message must have at least one message piece.")
self.message_pieces = message_pieces
if not skip_validation:
self.validate()
[docs]
def get_value(self, n: int = 0) -> str:
"""Return the converted value of the nth message piece."""
if n >= len(self.message_pieces):
raise IndexError(f"No message piece at index {n}.")
return self.message_pieces[n].converted_value
[docs]
def get_values(self) -> list[str]:
"""Return the converted values of all message pieces."""
return [message_piece.converted_value for message_piece in self.message_pieces]
[docs]
def get_piece(self, n: int = 0) -> MessagePiece:
"""Return the nth message piece."""
if len(self.message_pieces) == 0:
raise ValueError("Empty message pieces.")
if n >= len(self.message_pieces):
raise IndexError(f"No message piece at index {n}.")
return self.message_pieces[n]
@property
def role(self) -> ChatMessageRole:
"""Return the role of the first request piece (they should all be the same)."""
if len(self.message_pieces) == 0:
raise ValueError("Empty message pieces.")
return self.message_pieces[0].role
@property
def conversation_id(self) -> str:
"""Return the conversation ID of the first request piece (they should all be the same)."""
if len(self.message_pieces) == 0:
raise ValueError("Empty message pieces.")
return self.message_pieces[0].conversation_id
@property
def sequence(self) -> int:
"""Return the sequence of the first request piece (they should all be the same)."""
if len(self.message_pieces) == 0:
raise ValueError("Empty message pieces.")
return self.message_pieces[0].sequence
[docs]
def is_error(self) -> bool:
"""
Returns True if any of the message pieces have an error response.
"""
for piece in self.message_pieces:
if piece.response_error != "none" or piece.converted_value_data_type == "error":
return True
return False
[docs]
def set_response_not_in_database(self):
"""
Set that the prompt is not in the database.
This is needed when we're scoring prompts or other things that have not been sent by PyRIT
"""
for piece in self.message_pieces:
piece.set_piece_not_in_database()
[docs]
def validate(self):
"""
Validates the request response.
"""
if len(self.message_pieces) == 0:
raise ValueError("Empty message pieces.")
conversation_id = self.message_pieces[0].conversation_id
sequence = self.message_pieces[0].sequence
role = self.message_pieces[0].role
for message_piece in self.message_pieces:
if message_piece.conversation_id != conversation_id:
raise ValueError("Conversation ID mismatch.")
if message_piece.sequence != sequence:
raise ValueError("Inconsistent sequences within the same message entry.")
if message_piece.converted_value is None:
raise ValueError("Converted prompt text is None.")
if message_piece.role != role:
raise ValueError("Inconsistent roles within the same message entry.")
def __str__(self):
ret = ""
for message_piece in self.message_pieces:
ret += str(message_piece) + "\n"
return "\n".join([str(message_piece) for message_piece in self.message_pieces])
[docs]
@staticmethod
def get_all_values(messages: Sequence[Message]) -> list[str]:
"""Return all converted values across the provided messages."""
values: list[str] = []
for message in messages:
values.extend(message.get_values())
return values
[docs]
@staticmethod
def flatten_to_message_pieces(
messages: Sequence[Message],
) -> MutableSequence[MessagePiece]:
if not messages:
return []
message_pieces: MutableSequence[MessagePiece] = []
for response in messages:
message_pieces.extend(response.message_pieces)
return message_pieces
[docs]
@classmethod
def from_prompt(cls, *, prompt: str, role: ChatMessageRole) -> Message:
piece = MessagePiece(original_value=prompt, role=role)
return cls(message_pieces=[piece])
[docs]
@classmethod
def from_system_prompt(cls, system_prompt: str) -> Message:
return cls.from_prompt(prompt=system_prompt, role="system")
[docs]
def group_conversation_message_pieces_by_sequence(
message_pieces: Sequence[MessagePiece],
) -> MutableSequence[Message]:
"""
Groups message pieces from the same conversation into Messages.
This is done using the sequence number and conversation ID.
Args:
message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects representing individual
message pieces.
Returns:
MutableSequence[Message]: A list of Message objects representing grouped message
pieces. This is ordered by the sequence number.
Raises:
ValueError: If the conversation ID of any message piece does not match the conversation ID of the first
message piece.
Example:
>>> message_pieces = [
>>> MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your
>>> favorite:"),
>>> MessagePiece(conversation_id=1, sequence=2, text="Good question!"),
>>> MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?"),
>>> MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!"),
>>> ]
>>> grouped_responses = group_conversation_message_pieces(message_pieces)
... [
... Message(message_pieces=[
... MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your
... favorite:"),
... MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?")
... ]),
... Message(message_pieces=[
... MessagePiece(conversation_id=1, sequence=2, text="Good question!"),
... MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!")
... ])
... ]
"""
if not message_pieces:
return []
conversation_id = message_pieces[0].conversation_id
conversation_by_sequence: dict[int, list[MessagePiece]] = {}
for message_piece in message_pieces:
if message_piece.conversation_id != conversation_id:
raise ValueError(
f"All message pieces must be from the same conversation. "
f"Expected conversation_id='{conversation_id}', but found '{message_piece.conversation_id}'. "
f"If grouping pieces from multiple conversations, group by conversation_id first."
)
if message_piece.sequence not in conversation_by_sequence:
conversation_by_sequence[message_piece.sequence] = []
conversation_by_sequence[message_piece.sequence].append(message_piece)
sorted_sequences = sorted(conversation_by_sequence.keys())
return [Message(conversation_by_sequence[seq]) for seq in sorted_sequences]
def group_message_pieces_into_conversations(
message_pieces: Sequence[MessagePiece],
) -> list[list[Message]]:
"""
Groups message pieces from multiple conversations into separate conversation groups.
This function first groups pieces by conversation ID, then groups each conversation's
pieces by sequence number. Each conversation is returned as a separate list of
Message objects.
Args:
message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects from
potentially different conversations.
Returns:
list[list[Message]]: A list of conversations, where each conversation is a list
of Message objects grouped by sequence.
Example:
>>> message_pieces = [
>>> MessagePiece(conversation_id="conv1", sequence=1, text="Hello"),
>>> MessagePiece(conversation_id="conv2", sequence=1, text="Hi there"),
>>> MessagePiece(conversation_id="conv1", sequence=2, text="How are you?"),
>>> MessagePiece(conversation_id="conv2", sequence=2, text="I'm good"),
>>> ]
>>> conversations = group_message_pieces_into_conversations(message_pieces)
>>> # Returns a list of 2 conversations:
>>> # [
>>> # [Message(seq=1), Message(seq=2)], # conv1
>>> # [Message(seq=1), Message(seq=2)] # conv2
>>> # ]
"""
if not message_pieces:
return []
# Group pieces by conversation ID
conversations: dict[str, list[MessagePiece]] = {}
for piece in message_pieces:
conv_id = piece.conversation_id
if conv_id not in conversations:
conversations[conv_id] = []
conversations[conv_id].append(piece)
# For each conversation, group by sequence
result: list[list[Message]] = []
for conv_pieces in conversations.values():
responses = group_conversation_message_pieces_by_sequence(conv_pieces)
result.append(list(responses))
return result
[docs]
def construct_response_from_request(
request: MessagePiece,
response_text_pieces: list[str],
response_type: PromptDataType = "text",
prompt_metadata: Optional[Dict[str, Union[str, int]]] = None,
error: PromptResponseError = "none",
) -> Message:
"""
Constructs a response entry from a request.
"""
if request.prompt_metadata:
prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {})
return Message(
message_pieces=[
MessagePiece(
role="assistant",
original_value=resp_text,
conversation_id=request.conversation_id,
labels=request.labels,
prompt_target_identifier=request.prompt_target_identifier,
attack_identifier=request.attack_identifier,
original_value_data_type=response_type,
converted_value_data_type=response_type,
prompt_metadata=prompt_metadata,
response_error=error,
)
for resp_text in response_text_pieces
]
)