Source code for pyrit.prompt_target.common.prompt_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
import logging
from typing import Any, List, Optional

from pyrit.identifiers import Identifiable, TargetIdentifier
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Message

logger = logging.getLogger(__name__)


[docs] class PromptTarget(Identifiable[TargetIdentifier]): """ Abstract base class for prompt targets. A prompt target is a destination where prompts can be sent to interact with various services, models, or APIs. This class defines the interface that all prompt targets must implement. """ _memory: MemoryInterface #: A list of PromptConverters that are supported by the prompt target. #: An empty list implies that the prompt target supports all converters. supported_converters: List[Any] _identifier: Optional[TargetIdentifier] = None
[docs] def __init__( self, verbose: bool = False, max_requests_per_minute: Optional[int] = None, endpoint: str = "", model_name: str = "", underlying_model: Optional[str] = None, ) -> None: """ Initialize the PromptTarget. Args: verbose (bool): Enable verbose logging. Defaults to False. max_requests_per_minute (int, Optional): Maximum number of requests per minute. endpoint (str): The endpoint URL. Defaults to empty string. model_name (str): The model name. Defaults to empty string. underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for identification purposes. This is useful when the deployment name in Azure differs from the actual model. If not provided, `model_name` will be used for the identifier. Defaults to None. """ self._memory = CentralMemory.get_memory_instance() self._verbose = verbose self._max_requests_per_minute = max_requests_per_minute self._endpoint = endpoint self._model_name = model_name self._underlying_model = underlying_model if self._verbose: logging.basicConfig(level=logging.INFO)
[docs] @abc.abstractmethod async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Send a normalized prompt async to the prompt target. Returns: list[Message]: A list of message responses. Most targets return a single message, but some (like response target with tool calls) may return multiple messages. """
@abc.abstractmethod def _validate_request(self, *, message: Message) -> None: """ Validate the provided message. Args: message: The message to validate. """
[docs] def set_model_name(self, *, model_name: str) -> None: """ Set the model name for this target. Args: model_name (str): The model name to set. """ self._model_name = model_name
[docs] def dispose_db_engine(self) -> None: """ Dispose database engine to release database connections and resources. """ self._memory.dispose_engine()
def _create_identifier( self, *, temperature: Optional[float] = None, top_p: Optional[float] = None, target_specific_params: Optional[dict[str, Any]] = None, ) -> TargetIdentifier: """ Construct the target identifier. Subclasses should call this method in their _build_identifier() implementation to set the identifier with their specific parameters. Args: temperature (Optional[float]): The temperature parameter for generation. Defaults to None. top_p (Optional[float]): The top_p parameter for generation. Defaults to None. target_specific_params (Optional[dict[str, Any]]): Additional target-specific parameters that should be included in the identifier. Defaults to None. Returns: TargetIdentifier: The identifier for this prompt target. """ # Determine the model name to use model_name = "" if self._underlying_model: model_name = self._underlying_model elif self._model_name: model_name = self._model_name # Late import to avoid circular dependency from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget return TargetIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", endpoint=self._endpoint, model_name=model_name, temperature=temperature, top_p=top_p, max_requests_per_minute=self._max_requests_per_minute, supports_conversation_history=isinstance(self, PromptChatTarget), target_specific_params=target_specific_params, ) def _build_identifier(self) -> TargetIdentifier: """ Build the identifier for this target. Subclasses can override this method to call _create_identifier() with their specific parameters (temperature, top_p, target_specific_params). The base implementation calls _create_identifier() with no parameters, which works for targets that don't have model-specific settings. Returns: TargetIdentifier: The identifier for this prompt target. """ return self._create_identifier()