Source code for pyrit.score.batch_scorer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import uuid
from datetime import datetime
from typing import Optional, Sequence
from pyrit.memory import CentralMemory
from pyrit.models import (
PromptRequestPiece,
PromptRequestResponse,
Score,
group_request_pieces_into_conversations,
)
from pyrit.score.scorer import Scorer
logger = logging.getLogger(__name__)
[docs]
class BatchScorer:
"""
A utility class for scoring prompts in batches in a parallelizable and convenient way.
This class provides functionality to score existing prompts stored in memory
without any target interaction, making it a pure scoring utility.
"""
[docs]
def __init__(
self,
*,
batch_size: int = 10,
) -> None:
"""
Initialize the BatchScorer.
Args:
batch_size (int): The (max) batch size for sending prompts. Defaults to 10.
Note: If using a scorer that takes a prompt target, and providing max requests per
minute on the target, this should be set to 1 to ensure proper rate limit management.
"""
self._memory = CentralMemory.get_memory_instance()
self._batch_size = batch_size
[docs]
async def score_responses_by_filters_async(
self,
*,
scorer: Scorer,
attack_id: Optional[str | uuid.UUID] = None,
conversation_id: Optional[str | uuid.UUID] = None,
prompt_ids: Optional[list[str] | list[uuid.UUID]] = None,
labels: Optional[dict[str, str]] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
original_values: Optional[list[str]] = None,
converted_values: Optional[list[str]] = None,
data_type: Optional[str] = None,
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[list[str]] = None,
objective: str = "",
) -> list[Score]:
"""
Score the responses that match the specified filters.
Args:
scorer (Scorer): The Scorer object to use for scoring.
attack_id (Optional[str | uuid.UUID]): The ID of the attack. Defaults to None.
conversation_id (Optional[str | uuid.UUID]): The ID of the conversation. Defaults to None.
prompt_ids (Optional[list[str] | list[uuid.UUID]]): A list of prompt IDs. Defaults to None.
labels (Optional[dict[str, str]]): A dictionary of labels. Defaults to None.
sent_after (Optional[datetime]): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime]): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[list[str]]): A list of original values. Defaults to None.
converted_values (Optional[list[str]]): A list of converted values. Defaults to None.
data_type (Optional[str]): The data type to filter by. Defaults to None.
not_data_type (Optional[str]): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[list[str]]): A list of SHA256 hashes of converted values.
Defaults to None.
objective (str): A task is used to give the scorer more context on what exactly to score.
A task might be the request prompt text or the original attack model's objective.
**Note: the same task is applied to all matched prompts.** Defaults to an empty string.
Returns:
list[Score]: A list of Score objects for responses that match the specified filters.
Raises:
ValueError: If no entries match the provided filters.
"""
request_pieces: Sequence[PromptRequestPiece] = []
request_pieces = self._memory.get_prompt_request_pieces(
attack_id=attack_id,
conversation_id=conversation_id,
prompt_ids=prompt_ids,
labels=labels,
sent_after=sent_after,
sent_before=sent_before,
original_values=original_values,
converted_values=converted_values,
data_type=data_type,
not_data_type=not_data_type,
converted_value_sha256=converted_value_sha256,
)
if not request_pieces:
raise ValueError("No entries match the provided filters. Please check your filters.")
# Group pieces by conversation
conversations = group_request_pieces_into_conversations(request_pieces)
# Flatten all conversations into a single list of responses
responses: list[PromptRequestResponse] = []
for conversation in conversations:
responses.extend(conversation)
return await scorer.score_prompts_batch_async(
request_responses=responses, objectives=[objective] * len(responses), batch_size=self._batch_size
)
def _remove_duplicates(self, request_responses: Sequence[PromptRequestPiece]) -> list[PromptRequestPiece]:
"""
Remove duplicates from the list of PromptRequestPiece objects.
This ensures that identical prompts are not scored twice by filtering
to only include original prompts (where original_prompt_id == id).
Args:
request_responses (Sequence[PromptRequestPiece]): The request responses to deduplicate.
Returns:
list[PromptRequestPiece]: A list with duplicates removed.
"""
return [response for response in request_responses if response.original_prompt_id == response.id]