Source code for pyrit.prompt_target.openai.openai_dall_e_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from typing import Any, Dict, Literal

import httpx

from pyrit.common import net_utility
from pyrit.exceptions import (
    EmptyResponseException,
    handle_bad_request_exception,
    pyrit_target_retry,
)
from pyrit.exceptions.exception_classes import RateLimitException
from pyrit.models import (
    PromptDataType,
    PromptRequestResponse,
    construct_response_from_request,
    data_serializer_factory,
)
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute

logger = logging.getLogger(__name__)


[docs] class OpenAIDALLETarget(OpenAITarget): """ The Dalle3Target takes a prompt and generates images This class initializes a DALL-E image target """
[docs] def __init__( self, image_size: Literal["256x256", "512x512", "1024x1024"] = "1024x1024", num_images: int = 1, dalle_version: Literal["dall-e-2", "dall-e-3"] = "dall-e-2", quality: Literal["standard", "hd"] = "standard", style: Literal["natural", "vivid"] = "natural", *args, **kwargs, ): """ Initialize the DALL-E target with specified parameters. Args: model_name (str, Optional): The name of the model. endpoint (str, Optional): The target URL for the OpenAI service. api_key (str, Optional): The API key for accessing the Azure OpenAI service. Defaults to the OPENAI_DALLE_API_KEY environment variable. headers (str, Optional): Headers of the endpoint (JSON). use_aad_auth (bool, Optional): When set to True, user authentication is used instead of API Key. DefaultAzureCredential is taken for https://cognitiveservices.azure.com/.default . Please run `az login` locally to leverage user AuthN. api_version (str, Optional): The version of the Azure OpenAI API. Defaults to "2024-06-01". max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. image_size (Literal["256x256", "512x512", "1024x1024"], Optional): The size of the generated images. Defaults to "1024x1024". num_images (int, Optional): The number of images to generate. Defaults to 1. For DALL-E-2, this can be between 1 and 10. For DALL-E-3, this must be 1. dalle_version (Literal["dall-e-2", "dall-e-3"], Optional): The version of DALL-E to use. Defaults to "dall-e-2". quality (Literal["standard", "hd"], Optional): The quality of the generated images. Only applicable for DALL-E-3. Defaults to "standard". style (Literal["natural", "vivid"], Optional): The style of the generated images. Only applicable for DALL-E-3. Defaults to "natural". *args: Additional positional arguments to be passed to AzureOpenAITarget. **kwargs: Additional keyword arguments to be passed to AzureOpenAITarget. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the httpx.AsyncClient() constructor. For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180} Raises: ValueError: If `num_images` is not 1 for DALL-E-3. ValueError: If `num_images` is less than 1 or greater than 10 for DALL-E-2. """ self.dalle_version = dalle_version if dalle_version == "dall-e-3": if num_images != 1: raise ValueError("DALL-E-3 can only generate 1 image at a time.") self.quality = quality self.style = style elif dalle_version == "dall-e-2": if num_images < 1 or num_images > 10: raise ValueError("DALL-E-2 can generate only up to 10 images at a time.") self.image_size = image_size self.num_images = num_images super().__init__(*args, **kwargs)
def _set_openai_env_configuration_vars(self): self.model_name_environment_variable = "OPENAI_DALLE_MODEL" self.endpoint_environment_variable = "OPENAI_DALLE_ENDPOINT" self.api_key_environment_variable = "OPENAI_DALLE_API_KEY" @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async( self, *, prompt_request: PromptRequestResponse, ) -> PromptRequestResponse: """ Send a prompt to the DALL-E target and return the response. Args: prompt_request (PromptRequestResponse): The prompt request to send. Returns: PromptRequestResponse: The response from the DALL-E target. """ self._validate_request(prompt_request=prompt_request) request = prompt_request.request_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") # Refresh auth headers if using AAD self.refresh_auth_headers() body = self._construct_request_body(prompt=request.converted_value) params = {} if self._api_version is not None: params["api-version"] = self._api_version try: http_response: httpx.Response = await net_utility.make_request_and_raise_if_error_async( endpoint_uri=self._endpoint, method="POST", headers=self._headers, request_body=body, params=params, **self._httpx_client_kwargs, ) except httpx.HTTPStatusError as StatusError: if StatusError.response.status_code == 400: # Handle Bad Request return handle_bad_request_exception(response_text=StatusError.response.text, request=request) elif StatusError.response.status_code == 429: raise RateLimitException() else: raise json_response = json.loads(http_response.text) b64_data = json_response["data"][0]["b64_json"] # Handle empty response using retry if not b64_data: raise EmptyResponseException(message="The chat returned an empty response.") data = data_serializer_factory(category="prompt-memory-entries", data_type="image_path") await data.save_b64_image(data=b64_data) resp_text = data.value response_type: PromptDataType = "image_path" response_entry = construct_response_from_request( request=request, response_text_pieces=[resp_text], response_type=response_type ) return response_entry def _construct_request_body(self, prompt: str): image_generation_args: Dict[str, Any] = { "model": self._model_name, "prompt": prompt, "n": self.num_images, "size": self.image_size, "response_format": "b64_json", } if self.dalle_version == "dall-e-3" and self.quality and self.style: image_generation_args["quality"] = self.quality image_generation_args["style"] = self.style return image_generation_args def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: if len(prompt_request.request_pieces) != 1: raise ValueError("This target only supports a single prompt request piece.") if prompt_request.request_pieces[0].converted_value_data_type != "text": raise ValueError("This target only supports text prompt input.")
[docs] def is_json_response_supported(self) -> bool: """Indicates that this target supports JSON response format.""" return False