Source code for pyrit.models.prompt_request_response

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, MutableSequence, Optional, Sequence, Union

from pyrit.common.utils import combine_dict
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError
from pyrit.models.prompt_request_piece import PromptRequestPiece


[docs] class PromptRequestResponse: """ Represents a response to a prompt request. This is a single request to a target. It can contain multiple prompt request pieces. Parameters: request_pieces (Sequence[PromptRequestPiece]): The list of prompt request pieces. """
[docs] def __init__(self, request_pieces: Sequence[PromptRequestPiece], *, skip_validation: Optional[bool] = False): if not request_pieces: raise ValueError("PromptRequestResponse must have at least one request piece.") self.request_pieces = request_pieces if not skip_validation: self.validate()
[docs] def get_value(self, n: int = 0) -> str: """Return the converted value of the nth request piece.""" if n >= len(self.request_pieces): raise IndexError(f"No request piece at index {n}.") return self.request_pieces[n].converted_value
[docs] def get_values(self) -> list[str]: """Return the converted values of all request pieces.""" return [request_piece.converted_value for request_piece in self.request_pieces]
[docs] def get_piece(self, n: int = 0) -> PromptRequestPiece: """Return the nth request piece.""" if len(self.request_pieces) == 0: raise ValueError("Empty request pieces.") if n >= len(self.request_pieces): raise IndexError(f"No request piece at index {n}.") return self.request_pieces[n]
[docs] def get_role(self) -> ChatMessageRole: """Return the role of the first request.""" if len(self.request_pieces) == 0: raise ValueError("Empty request pieces.") return self.request_pieces[0].role
[docs] def is_error(self) -> bool: """ Returns True if any of the request pieces has an error response. """ for piece in self.request_pieces: if piece.response_error != "none" or piece.converted_value_data_type == "error": return True return False
[docs] def set_response_not_in_database(self): """ Set that the prompt is not in the database. This is needed when we're scoring prompts or other things that have not been sent by PyRIT """ for piece in self.request_pieces: piece.set_piece_not_in_database()
[docs] def validate(self): """ Validates the request response. """ if len(self.request_pieces) == 0: raise ValueError("Empty request pieces.") conversation_id = self.request_pieces[0].conversation_id sequence = self.request_pieces[0].sequence role = self.request_pieces[0].role for request_piece in self.request_pieces: if request_piece.conversation_id != conversation_id: raise ValueError("Conversation ID mismatch.") if request_piece.sequence != sequence: raise ValueError("Inconsistent sequences within the same prompt request response entry.") if request_piece.converted_value is None: raise ValueError("Converted prompt text is None.") if request_piece.role != role: raise ValueError("Inconsistent roles within the same prompt request response entry.")
def __str__(self): ret = "" for request_piece in self.request_pieces: ret += str(request_piece) + "\n" return "\n".join([str(request_piece) for request_piece in self.request_pieces])
[docs] @staticmethod def get_all_values(request_responses: Sequence["PromptRequestResponse"]) -> list[str]: """Return all converted values across the provided request responses.""" values: list[str] = [] for request_response in request_responses: values.extend(request_response.get_values()) return values
[docs] @staticmethod def flatten_to_prompt_request_pieces( request_responses: Sequence["PromptRequestResponse"], ) -> MutableSequence[PromptRequestPiece]: if not request_responses: return [] response_pieces: MutableSequence[PromptRequestPiece] = [] for response in request_responses: response_pieces.extend(response.request_pieces) return response_pieces
[docs] @classmethod def from_prompt(cls, *, prompt: str, role: ChatMessageRole) -> "PromptRequestResponse": piece = PromptRequestPiece(original_value=prompt, role=role) return cls(request_pieces=[piece])
[docs] @classmethod def from_system_prompt(cls, system_prompt: str) -> "PromptRequestResponse": return cls.from_prompt(prompt=system_prompt, role="system")
[docs] def group_conversation_request_pieces_by_sequence( request_pieces: Sequence[PromptRequestPiece], ) -> MutableSequence[PromptRequestResponse]: """ Groups prompt request pieces from the same conversation into PromptRequestResponses. This is done using the sequence number and conversation ID. Args: request_pieces (Sequence[PromptRequestPiece]): A list of PromptRequestPiece objects representing individual request pieces. Returns: MutableSequence[PromptRequestResponse]: A list of PromptRequestResponse objects representing grouped request pieces. This is ordered by the sequence number Raises: ValueError: If the conversation ID of any request piece does not match the conversation ID of the first request piece. Example: >>> request_pieces = [ >>> PromptRequestPiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your >>> favorite:"), >>> PromptRequestPiece(conversation_id=1, sequence=2, text="Good question!"), >>> PromptRequestPiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?"), >>> PromptRequestPiece(conversation_id=1, sequence=2, text="I'd have to say racoons are my favorite!"), >>> ] >>> grouped_responses = group_conversation_request_pieces(request_pieces) ... [ ... PromptRequestResponse(request_pieces=[ ... PromptRequestPiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your ... favorite:"), ... PromptRequestPiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?") ... ]), ... PromptRequestResponse(request_pieces=[ ... PromptRequestPiece(conversation_id=1, sequence=2, text="Good question!"), ... PromptRequestPiece(conversation_id=1, sequence=2, text="I'd have to say racoons are my favorite!") ... ]) ... ] """ if not request_pieces: return [] conversation_id = request_pieces[0].conversation_id conversation_by_sequence: dict[int, list[PromptRequestPiece]] = {} for request_piece in request_pieces: if request_piece.conversation_id != conversation_id: raise ValueError( f"All request pieces must be from the same conversation. " f"Expected conversation_id='{conversation_id}', but found '{request_piece.conversation_id}'. " f"If grouping pieces from multiple conversations, group by conversation_id first." ) if request_piece.sequence not in conversation_by_sequence: conversation_by_sequence[request_piece.sequence] = [] conversation_by_sequence[request_piece.sequence].append(request_piece) sorted_sequences = sorted(conversation_by_sequence.keys()) return [PromptRequestResponse(conversation_by_sequence[seq]) for seq in sorted_sequences]
def group_request_pieces_into_conversations( request_pieces: Sequence[PromptRequestPiece], ) -> list[list[PromptRequestResponse]]: """ Groups prompt request pieces from multiple conversations into separate conversation groups. This function first groups pieces by conversation ID, then groups each conversation's pieces by sequence number. Each conversation is returned as a separate list of PromptRequestResponse objects. Args: request_pieces (Sequence[PromptRequestPiece]): A list of PromptRequestPiece objects from potentially different conversations. Returns: list[list[PromptRequestResponse]]: A list of conversations, where each conversation is a list of PromptRequestResponse objects grouped by sequence. Example: >>> request_pieces = [ >>> PromptRequestPiece(conversation_id="conv1", sequence=1, text="Hello"), >>> PromptRequestPiece(conversation_id="conv2", sequence=1, text="Hi there"), >>> PromptRequestPiece(conversation_id="conv1", sequence=2, text="How are you?"), >>> PromptRequestPiece(conversation_id="conv2", sequence=2, text="I'm good"), >>> ] >>> conversations = group_request_pieces_into_conversations(request_pieces) >>> # Returns a list of 2 conversations: >>> # [ >>> # [PromptRequestResponse(seq=1), PromptRequestResponse(seq=2)], # conv1 >>> # [PromptRequestResponse(seq=1), PromptRequestResponse(seq=2)] # conv2 >>> # ] """ if not request_pieces: return [] # Group pieces by conversation ID conversations: dict[str, list[PromptRequestPiece]] = {} for piece in request_pieces: conv_id = piece.conversation_id if conv_id not in conversations: conversations[conv_id] = [] conversations[conv_id].append(piece) # For each conversation, group by sequence result: list[list[PromptRequestResponse]] = [] for conv_pieces in conversations.values(): responses = group_conversation_request_pieces_by_sequence(conv_pieces) result.append(list(responses)) return result
[docs] def construct_response_from_request( request: PromptRequestPiece, response_text_pieces: list[str], response_type: PromptDataType = "text", prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, error: PromptResponseError = "none", ) -> PromptRequestResponse: """ Constructs a response entry from a request. """ if request.prompt_metadata: prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {}) return PromptRequestResponse( request_pieces=[ PromptRequestPiece( role="assistant", original_value=resp_text, conversation_id=request.conversation_id, labels=request.labels, prompt_target_identifier=request.prompt_target_identifier, attack_identifier=request.attack_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata, response_error=error, ) for resp_text in response_text_pieces ] )