Source code for pyrit.models.prompt_request_response

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import MutableSequence, Optional, Sequence

from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.literals import PromptDataType, PromptResponseError


[docs] class PromptRequestResponse: """ Represents a response to a prompt request. This is a single request to a target. It can contain multiple prompt request pieces. Parameters: request_pieces (list[PromptRequestPiece]): The list of prompt request pieces. """
[docs] def __init__(self, request_pieces: list[PromptRequestPiece]): self.request_pieces = request_pieces
[docs] def validate(self): """ Validates the request response. """ if len(self.request_pieces) == 0: raise ValueError("Empty request pieces.") conversation_id = self.request_pieces[0].conversation_id role = None for request_piece in self.request_pieces: if request_piece.conversation_id != conversation_id: raise ValueError("Conversation ID mismatch.") if not request_piece.converted_value: raise ValueError("Converted prompt text is None.") if not role: role = request_piece.role elif role != request_piece.role: raise ValueError("Inconsistent roles within the same prompt request response entry.")
def __str__(self): ret = "" for request_piece in self.request_pieces: ret += str(request_piece) + "\n" return "\n".join([str(request_piece) for request_piece in self.request_pieces])
[docs] def group_conversation_request_pieces_by_sequence( request_pieces: Sequence[PromptRequestPiece], ) -> MutableSequence[PromptRequestResponse]: """ Groups prompt request pieces from the same conversation into PromptRequestResponses. This is done using the sequence number and conversation ID. Args: request_pieces (Sequence[PromptRequestPiece]): A list of PromptRequestPiece objects representing individual request pieces. Returns: MutableSequence[PromptRequestResponse]: A list of PromptRequestResponse objects representing grouped request pieces. This is ordered by the sequence number Raises: ValueError: If the conversation ID of any request piece does not match the conversation ID of the first request piece. Example: >>> request_pieces = [ >>> PromptRequestPiece(conversation_id=1, sequence=1, text="Hello"), >>> PromptRequestPiece(conversation_id=1, sequence=2, text="How are you?"), >>> PromptRequestPiece(conversation_id=1, sequence=1, text="Hi"), >>> PromptRequestPiece(conversation_id=1, sequence=2, text="I'm good, thanks!") >>> ] >>> grouped_responses = group_conversation_request_pieces(request_pieces) ... [ ... PromptRequestResponse(request_pieces=[ ... PromptRequestPiece(conversation_id=1, sequence=1, text="Hello"), ... PromptRequestPiece(conversation_id=1, sequence=1, text="Hi") ... ]), ... PromptRequestResponse(request_pieces=[ ... PromptRequestPiece(conversation_id=1, sequence=2, text="How are you?"), ... PromptRequestPiece(conversation_id=1, sequence=2, text="I'm good, thanks!") ... ]) ... ] """ if not request_pieces: return [] conversation_id = request_pieces[0].conversation_id conversation_by_sequence: dict[int, list[PromptRequestPiece]] = {} for request_piece in request_pieces: if request_piece.conversation_id != conversation_id: raise ValueError("Conversation ID must match.") if request_piece.sequence not in conversation_by_sequence: conversation_by_sequence[request_piece.sequence] = [request_piece] else: conversation_by_sequence[request_piece.sequence].append(request_piece) sorted_sequences = sorted(conversation_by_sequence.keys()) return [PromptRequestResponse(conversation_by_sequence[seq]) for seq in sorted_sequences]
[docs] def construct_response_from_request( request: PromptRequestPiece, response_text_pieces: list[str], response_type: PromptDataType = "text", prompt_metadata: Optional[str] = None, error: PromptResponseError = "none", ) -> PromptRequestResponse: """ Constructs a response entry from a request. """ return PromptRequestResponse( request_pieces=[ PromptRequestPiece( role="assistant", original_value=resp_text, conversation_id=request.conversation_id, labels=request.labels, prompt_target_identifier=request.prompt_target_identifier, orchestrator_identifier=request.orchestrator_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata, response_error=error, ) for resp_text in response_text_pieces ] )