Source code for pyrit.score.human_in_the_loop_gradio
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
from typing import Optional
from pyrit.models import PromptRequestPiece, Score
from pyrit.score.scorer import Scorer
[docs]
class HumanInTheLoopScorerGradio(Scorer):
"""
Create scores from manual human input using Gradio and adds them to the database.
Parameters:
open_browser(bool): The scorer will open the Gradio interface in a browser instead of opening it in PyWebview
"""
[docs]
def __init__(self, *, open_browser=False) -> None:
# Import here to avoid importing rpyc in the main module that might not be installed
from pyrit.ui.rpc import AppRPCServer
self._rpc_server = AppRPCServer(open_browser=open_browser)
self._rpc_server.start()
[docs]
async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]:
self.validate(request_response=request_response)
try:
score = await asyncio.to_thread(self.retrieve_score, request_response, task=task)
self._memory.add_scores_to_memory(scores=score)
return score
except asyncio.CancelledError:
self._rpc_server.stop()
raise
[docs]
def retrieve_score(self, request_prompt: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]:
self._rpc_server.wait_for_client()
self._rpc_server.send_score_prompt(request_prompt)
score = self._rpc_server.wait_for_score()
score.scorer_class_identifier = self.get_identifier()
return [score]
[docs]
def validate(self, request_response: PromptRequestPiece, *, task: Optional[str] = None):
if request_response.converted_value_data_type != "text":
raise ValueError("Prompt data type must be 'text' for Gradio manual scoring.")
def __del__(self):
self._rpc_server.stop()