# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any
from pyrit.exceptions import (
pyrit_target_retry,
)
from pyrit.models import (
Message,
MessagePiece,
construct_response_from_request,
data_serializer_factory,
)
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute
from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error
logger = logging.getLogger(__name__)
[docs]
class OpenAIVideoTarget(OpenAITarget):
"""
OpenAI Video Target using the OpenAI SDK for video generation.
Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API.
Supported resolutions:
- Sora-2: 720x1280, 1280x720
- Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024
Supported durations: 4, 8, or 12 seconds
Default: model="sora-2", resolution="1280x720", duration=4 seconds
"""
SUPPORTED_RESOLUTIONS = ["720x1280", "1280x720", "1024x1792", "1792x1024"]
SUPPORTED_DURATIONS = [4, 8, 12]
[docs]
def __init__(
self,
*,
resolution_dimensions: str = "1280x720",
n_seconds: int = 4,
**kwargs,
):
"""
Initialize the OpenAI Video Target.
Args:
model_name (str, Optional): The video model to use (e.g., "sora-2", "sora-2-pro").
Defaults to "sora-2".
endpoint (str, Optional): The target URL for the OpenAI service.
api_key (str, Optional): The API key for accessing the service.
Uses OPENAI_VIDEO_KEY environment variable by default.
headers (str, Optional): Extra headers of the endpoint (JSON).
use_entra_auth (bool, Optional): When set to True, user authentication is used
instead of API Key.
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit.
resolution_dimensions (str, Optional): Resolution dimensions for the video in WIDTHxHEIGHT format.
Defaults to "1280x720".
Supported resolutions:
- Sora-2: "720x1280", "1280x720"
- Sora-2-Pro: "720x1280", "1280x720", "1024x1792", "1792x1024"
n_seconds (int, Optional): The duration of the generated video (in seconds).
Defaults to 4. Supported values: 4, 8, or 12 seconds.
"""
super().__init__(**kwargs)
self._n_seconds = n_seconds
self._validate_duration()
self._size = self._validate_resolution(resolution_dimensions=resolution_dimensions)
def _set_openai_env_configuration_vars(self) -> None:
"""Set environment variable names."""
self.model_name_environment_variable = "OPENAI_VIDEO_MODEL"
self.endpoint_environment_variable = "OPENAI_VIDEO_ENDPOINT"
self.api_key_environment_variable = "OPENAI_VIDEO_KEY"
def _get_target_api_paths(self) -> list[str]:
"""Return API paths that should not be in the URL."""
return ["/videos", "/v1/videos"]
def _get_provider_examples(self) -> dict[str, str]:
"""Return provider-specific example URLs."""
return {
".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1",
"api.openai.com": "https://api.openai.com/v1",
}
def _validate_resolution(self, *, resolution_dimensions: str) -> str:
"""
Validate resolution dimensions.
Args:
resolution_dimensions: Resolution in WIDTHxHEIGHT format.
Returns:
The validated resolution string.
Raises:
ValueError: If the resolution is not supported.
"""
if resolution_dimensions not in self.SUPPORTED_RESOLUTIONS:
raise ValueError(
f"Invalid resolution '{resolution_dimensions}'. "
f"Supported resolutions: {', '.join(self.SUPPORTED_RESOLUTIONS)}"
)
return resolution_dimensions
def _validate_duration(self) -> None:
"""
Validate video duration.
Raises:
ValueError: If the duration is not supported.
"""
if self._n_seconds not in self.SUPPORTED_DURATIONS:
raise ValueError(
f"Invalid duration {self._n_seconds}s. "
f"Supported durations: {', '.join(map(str, self.SUPPORTED_DURATIONS))} seconds"
)
@limit_requests_per_minute
@pyrit_target_retry
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Asynchronously sends a message and generates a video using the OpenAI SDK.
Args:
message (Message): The message object containing the prompt.
Returns:
list[Message]: A list containing the response with the generated video path.
Raises:
RateLimitException: If the rate limit is exceeded.
ValueError: If the request is invalid.
"""
self._validate_request(message=message)
message_piece = message.message_pieces[0]
prompt = message_piece.converted_value
logger.info(f"Sending video generation prompt: {prompt}")
# Use unified error handler - automatically detects Video and validates
response = await self._handle_openai_request(
api_call=lambda: self._async_client.videos.create_and_poll(
model=self._model_name, # type: ignore[arg-type]
prompt=prompt,
size=self._size, # type: ignore[arg-type]
seconds=str(self._n_seconds), # type: ignore[arg-type]
),
request=message,
)
return [response]
def _check_content_filter(self, response: Any) -> bool:
"""
Check if a video generation response was content filtered.
Response indicates content filtering through:
- Status is "failed"
- Error code is "content_filter" (output-side filtering)
- Error code is "moderation_blocked" (input moderation)
Note: Input-side filtering (content_policy_violation via BadRequestError) is also caught
by the base class before reaching this method.
Args:
response: A Video object from the OpenAI SDK.
Returns:
True if content was filtered, False otherwise.
"""
if response.status == "failed" and response.error:
# Convert response to dict and use common filter detection
response_dict = response.model_dump()
return _is_content_filter_error(response_dict)
return False
async def _construct_message_from_response(self, response: Any, request: Any) -> Message:
"""
Construct a Message from a video response.
Args:
response: The Video response from OpenAI SDK.
request: The original request MessagePiece.
Returns:
Message: Constructed message with video file path.
"""
video = response
# Check if video generation was successful
if video.status == "completed":
logger.info(f"Video generation completed successfully: {video.id}")
# Download video content using SDK
video_response = await self._async_client.videos.download_content(video.id)
# Extract bytes from HttpxBinaryResponseContent
video_content = video_response.content
# Save the video to storage
return await self._save_video_response(request=request, video_data=video_content)
elif video.status == "failed":
# Handle failed video generation (non-content-filter)
error_message = str(video.error) if video.error else "Video generation failed"
logger.error(f"Video generation failed: {error_message}")
# Non-content-filter errors are returned as processing errors
return construct_response_from_request(
request=request,
response_text_pieces=[error_message],
response_type="error",
error="processing",
)
else:
# Unexpected status
error_message = f"Video generation ended with unexpected status: {video.status}"
logger.error(error_message)
return construct_response_from_request(
request=request,
response_text_pieces=[error_message],
response_type="error",
error="unknown",
)
async def _save_video_response(self, *, request: MessagePiece, video_data: bytes) -> Message:
"""
Save video data to storage and construct response.
Args:
request: The original request message piece.
video_data: The video content as bytes.
Returns:
Message: The response with the video file path.
"""
# Save video using data serializer
data = data_serializer_factory(category="prompt-memory-entries", data_type="video_path")
await data.save_data(data=video_data)
video_path = data.value
logger.info(f"Video saved to: {video_path}")
# Construct response
response_entry = construct_response_from_request(
request=request,
response_text_pieces=[video_path],
response_type="video_path",
)
return response_entry
def _validate_request(self, *, message: Message) -> None:
"""
Validate the request message.
Args:
message: The message to validate.
Raises:
ValueError: If the request is invalid.
"""
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")
piece_type = message.message_pieces[0].converted_value_data_type
if piece_type != "text":
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")
[docs]
def is_json_response_supported(self) -> bool:
"""
Check if the target supports JSON response data.
Returns:
bool: False, as video generation doesn't return JSON content.
"""
return False