Source code for pyrit.prompt_target.common.prompt_target
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import logging
from typing import Any, Dict, List, Optional
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import Identifiable, Message
logger = logging.getLogger(__name__)
[docs]
class PromptTarget(abc.ABC, Identifiable):
"""
Abstract base class for prompt targets.
A prompt target is a destination where prompts can be sent to interact with various services,
models, or APIs. This class defines the interface that all prompt targets must implement.
"""
_memory: MemoryInterface
#: A list of PromptConverters that are supported by the prompt target.
#: An empty list implies that the prompt target supports all converters.
supported_converters: List[Any]
[docs]
def __init__(
self,
verbose: bool = False,
max_requests_per_minute: Optional[int] = None,
endpoint: str = "",
model_name: str = "",
underlying_model: Optional[str] = None,
) -> None:
"""
Initialize the PromptTarget.
Args:
verbose (bool): Enable verbose logging. Defaults to False.
max_requests_per_minute (int, Optional): Maximum number of requests per minute.
endpoint (str): The endpoint URL. Defaults to empty string.
model_name (str): The model name. Defaults to empty string.
underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for
identification purposes. This is useful when the deployment name in Azure differs
from the actual model. If not provided, `model_name` will be used for the identifier.
Defaults to None.
"""
self._memory = CentralMemory.get_memory_instance()
self._verbose = verbose
self._max_requests_per_minute = max_requests_per_minute
self._endpoint = endpoint
self._model_name = model_name
self._underlying_model = underlying_model
if self._verbose:
logging.basicConfig(level=logging.INFO)
[docs]
@abc.abstractmethod
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Send a normalized prompt async to the prompt target.
Returns:
list[Message]: A list of message responses. Most targets return a single message,
but some (like response target with tool calls) may return multiple messages.
"""
@abc.abstractmethod
def _validate_request(self, *, message: Message) -> None:
"""
Validate the provided message.
Args:
message: The message to validate.
"""
[docs]
def set_model_name(self, *, model_name: str) -> None:
"""
Set the model name for this target.
Args:
model_name (str): The model name to set.
"""
self._model_name = model_name
[docs]
def dispose_db_engine(self) -> None:
"""
Dispose database engine to release database connections and resources.
"""
self._memory.dispose_engine()
[docs]
def get_identifier(self) -> Dict[str, Any]:
"""
Get an identifier dictionary for this prompt target.
This includes essential attributes needed for scorer evaluation and registry tracking.
Subclasses should override this method to include additional relevant attributes
(e.g., temperature, top_p) when available.
Returns:
Dict[str, Any]: A dictionary containing identification attributes.
Note:
If the `self._underlying_model` is specified, either passed in during instantiation
or via environment variable, it is used as the "model_name" for the identifier.
Otherwise, `self._model_name` (which is often the deployment name in Azure) is used.
"""
public_attributes: Dict[str, Any] = {}
public_attributes["__type__"] = self.__class__.__name__
public_attributes["__module__"] = self.__class__.__module__
if self._endpoint:
public_attributes["endpoint"] = self._endpoint
# if the underlying model is specified, use it as the model name for identification
# otherwise, use self._model_name (which is often the deployment name in Azure)
if self._underlying_model:
public_attributes["model_name"] = self._underlying_model
elif self._model_name:
public_attributes["model_name"] = self._model_name
# Include temperature and top_p if available (set by subclasses)
if hasattr(self, "_temperature") and self._temperature is not None:
public_attributes["temperature"] = self._temperature
if hasattr(self, "_top_p") and self._top_p is not None:
public_attributes["top_p"] = self._top_p
return public_attributes