Source code for pyrit.score.question_answer_scorer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from typing import Optional

from pyrit.models import Score
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.score import Scorer


[docs] class QuestionAnswerScorer(Scorer): """ A class that represents a question answering scorer. """ CORRECT_ANSWER_MATCHING_PATTERNS = ["{correct_answer_index}:", "{correct_answer}"]
[docs] def __init__( self, *, correct_answer_matching_patterns: list[str] = CORRECT_ANSWER_MATCHING_PATTERNS, category: str = "", ) -> None: """ Scores PromptRequestResponse objects that contain correct_answer_index and/or correct_answer metadata Args: correct_answer_matching_patterns (list[str]): A list of patterns to check for in the response. If any pattern is found in the response, the score will be True. These patterns should be format strings that will be formatted with the correct answer metadata. """ self._correct_answer_matching_patterns = correct_answer_matching_patterns self._score_category = category self.scorer_type = "true_false"
[docs] async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: """ Score the request_reponse using the QuestionAnsweringEntry and return a single score object Args: request_response (PromptRequestPiece): The answer given by the target task (QuestionAnsweringEntry): The entry containing the original question and the correct answer Returns: Score: A single Score object representing the result """ self.validate(request_response, task=task) result = False matching_text = None correct_index = request_response.prompt_metadata["correct_answer_index"] correct_answer = request_response.prompt_metadata["correct_answer"] for pattern in self._correct_answer_matching_patterns: text = pattern.format(correct_answer_index=correct_index, correct_answer=correct_answer).lower() if text in request_response.converted_value.lower(): result = True matching_text = text break scores = [ Score( score_value=str(result), score_value_description=None, score_metadata=None, score_type=self.scorer_type, score_category=self._score_category, score_rationale=( f"Found matching text '{matching_text}' in response" if matching_text else "No matching text found in response" ), scorer_class_identifier=self.get_identifier(), prompt_request_response_id=request_response.id, task=task, ) ] self._memory.add_scores_to_memory(scores=scores) return scores
[docs] def validate(self, request_response: PromptRequestPiece, *, task: Optional[str] = None): """ Validates the request_response piece to score. Because some scorers may require specific PromptRequestPiece types or values. Args: request_response (PromptRequestPiece): The request response to be validated. task (str): The task based on which the text should be scored (the original attacker model's objective). """ if request_response.converted_value_data_type != "text": raise ValueError("Question Answer Scorer only supports text data type") if not request_response.prompt_metadata or ( "correct_answer_index" not in request_response.prompt_metadata and "correct_answer" not in request_response.prompt_metadata ): raise ValueError( "Question Answer Scorer requires metadata with either correct_answer_index or correct_answer" )