Source code for pyrit.prompt_target.openai.openai_chat_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Any, MutableSequence, Optional

from pyrit.common import convert_local_image_to_data_url
from pyrit.exceptions import (
    EmptyResponseException,
    PyritException,
    pyrit_target_retry,
)
from pyrit.models import (
    ChatMessage,
    ChatMessageListDictContent,
    Message,
    MessagePiece,
    construct_response_from_request,
)
from pyrit.prompt_target import (
    OpenAITarget,
    PromptChatTarget,
    limit_requests_per_minute,
)
from pyrit.prompt_target.common.utils import validate_temperature, validate_top_p

logger = logging.getLogger(__name__)


[docs] class OpenAIChatTarget(OpenAITarget, PromptChatTarget): """ This class facilitates multimodal (image and text) input and text output generation. This works with GPT3.5, GPT4, GPT4o, GPT-V, and other compatible models Args: api_key (str): The api key for the OpenAI API endpoint (str): The endpoint for the OpenAI API model_name (str): The model name for the OpenAI API deployment_name (str): For Azure, the deployment name temperature (float): The temperature for the completion max_completion_tokens (int): The maximum number of tokens to be returned by the model. The total length of input tokens and generated tokens is limited by the model's context length. max_tokens (int): Deprecated. Use max_completion_tokens instead top_p (float): The nucleus sampling probability. frequency_penalty (float): Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. presence_penalty (float): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. seed (int): This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. n (int): How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. extra_body_parameters (dict): Additional parameters to send in the request body """
[docs] def __init__( self, *, max_completion_tokens: Optional[int] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, seed: Optional[int] = None, n: Optional[int] = None, is_json_supported: bool = True, extra_body_parameters: Optional[dict[str, Any]] = None, **kwargs, ): """ Args: model_name (str, Optional): The name of the model. If no value is provided, the OPENAI_CHAT_MODEL environment variable will be used. endpoint (str, Optional): The target URL for the OpenAI service. api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, or a callable that returns an access token. For Azure endpoints with Entra authentication, pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). Defaults to the `OPENAI_CHAT_KEY` environment variable. headers (str, Optional): Headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. max_completion_tokens (int, Optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. NOTE: Specify this value when using an o1 series model. max_tokens (int, Optional): The maximum number of tokens that can be generated in the chat completion. This value can be used to control costs for text generated via API. This value is now deprecated in favor of `max_completion_tokens`, and IS NOT COMPATIBLE with o1 series models. temperature (float, Optional): The temperature parameter for controlling the randomness of the response. top_p (float, Optional): The top-p parameter for controlling the diversity of the response. frequency_penalty (float, Optional): The frequency penalty parameter for penalizing frequently generated tokens. presence_penalty (float, Optional): The presence penalty parameter for penalizing tokens that are already present in the conversation history. seed (int, Optional): If specified, openAI will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. n (int, Optional): The number of completions to generate for each prompt. is_json_supported (bool, Optional): If True, the target will support formatting responses as JSON by setting the response_format header. Official OpenAI models all support this, but if you are using this target with different models, is_json_supported should be set correctly to avoid issues when using adversarial infrastructure (e.g. Crescendo scorers will set this flag). extra_body_parameters (dict, Optional): Additional parameters to be included in the request body. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the `httpx.AsyncClient()` constructor. For example, to specify a 3 minute timeout: httpx_client_kwargs={"timeout": 180} Raises: PyritException: If the temperature or top_p values are out of bounds. ValueError: If the temperature is not between 0 and 2 (inclusive). ValueError: If the top_p is not between 0 and 1 (inclusive). ValueError: If both `max_completion_tokens` and `max_tokens` are provided. RateLimitException: If the target is rate-limited. httpx.HTTPStatusError: If the request fails with a 400 Bad Request or 429 Too Many Requests error. json.JSONDecodeError: If the response from the target is not valid JSON. Exception: If the request fails for any other reason. """ super().__init__(**kwargs) # Validate temperature and top_p validate_temperature(temperature) validate_top_p(top_p) if max_completion_tokens and max_tokens: raise ValueError("Cannot provide both max_tokens and max_completion_tokens.") self._temperature = temperature self._top_p = top_p self._is_json_supported = is_json_supported self._max_completion_tokens = max_completion_tokens self._max_tokens = max_tokens self._frequency_penalty = frequency_penalty self._presence_penalty = presence_penalty self._seed = seed self._n = n self._extra_body_parameters = extra_body_parameters
def _set_openai_env_configuration_vars(self): self.model_name_environment_variable = "OPENAI_CHAT_MODEL" self.endpoint_environment_variable = "OPENAI_CHAT_ENDPOINT" self.api_key_environment_variable = "OPENAI_CHAT_KEY" def _get_target_api_paths(self) -> list[str]: """Return API paths that should not be in the URL.""" return ["/chat/completions", "/v1/chat/completions"] def _get_provider_examples(self) -> dict[str, str]: """Return provider-specific example URLs.""" return { ".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1", "api.openai.com": "https://api.openai.com/v1", "api.anthropic.com": "https://api.anthropic.com/v1", "generativelanguage.googleapis.com": "https://generativelanguage.googleapis.com/v1beta/openai", } @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously sends a message and handles the response within a managed conversation context. Args: message (Message): The message object. Returns: list[Message]: A list containing the response from the prompt target. """ self._validate_request(message=message) message_piece: MessagePiece = message.message_pieces[0] is_json_response = self.is_response_format_json(message_piece) # Get conversation from memory and append the current message conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) conversation.append(message) logger.info(f"Sending the following prompt to the prompt target: {message}") body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response) # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( api_call=lambda: self._async_client.chat.completions.create(**body), request=message, ) return [response] def _check_content_filter(self, response: Any) -> bool: """ Check if a Chat Completions API response has finish_reason=content_filter. Args: response: A ChatCompletion object from the OpenAI SDK. Returns: True if content was filtered, False otherwise. """ try: if response.choices and response.choices[0].finish_reason == "content_filter": return True except (AttributeError, IndexError): pass return False def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: """ Validate a Chat Completions API response for errors. Checks for: - Missing choices - Invalid finish_reason - Empty content Args: response: The ChatCompletion response from OpenAI SDK. request: The original request MessagePiece. Returns: None if valid, does not return Message for content filter (handled by _check_content_filter). Raises: PyritException: For unexpected response structures or finish reasons. EmptyResponseException: When the API returns an empty response. """ # Check for missing choices if not response.choices: raise PyritException(message="No choices returned in the completion response.") choice = response.choices[0] finish_reason = choice.finish_reason # Check finish_reason (content_filter is handled by _check_content_filter) if finish_reason not in ["stop", "length", "content_filter"]: # finish_reason="stop" means API returned complete message # "length" means API returned incomplete message due to max_tokens limit raise PyritException( message=f"Unknown finish_reason {finish_reason} from response: {response.model_dump_json()}" ) # Check for empty content content = choice.message.content or "" if not content: logger.error("The chat returned an empty response.") raise EmptyResponseException(message="The chat returned an empty response.") return None async def _construct_message_from_response(self, response: Any, request: MessagePiece) -> Message: """ Construct a Message from a ChatCompletion response. Args: response: The ChatCompletion response from OpenAI SDK. request: The original request MessagePiece. Returns: Message: Constructed message with extracted content. """ extracted_response = response.choices[0].message.content or "" return construct_response_from_request(request=request, response_text_pieces=[extracted_response])
[docs] def is_json_response_supported(self) -> bool: """Indicates that this target supports JSON response format.""" return self._is_json_supported
async def _build_chat_messages_async(self, conversation: MutableSequence[Message]) -> list[dict]: """ Builds chat messages based on message entries. Args: conversation (list[Message]): A list of Message objects. Returns: list[dict]: The list of constructed chat messages. """ if self._is_text_message_format(conversation): return self._build_chat_messages_for_text(conversation) else: return await self._build_chat_messages_for_multi_modal_async(conversation) def _is_text_message_format(self, conversation: MutableSequence[Message]) -> bool: """ Checks if the message piece is in text message format. Args: conversation list[Message]: The conversation Returns: bool: True if the message piece is in text message format, False otherwise. """ for turn in conversation: if len(turn.message_pieces) != 1: return False if turn.message_pieces[0].converted_value_data_type != "text": return False return True def _build_chat_messages_for_text(self, conversation: MutableSequence[Message]) -> list[dict]: """ Builds chat messages based on message entries. This is needed because many openai "compatible" models don't support ChatMessageListDictContent format (this is more universally accepted). Args: conversation (list[Message]): A list of Message objects. Returns: list[dict]: The list of constructed chat messages. """ chat_messages: list[dict] = [] for message in conversation: # validated to only have one text entry if len(message.message_pieces) != 1: raise ValueError("_build_chat_messages_for_text only supports a single message piece.") message_piece = message.message_pieces[0] if message_piece.converted_value_data_type != "text": raise ValueError("_build_chat_messages_for_text only supports text.") chat_message = ChatMessage(role=message_piece.role, content=message_piece.converted_value) chat_messages.append(chat_message.model_dump(exclude_none=True)) return chat_messages async def _build_chat_messages_for_multi_modal_async(self, conversation: MutableSequence[Message]) -> list[dict]: """ Builds chat messages based on message entries. Args: conversation (list[Message]): A list of Message objects. Returns: list[dict]: The list of constructed chat messages. """ chat_messages: list[dict] = [] for message in conversation: message_pieces = message.message_pieces content = [] role = None for message_piece in message_pieces: role = message_piece.role if message_piece.converted_value_data_type == "text": entry = {"type": "text", "text": message_piece.converted_value} content.append(entry) elif message_piece.converted_value_data_type == "image_path": data_base64_encoded_url = await convert_local_image_to_data_url(message_piece.converted_value) image_url_entry = {"url": data_base64_encoded_url} entry = {"type": "image_url", "image_url": image_url_entry} # type: ignore content.append(entry) else: raise ValueError( f"Multimodal data type {message_piece.converted_value_data_type} is not yet supported." ) if not role: raise ValueError("No role could be determined from the message pieces.") chat_message = ChatMessageListDictContent(role=role, content=content) # type: ignore chat_messages.append(chat_message.model_dump(exclude_none=True)) return chat_messages async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: messages = await self._build_chat_messages_async(conversation) body_parameters = { "model": self._model_name, "max_completion_tokens": self._max_completion_tokens, "max_tokens": self._max_tokens, "temperature": self._temperature, "top_p": self._top_p, "frequency_penalty": self._frequency_penalty, "presence_penalty": self._presence_penalty, "stream": False, "seed": self._seed, "n": self._n, "messages": messages, "response_format": {"type": "json_object"} if is_json_response else None, } if self._extra_body_parameters: for key, value in self._extra_body_parameters.items(): body_parameters[key] = value # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} def _validate_request(self, *, message: Message) -> None: """ Validates the structure and content of a message for compatibility of this target. Args: message (Message): The message object. Raises: ValueError: If any of the message pieces have a data type other than 'text' or 'image_path'. """ converted_prompt_data_types = [ message_piece.converted_value_data_type for message_piece in message.message_pieces ] # Some models may not support all of these for prompt_data_type in converted_prompt_data_types: if prompt_data_type not in ["text", "image_path"]: raise ValueError(f"This target only supports text and image_path. Received: {prompt_data_type}.")