Source code for pyrit.score.human.human_in_the_loop_gradio
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
from typing import Optional
from pyrit.models import PromptRequestPiece, Score
from pyrit.score.scorer_prompt_validator import ScorerPromptValidator
from pyrit.score.true_false.true_false_score_aggregator import (
TrueFalseAggregatorFunc,
TrueFalseScoreAggregator,
)
from pyrit.score.true_false.true_false_scorer import TrueFalseScorer
[docs]
class HumanInTheLoopScorerGradio(TrueFalseScorer):
"""
Create scores from manual human input using Gradio and adds them to the database.
In the future this will not be a TrueFalseScorer. However, it is all that is supported currently.
Args:
open_browser (bool): If True, the scorer will open the Gradio interface in a browser
instead of opening it in PyWebview. Defaults to False.
validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None.
score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to
TrueFalseScoreAggregator.OR.
"""
_default_validator: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"])
[docs]
def __init__(
self,
*,
open_browser=False,
validator: Optional[ScorerPromptValidator] = None,
score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR,
) -> None:
# Import here to avoid importing rpyc in the main module that might not be installed
from pyrit.ui.rpc import AppRPCServer
super().__init__(validator=validator or self._default_validator, score_aggregator=score_aggregator)
self._rpc_server = AppRPCServer(open_browser=open_browser)
self._rpc_server.start()
async def _score_piece_async(
self, request_piece: PromptRequestPiece, *, objective: Optional[str] = None
) -> list[Score]:
"""Score a prompt request piece using human input through Gradio interface.
Args:
request_piece (PromptRequestPiece): The prompt request piece to be scored by a human.
objective (Optional[str]): The objective to evaluate against. Defaults to None.
Returns:
list[Score]: A list containing a single Score object based on human evaluation.
"""
try:
score = await asyncio.to_thread(self.retrieve_score, request_piece, objective=objective)
return score
except asyncio.CancelledError:
self._rpc_server.stop()
raise
[docs]
def retrieve_score(self, request_prompt: PromptRequestPiece, *, objective: Optional[str] = None) -> list[Score]:
"""Retrieve a score from the human evaluator through the RPC server.
Args:
request_prompt (PromptRequestPiece): The prompt request piece to be scored.
objective (Optional[str]): The objective to evaluate against. Defaults to None.
Returns:
list[Score]: A list containing a single Score object from the human evaluator.
"""
self._rpc_server.wait_for_client()
self._rpc_server.send_score_prompt(request_prompt)
score = self._rpc_server.wait_for_score()
score.scorer_class_identifier = self.get_identifier()
return [score]
def __del__(self):
self._rpc_server.stop()