Source code for pyrit.prompt_target.openai.openai_video_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
from mimetypes import guess_type
from typing import Any, Optional, Union, cast

from openai.types import VideoSeconds, VideoSize

from pyrit.exceptions import (
    pyrit_target_retry,
)
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import (
    DataTypeSerializer,
    Message,
    MessagePiece,
    construct_response_from_request,
    data_serializer_factory,
)
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute
from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error
from pyrit.prompt_target.openai.openai_target import OpenAITarget

logger = logging.getLogger(__name__)


[docs] class OpenAIVideoTarget(OpenAITarget): """ OpenAI Video Target using the OpenAI SDK for video generation. Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API. Supports three modes: - Text-to-video: Generate video from a text prompt - Text+Image-to-video: Generate video using an image as the first frame (include image_path piece) - Remix: Create variation of existing video (include video_id in prompt_metadata) Supported resolutions: - Sora-2: 720x1280, 1280x720 - Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024 Supported durations: 4, 8, or 12 seconds Default: resolution="1280x720", duration=4 seconds Supported image formats for text+image-to-video: JPEG, PNG, WEBP """ SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"] SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"] SUPPORTED_IMAGE_FORMATS: list[str] = ["image/jpeg", "image/png", "image/webp"] _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=False)
[docs] def __init__( self, *, resolution_dimensions: VideoSize = "1280x720", n_seconds: int | VideoSeconds = 4, **kwargs: Any, ) -> None: """ Initialize the OpenAI Video Target. Args: model_name (str, Optional): The video model to use (e.g., "sora-2", "sora-2-pro") (or deployment name in Azure). If no value is provided, the OPENAI_VIDEO_MODEL environment variable will be used. endpoint (str, Optional): The target URL for the OpenAI service. api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, or a callable that returns an access token. For Azure endpoints with Entra authentication, pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). Uses OPENAI_VIDEO_KEY environment variable by default. headers (str, Optional): Extra headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. resolution_dimensions (VideoSize, Optional): Resolution dimensions for the video. Defaults to "1280x720". Supported resolutions: - Sora-2: "720x1280", "1280x720" - Sora-2-Pro: "720x1280", "1280x720", "1024x1792", "1792x1024" n_seconds (int | VideoSeconds, Optional): The duration of the generated video. Accepts an int (4, 8, 12) or a VideoSeconds string ("4", "8", "12"). Defaults to 4. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` Remix workflow: To remix an existing video, set ``prompt_metadata={"video_id": "<id>"}`` on the text MessagePiece. The video_id is returned in the response metadata after any successful generation (``response.message_pieces[0].prompt_metadata["video_id"]``). """ super().__init__(**kwargs) self._n_seconds: VideoSeconds = ( cast("VideoSeconds", str(n_seconds)) if isinstance(n_seconds, int) else n_seconds ) self._validate_duration() self._size: VideoSize = self._validate_resolution(resolution_dimensions=resolution_dimensions)
def _set_openai_env_configuration_vars(self) -> None: """Set environment variable names.""" self.model_name_environment_variable = "OPENAI_VIDEO_MODEL" self.endpoint_environment_variable = "OPENAI_VIDEO_ENDPOINT" self.api_key_environment_variable = "OPENAI_VIDEO_KEY" self.underlying_model_environment_variable = "OPENAI_VIDEO_UNDERLYING_MODEL" def _get_target_api_paths(self) -> list[str]: """Return API paths that should not be in the URL.""" return ["/videos", "/v1/videos"] def _get_provider_examples(self) -> dict[str, str]: """Return provider-specific example URLs.""" return { ".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1", "api.openai.com": "https://api.openai.com/v1", } def _build_identifier(self) -> ComponentIdentifier: """ Build the identifier with video generation-specific parameters. Returns: ComponentIdentifier: The identifier for this target instance. """ return self._create_identifier( params={ "resolution": self._size, "n_seconds": self._n_seconds, }, ) def _validate_resolution(self, *, resolution_dimensions: VideoSize) -> VideoSize: """ Validate resolution dimensions. Args: resolution_dimensions: Resolution in WIDTHxHEIGHT format. Returns: The validated resolution string. Raises: ValueError: If the resolution is not supported. """ if resolution_dimensions not in self.SUPPORTED_RESOLUTIONS: raise ValueError( f"Invalid resolution '{resolution_dimensions}'. " f"Supported resolutions: {', '.join(self.SUPPORTED_RESOLUTIONS)}" ) return resolution_dimensions def _validate_duration(self) -> None: """ Validate video duration. Raises: ValueError: If the duration is not supported. """ if self._n_seconds not in self.SUPPORTED_DURATIONS: raise ValueError( f"Invalid duration '{self._n_seconds}'. " f"Supported durations: {', '.join(self.SUPPORTED_DURATIONS)} seconds" ) @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. Supports three modes: - Text-to-video: Single text piece - Text+Image-to-video: Text piece + image_path piece (image becomes first frame) - Remix: Text piece with prompt_metadata["video_id"] set to an existing video ID Args: message: The message object containing the prompt. Returns: A list containing the response with the generated video path. Raises: RateLimitException: If the rate limit is exceeded. ValueError: If the request is invalid. """ self._validate_request(message=message) text_piece = message.get_piece_by_type(data_type="text") image_piece = message.get_piece_by_type(data_type="image_path") prompt = text_piece.converted_value # Check for remix mode via prompt_metadata remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None logger.info(f"Sending video generation prompt: {prompt}") if remix_video_id: response = await self._send_remix_async(video_id=str(remix_video_id), prompt=prompt, request=message) elif image_piece: response = await self._send_text_plus_image_to_video_async( image_piece=image_piece, prompt=prompt, request=message ) else: response = await self._send_text_to_video_async(prompt=prompt, request=message) return [response] async def _send_remix_async(self, *, video_id: str, prompt: str, request: Message) -> Message: """ Send a remix request for an existing video. Args: video_id: The ID of the completed video to remix. prompt: The text prompt directing the remix. request: The original request message. Returns: The response Message with the generated video path. """ logger.info(f"Remix mode: Creating variation of video {video_id}") return await self._handle_openai_request( api_call=lambda: self._remix_and_poll_async(video_id=video_id, prompt=prompt), request=request, ) async def _send_text_plus_image_to_video_async( self, *, image_piece: MessagePiece, prompt: str, request: Message ) -> Message: """ Send a text+image-to-video request using an image as the first frame. Args: image_piece: The MessagePiece containing the image path. prompt: The text prompt describing the desired video. request: The original request message. Returns: The response Message with the generated video path. """ logger.info("Text+Image-to-video mode: Using image as first frame") input_file = await self._prepare_image_input_async(image_piece=image_piece) return await self._handle_openai_request( api_call=lambda: self._async_client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, seconds=self._n_seconds, input_reference=input_file, ), request=request, ) async def _send_text_to_video_async(self, *, prompt: str, request: Message) -> Message: """ Send a text-to-video generation request. Args: prompt: The text prompt describing the desired video. request: The original request message. Returns: The response Message with the generated video path. """ return await self._handle_openai_request( api_call=lambda: self._async_client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, seconds=self._n_seconds, ), request=request, ) async def _prepare_image_input_async(self, *, image_piece: MessagePiece) -> tuple[str, bytes, str]: """ Prepare image data for the OpenAI video API input_reference parameter. Reads the image bytes from storage and determines the MIME type. Args: image_piece: The MessagePiece containing the image path. Returns: A tuple of (filename, image_bytes, mime_type) for the SDK. Raises: ValueError: If the image format is not supported. """ image_path = image_piece.converted_value image_serializer = data_serializer_factory( value=image_path, data_type="image_path", category="prompt-memory-entries" ) image_bytes = await image_serializer.read_data() mime_type = DataTypeSerializer.get_mime_type(image_path) if not mime_type: mime_type, _ = guess_type(image_path, strict=False) if not mime_type or mime_type not in self.SUPPORTED_IMAGE_FORMATS: raise ValueError( f"Unsupported image format: {mime_type or 'unknown'}. " f"Supported formats: {', '.join(self.SUPPORTED_IMAGE_FORMATS)}" ) filename = os.path.basename(image_path) return (filename, image_bytes, mime_type) async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: """ Create a remix of an existing video and poll until complete. The OpenAI SDK's remix() method returns immediately with a job status. This method polls until the job completes or fails. Args: video_id: The ID of the completed video to remix. prompt: The text prompt directing the remix. Returns: The completed Video object from the OpenAI SDK. """ video = await self._async_client.videos.remix(video_id, prompt=prompt) # Poll until completion if not already done if video.status not in ["completed", "failed"]: video = await self._async_client.videos.poll(video.id) return video def _check_content_filter(self, response: Any) -> bool: """ Check if a video generation response was content filtered. Response indicates content filtering through: - Status is "failed" - Error code is "content_filter" (output-side filtering) - Error code is "moderation_blocked" (input moderation) Note: Input-side filtering (content_policy_violation via BadRequestError) is also caught by the base class before reaching this method. Args: response: A Video object from the OpenAI SDK. Returns: True if content was filtered, False otherwise. """ if response.status == "failed" and response.error: # Convert response to dict and use common filter detection response_dict = response.model_dump() return _is_content_filter_error(response_dict) return False async def _construct_message_from_response(self, response: Any, request: Any) -> Message: """ Construct a Message from a video response. Args: response: The Video response from OpenAI SDK. request: The original request MessagePiece. Returns: Message: Constructed message with video file path. """ video = response # Check if video generation was successful if video.status == "completed": logger.info(f"Video generation completed successfully: {video.id}") # Log remix metadata if available if video.remixed_from_video_id: logger.info(f"Video was remixed from: {video.remixed_from_video_id}") # Download video content using SDK video_response = await self._async_client.videos.download_content(video.id) # Extract bytes from HttpxBinaryResponseContent video_content = video_response.content # Save the video to storage (include video.id for chaining remixes) return await self._save_video_response(request=request, video_data=video_content, video_id=video.id) if video.status == "failed": # Handle failed video generation (non-content-filter) error_message = str(video.error) if video.error else "Video generation failed" logger.error(f"Video generation failed: {error_message}") # Non-content-filter errors are returned as processing errors return construct_response_from_request( request=request, response_text_pieces=[error_message], response_type="error", error="processing", ) # Unexpected status error_message = f"Video generation ended with unexpected status: {video.status}" logger.error(error_message) return construct_response_from_request( request=request, response_text_pieces=[error_message], response_type="error", error="unknown", ) async def _save_video_response( self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None ) -> Message: """ Save video data to storage and construct response. Args: request: The original request message piece. video_data: The video content as bytes. video_id: The video ID from the API (stored in metadata for chaining remixes). Returns: Message: The response with the video file path. """ # Save video using data serializer data = data_serializer_factory(category="prompt-memory-entries", data_type="video_path") await data.save_data(data=video_data) video_path = data.value logger.info(f"Video saved to: {video_path}") # Include video_id in metadata for chaining (e.g., remix the generated video later) prompt_metadata: Optional[dict[str, Union[str, int]]] = {"video_id": video_id} if video_id else None # Construct response return construct_response_from_request( request=request, response_text_pieces=[video_path], response_type="video_path", prompt_metadata=prompt_metadata, ) def _validate_request(self, *, message: Message) -> None: """ Validate the request message. Accepts: - Single text piece (text-to-video or remix mode) - Text piece + image_path piece (text+image-to-video mode) Args: message: The message to validate. Raises: ValueError: If the request is invalid. """ text_pieces = message.get_pieces_by_type(data_type="text") image_pieces = message.get_pieces_by_type(data_type="image_path") # Check for unsupported types supported_count = len(text_pieces) + len(image_pieces) if supported_count != len(message.message_pieces): other_types = [ p.converted_value_data_type for p in message.message_pieces if p.converted_value_data_type not in ("text", "image_path") ] raise ValueError(f"Unsupported piece types: {other_types}. Only 'text' and 'image_path' are supported.") # Must have exactly one text piece if len(text_pieces) != 1: raise ValueError(f"Expected exactly 1 text piece, got {len(text_pieces)}.") # At most one image piece if len(image_pieces) > 1: raise ValueError(f"Expected at most 1 image piece, got {len(image_pieces)}.") # Check for conflicting modes: remix + image text_piece = text_pieces[0] remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None if remix_video_id and image_pieces: raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") messages = self._memory.get_conversation(conversation_id=text_piece.conversation_id) n_messages = len(messages) if n_messages > 0: raise ValueError( "This target only supports a single turn conversation. " f"Received: {n_messages} messages which indicates a prior turn." )
[docs] def is_json_response_supported(self) -> bool: """ Check if the target supports JSON response data. Returns: bool: False, as video generation doesn't return JSON content. """ return False