Source code for pyrit.score.scorer_prompt_validator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Sequence, get_args

from pyrit.models import PromptRequestPiece, PromptRequestResponse
from pyrit.models.literals import PromptDataType


[docs] class ScorerPromptValidator:
[docs] def __init__( self, *, supported_data_types: Optional[Sequence[PromptDataType]] = None, required_metadata: Optional[Sequence[str]] = None, max_pieces_in_response: Optional[int] = None, enforce_all_pieces_valid: Optional[bool] = False, is_objective_required=False, ): if supported_data_types: self._supported_data_types = supported_data_types else: self._supported_data_types = get_args(PromptDataType) self._required_metadata = required_metadata or [] self._max_pieces_in_response = max_pieces_in_response self._enforce_all_pieces_valid = enforce_all_pieces_valid self._is_objective_required = is_objective_required
[docs] def validate(self, request_response: PromptRequestResponse, objective: str | None) -> None: valid_pieces_count = 0 for piece in request_response.request_pieces: if self.is_request_piece_supported(piece): valid_pieces_count += 1 elif self._enforce_all_pieces_valid: raise ValueError( f"Request piece {piece.id} with data type {piece.converted_value_data_type} is not supported." ) if valid_pieces_count < 1: attempted_metadata = [getattr(piece, "prompt_metadata", None) for piece in request_response.request_pieces] raise ValueError( "There are no valid pieces to score. \n\n" f"Required types: {self._supported_data_types}. " f"Required metadata: {self._required_metadata}. " f"Length limit: {self._max_pieces_in_response}. " f"Objective required: {self._is_objective_required}. " f"Prompt pieces: {request_response.request_pieces}. " f"Prompt metadata: {attempted_metadata}. " f"Objective included: {objective}. " ) if self._max_pieces_in_response is not None: if len(request_response.request_pieces) > self._max_pieces_in_response: raise ValueError( f"Request response has {len(request_response.request_pieces)} pieces, " f"exceeding the limit of {self._max_pieces_in_response}." ) if self._is_objective_required and not objective: raise ValueError("Objective is required but not provided.")
[docs] def is_request_piece_supported(self, request_piece: PromptRequestPiece) -> bool: if request_piece.converted_value_data_type not in self._supported_data_types: return False for metadata in self._required_metadata: if metadata not in request_piece.prompt_metadata: return False return True