Source code for pyrit.orchestrator.scoring_orchestrator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Sequence
from pyrit.models import PromptRequestPiece
from pyrit.models import Score
from pyrit.orchestrator import Orchestrator
from pyrit.score.scorer import Scorer
logger = logging.getLogger(__name__)
[docs]
class ScoringOrchestrator(Orchestrator):
"""
This orchestrator scores prompts in a parallelizable and convenient way.
"""
[docs]
def __init__(
self,
batch_size: int = 10,
verbose: bool = False,
) -> None:
"""
Args:
batch_size (int, Optional): The (max) batch size for sending prompts. Defaults to 10.
Note: If using a scorer that takes a prompt target, and providing max requests per
minute on the target, this should be set to 1 to ensure proper rate limit management.
"""
super().__init__(verbose=verbose)
self._batch_size = batch_size
[docs]
async def score_prompts_by_orchestrator_id_async(
self,
*,
scorer: Scorer,
orchestrator_ids: list[str],
responses_only: bool = True,
) -> list[Score]:
"""
Scores prompts using the Scorer for prompts correlated to the orchestrator_ids.
"""
request_pieces: list[PromptRequestPiece] = []
for id in orchestrator_ids:
request_pieces.extend(self._memory.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=id))
if responses_only:
request_pieces = self._extract_responses_only(request_pieces)
request_pieces = self._remove_duplicates(request_pieces)
return await scorer.score_prompts_batch_async(request_responses=request_pieces, batch_size=self._batch_size)
[docs]
async def score_prompts_by_memory_labels_async(
self,
*,
scorer: Scorer,
memory_labels: dict[str, str] = {},
responses_only: bool = True,
) -> list[Score]:
"""
Scores prompts using the Scorer for prompts based on the memory labels.
"""
if not memory_labels:
raise ValueError("Invalid memory_labels: Please provide valid memory labels.")
request_pieces: list[PromptRequestPiece] = self._memory.get_prompt_request_piece_by_memory_labels(
memory_labels=memory_labels
)
if not request_pieces:
raise ValueError("No entries match the provided memory labels. Please check your memory labels.")
if responses_only:
request_pieces = self._extract_responses_only(request_pieces)
request_pieces = self._remove_duplicates(request_pieces)
return await scorer.score_prompts_batch_async(request_responses=request_pieces, batch_size=self._batch_size)
[docs]
async def score_prompts_by_request_id_async(
self,
*,
scorer: Scorer,
prompt_ids: list[str],
responses_only: bool = False,
) -> list[Score]:
"""
Scores prompts using the Scorer for prompts with the prompt_ids
"""
request_pieces: Sequence[PromptRequestPiece] = []
request_pieces = self._memory.get_prompt_request_pieces_by_id(prompt_ids=prompt_ids)
if responses_only:
request_pieces = self._extract_responses_only(request_pieces)
request_pieces = self._remove_duplicates(request_pieces)
return await scorer.score_prompts_batch_async(request_responses=request_pieces, batch_size=self._batch_size)
def _extract_responses_only(self, request_responses: Sequence[PromptRequestPiece]) -> list[PromptRequestPiece]:
"""
Extracts the responses from the list of PromptRequestPiece objects.
"""
return [response for response in request_responses if response.role == "assistant"]
def _remove_duplicates(self, request_responses: Sequence[PromptRequestPiece]) -> list[PromptRequestPiece]:
"""
Removes the duplicates from the list of PromptRequestPiece objects so that identical prompts are not
scored twice.
"""
return [response for response in request_responses if response.original_prompt_id == response.id]