Source code for pyrit.orchestrator.orchestrator_class

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
import ast
import logging
import uuid

from typing import Optional

from pyrit.common import default_values
from pyrit.memory import MemoryInterface, CentralMemory
from pyrit.models import PromptDataType, Identifier
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_normalizer import NormalizerRequest, NormalizerRequestPiece

logger = logging.getLogger(__name__)


[docs] class Orchestrator(abc.ABC, Identifier): _memory: MemoryInterface
[docs] def __init__( self, *, prompt_converters: Optional[list[PromptConverter]] = None, verbose: bool = False, ): self._prompt_converters = prompt_converters if prompt_converters else [] self._memory = CentralMemory.get_memory_instance() self._verbose = verbose self._id = uuid.uuid4() # Pull in global memory labels from .env.local. memory_labels. These labels will be applied to all prompts # sent via orchestrator. self._global_memory_labels: dict[str, str] = ast.literal_eval( default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS", passed_value=None) or "{}" ) if self._verbose: logging.basicConfig(level=logging.INFO)
def __enter__(self): """Enter the runtime context related to this object.""" return self # You can return self or another object that should be used in the with-statement. def __exit__(self, exc_type, exc_val, exc_tb): """Exit the runtime context and perform any cleanup actions.""" self.dispose_db_engine()
[docs] def dispose_db_engine(self) -> None: """ Dispose database engine to release database connections and resources. """ self._memory.dispose_engine()
def _create_normalizer_request( self, prompt_text: str, prompt_type: PromptDataType = "text", converters=None, metadata=None, conversation_id=None, ): if converters is None: converters = self._prompt_converters request_piece = NormalizerRequestPiece( request_converters=converters, prompt_value=prompt_text, prompt_data_type=prompt_type, metadata=metadata ) request = NormalizerRequest(request_pieces=[request_piece], conversation_id=conversation_id) return request def _combine_with_global_memory_labels(self, memory_labels: dict[str, str]) -> dict[str, str]: """ Combines the global memory labels with the provided memory labels. The passed memory_labels take precedence with collisions. """ return {**(self._global_memory_labels or {}), **(memory_labels or {})}
[docs] def get_memory(self): """ Retrieves the memory associated with this orchestrator. """ return self._memory.get_prompt_request_piece_by_orchestrator_id(orchestrator_id=self._id)
[docs] def get_score_memory(self): """ Retrieves the scores of the PromptRequestPieces associated with this orchestrator. These exist if a scorer is provided to the orchestrator. """ return self._memory.get_scores_by_orchestrator_id(orchestrator_id=self._id)
[docs] def get_identifier(self) -> dict[str, str]: orchestrator_dict = {} orchestrator_dict["__type__"] = self.__class__.__name__ orchestrator_dict["__module__"] = self.__class__.__module__ orchestrator_dict["id"] = str(self._id) return orchestrator_dict