Source code for pyrit.score.scorer_prompt_validator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Sequence, get_args

from pyrit.models import ChatMessageRole, Message, MessagePiece, PromptDataType


[docs] class ScorerPromptValidator: """ Validates message pieces and scorer configurations. This class provides validation for scorer inputs, ensuring that message pieces meet required criteria such as data types, roles, and metadata requirements. """
[docs] def __init__( self, *, supported_data_types: Optional[Sequence[PromptDataType]] = None, required_metadata: Optional[Sequence[str]] = None, supported_roles: Optional[Sequence[ChatMessageRole]] = None, max_pieces_in_response: Optional[int] = None, max_text_length: Optional[int] = None, enforce_all_pieces_valid: Optional[bool] = False, raise_on_no_valid_pieces: Optional[bool] = True, is_objective_required=False, ): """ Initialize the ScorerPromptValidator. Args: supported_data_types (Optional[Sequence[PromptDataType]]): Data types that the scorer supports. Defaults to all data types if not provided. required_metadata (Optional[Sequence[str]]): Metadata keys that must be present in message pieces. Defaults to empty list. supported_roles (Optional[Sequence[ChatMessageRole]]): Message roles that the scorer supports. Defaults to all roles if not provided. max_pieces_in_response (Optional[int]): Maximum number of pieces allowed in a response. Defaults to None (no limit). max_text_length (Optional[int]): Maximum character length for text data type pieces. Defaults to None (no limit). enforce_all_pieces_valid (Optional[bool]): Whether all pieces must be valid or just at least one. Defaults to False. raise_on_no_valid_pieces (Optional[bool]): Whether to raise ValueError when no pieces are valid. Defaults to True for backwards compatibility. Set to False to allow empty scores. is_objective_required (bool): Whether an objective must be provided for scoring. Defaults to False. """ if supported_data_types: self._supported_data_types = supported_data_types else: self._supported_data_types = get_args(PromptDataType) if supported_roles: self._supported_roles = supported_roles else: self._supported_roles = get_args(ChatMessageRole) self._required_metadata = required_metadata or [] self._max_pieces_in_response = max_pieces_in_response self._max_text_length = max_text_length self._enforce_all_pieces_valid = enforce_all_pieces_valid self._raise_on_no_valid_pieces = raise_on_no_valid_pieces self._is_objective_required = is_objective_required
[docs] def validate(self, message: Message, objective: str | None) -> None: """ Validate a message and objective against configured requirements. Args: message (Message): The message to validate. objective (str | None): The objective string, if required. Raises: ValueError: If validation fails due to unsupported pieces, exceeding max pieces, or missing objective. """ valid_pieces_count = 0 for piece in message.message_pieces: if self.is_message_piece_supported(piece): valid_pieces_count += 1 elif self._enforce_all_pieces_valid: raise ValueError( f"Message piece {piece.id} with data type {piece.converted_value_data_type} is not supported." ) if valid_pieces_count < 1 and self._raise_on_no_valid_pieces: attempted_metadata = [getattr(piece, "prompt_metadata", None) for piece in message.message_pieces] raise ValueError( "There are no valid pieces to score. \n\n" f"Required types: {self._supported_data_types}. " f"Required metadata: {self._required_metadata}. " f"Length limit: {self._max_pieces_in_response}. " f"Objective required: {self._is_objective_required}. " f"Message pieces: {message.message_pieces}. " f"Prompt metadata: {attempted_metadata}. " f"Objective included: {objective}. " ) if self._max_pieces_in_response is not None: if len(message.message_pieces) > self._max_pieces_in_response: raise ValueError( f"Message has {len(message.message_pieces)} pieces, " f"exceeding the limit of {self._max_pieces_in_response}." ) if self._is_objective_required and not objective: raise ValueError("Objective is required but not provided.")
[docs] def is_message_piece_supported(self, message_piece: MessagePiece) -> bool: """ Check if a message piece is supported by this validator. Args: message_piece (MessagePiece): The message piece to check. Returns: bool: True if the message piece meets all validation criteria, False otherwise. """ if message_piece.converted_value_data_type not in self._supported_data_types: return False for metadata in self._required_metadata: if metadata not in message_piece.prompt_metadata: return False if message_piece.role not in self._supported_roles: return False # Check text length limit for text data types if self._max_text_length is not None and message_piece.converted_value_data_type == "text": text_length = len(message_piece.converted_value) if message_piece.converted_value else 0 if text_length > self._max_text_length: return False return True