Source code for pyrit.orchestrator.scoring_orchestrator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Sequence

from pyrit.models import PromptRequestPiece, Score
from pyrit.orchestrator import Orchestrator
from pyrit.score.scorer import Scorer

logger = logging.getLogger(__name__)


[docs] class ScoringOrchestrator(Orchestrator): """ This orchestrator scores prompts in a parallelizable and convenient way. """
[docs] def __init__( self, batch_size: int = 10, verbose: bool = False, ) -> None: """ Args: batch_size (int, Optional): The (max) batch size for sending prompts. Defaults to 10. Note: If using a scorer that takes a prompt target, and providing max requests per minute on the target, this should be set to 1 to ensure proper rate limit management. """ super().__init__(verbose=verbose) self._batch_size = batch_size
[docs] async def score_responses_by_orchestrator_id_async( self, *, scorer: Scorer, orchestrator_ids: list[str], ) -> list[Score]: """ Scores prompts using the Scorer for prompts correlated to the orchestrator_ids. """ request_pieces: list[PromptRequestPiece] = [] for id in orchestrator_ids: request_pieces.extend(self._memory.get_prompt_request_pieces(orchestrator_id=id)) request_pieces = self._remove_duplicates(request_pieces) return await scorer.score_responses_inferring_tasks_batch_async( request_responses=request_pieces, batch_size=self._batch_size )
[docs] async def score_responses_by_memory_labels_async( self, *, scorer: Scorer, memory_labels: dict[str, str] = {}, ) -> list[Score]: """ Scores prompts using the Scorer for prompts based on the memory labels. """ if not memory_labels: raise ValueError("Invalid memory_labels: Please provide valid memory labels.") request_pieces: list[PromptRequestPiece] = self._memory.get_prompt_request_pieces(labels=memory_labels) if not request_pieces: raise ValueError("No entries match the provided memory labels. Please check your memory labels.") request_pieces = self._remove_duplicates(request_pieces) return await scorer.score_responses_inferring_tasks_batch_async( request_responses=request_pieces, batch_size=self._batch_size )
[docs] async def score_prompts_by_id_async( self, *, scorer: Scorer, prompt_ids: list[str], responses_only: bool = False, task: str = "", ) -> list[Score]: """ Scores prompts using the Scorer for prompts with the prompt_ids The task is the task to score against. """ request_pieces: Sequence[PromptRequestPiece] = [] request_pieces = self._memory.get_prompt_request_pieces(prompt_ids=prompt_ids) if responses_only: request_pieces = self._extract_responses_only(request_pieces) request_pieces = self._remove_duplicates(request_pieces) return await scorer.score_prompts_with_tasks_batch_async( request_responses=request_pieces, batch_size=self._batch_size, tasks=[task] * len(request_pieces) )
def _extract_responses_only(self, request_responses: Sequence[PromptRequestPiece]) -> list[PromptRequestPiece]: """ Extracts the responses from the list of PromptRequestPiece objects. """ return [response for response in request_responses if response.role == "assistant"] def _remove_duplicates(self, request_responses: Sequence[PromptRequestPiece]) -> list[PromptRequestPiece]: """ Removes the duplicates from the list of PromptRequestPiece objects so that identical prompts are not scored twice. """ return [response for response in request_responses if response.original_prompt_id == response.id]