Source code for pyrit.prompt_target.openai.openai_image_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Dict, Literal, Optional

from pyrit.exceptions import (
    EmptyResponseException,
    pyrit_target_retry,
)
from pyrit.models import (
    Message,
    PromptDataType,
    construct_response_from_request,
    data_serializer_factory,
)
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute

logger = logging.getLogger(__name__)


[docs] class OpenAIImageTarget(OpenAITarget): """A target for image generation using OpenAI's image models."""
[docs] def __init__( self, image_size: Literal["256x256", "512x512", "1024x1024", "1536x1024", "1024x1536"] = "1024x1024", quality: Optional[Literal["standard", "hd", "low", "medium", "high"]] = None, style: Optional[Literal["natural", "vivid"]] = None, *args, **kwargs, ): """ Initialize the image target with specified parameters. Args: model_name (str, Optional): The name of the model. If no value is provided, the OPENAI_IMAGE_MODEL environment variable will be used. endpoint (str, Optional): The target URL for the OpenAI service. api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, or a callable that returns an access token. For Azure endpoints with Entra authentication, pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). Defaults to the `OPENAI_IMAGE_API_KEY` environment variable. headers (str, Optional): Headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. image_size (Literal["256x256", "512x512", "1024x1024"], Optional): The size of the generated images. Defaults to "1024x1024". quality (Literal["standard", "hd", "low", "medium", "high"], Optional): The quality of the generated images. Different models support different quality settings. For DALL-E-3, there's "standard" and "hd". For newer models, there are "low", "medium", and "high". Default is to not specify. style (Literal["natural", "vivid"], Optional): The style of the generated images. Default is to not specify. *args: Additional positional arguments to be passed to AzureOpenAITarget. **kwargs: Additional keyword arguments to be passed to AzureOpenAITarget. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the `httpx.AsyncClient()` constructor. For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180} """ self.quality = quality self.style = style self.image_size = image_size # Flag to track if we need to explicitly request b64_json format # Will be set to True if the model returns URLs instead of base64 self._requires_response_format = False super().__init__(*args, **kwargs)
def _set_openai_env_configuration_vars(self): self.model_name_environment_variable = "OPENAI_IMAGE_MODEL" self.endpoint_environment_variable = "OPENAI_IMAGE_ENDPOINT" self.api_key_environment_variable = "OPENAI_IMAGE_API_KEY" def _get_target_api_paths(self) -> list[str]: """Return API paths that should not be in the URL.""" return ["/images/generations", "/v1/images/generations"] def _get_provider_examples(self) -> dict[str, str]: """Return provider-specific example URLs.""" return { ".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1", "api.openai.com": "https://api.openai.com/v1", } @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async( self, *, message: Message, ) -> list[Message]: """ Send a prompt to the DALL-E target and return the response. Args: message (Message): The message to send. Returns: list[Message]: A list containing the response from the image target. """ self._validate_request(message=message) message_piece = message.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {message_piece}") # Construct request parameters image_generation_args: Dict[str, Any] = { "model": self._model_name, "prompt": message_piece.converted_value, "size": self.image_size, } # Add response_format if we've detected the model returns URLs by default if self._requires_response_format: image_generation_args["response_format"] = "b64_json" if self.quality: image_generation_args["quality"] = self.quality if self.style: image_generation_args["style"] = self.style # Use unified error handler for consistent error handling response = await self._handle_openai_request( api_call=lambda: self._async_client.images.generate(**image_generation_args), request=message, ) return [response] async def _construct_message_from_response(self, response: Any, request: Any) -> Message: """ Construct a Message from an ImagesResponse. Args: response: The ImagesResponse from OpenAI SDK. request: The original request MessagePiece. Returns: Message: Constructed message with image path. Raises: EmptyResponseException: If the image generation returned an empty response or if the model returned a URL instead of base64. Note: PyRIT expects base64-encoded images. Some models (like dall-e) return URLs by default, while others (like gpt-image-1) always return base64. This method detects the format and adapts automatically. """ image_data = response.data[0] # Try to get base64 data first (preferred format) b64_data = getattr(image_data, "b64_json", None) if not b64_data: # Check if URL format was returned instead image_url = getattr(image_data, "url", None) if image_url: # Model returned URL instead of base64 - set flag and retry logger.info( "Image model returned URL instead of base64. " "Setting flag to request b64_json format in future calls." ) self._requires_response_format = True raise EmptyResponseException( message="Image was returned as URL instead of base64. Retrying with response_format parameter." ) else: # Neither URL nor base64 - truly empty response raise EmptyResponseException(message="The image generation returned an empty response.") # Save the image and get the file path data = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") await data.save_b64_image(data=b64_data) resp_text = data.value response_type: PromptDataType = "image_path" return construct_response_from_request( request=request, response_text_pieces=[resp_text], response_type=response_type ) def _validate_request(self, *, message: Message) -> None: n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") piece_type = message.message_pieces[0].converted_value_data_type if piece_type != "text": raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")
[docs] def is_json_response_supported(self) -> bool: """ Check if the target supports JSON as a response format. Returns: bool: True if JSON response is supported, False otherwise. """ return False