Source code for pyrit.prompt_target.openai.openai_realtime_target
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import base64
import logging
import re
import wave
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, Tuple
from openai import AsyncOpenAI
from pyrit.exceptions import (
pyrit_target_retry,
)
from pyrit.exceptions.exception_classes import ServerErrorException
from pyrit.models import (
Message,
construct_response_from_request,
data_serializer_factory,
)
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute
logger = logging.getLogger(__name__)
RealTimeVoice = Literal["alloy", "echo", "shimmer"]
@dataclass
class RealtimeTargetResult:
"""
Represents the result of a Realtime API request, containing audio data and transcripts.
Attributes:
audio_bytes: Raw audio data returned by the API
transcripts: List of text transcripts generated from the audio
"""
audio_bytes: bytes = field(default_factory=lambda: b"")
transcripts: List[str] = field(default_factory=list)
def flatten_transcripts(self) -> str:
"""
Flattens the list of transcripts into a single string.
Returns:
A single string containing all transcripts concatenated together.
"""
return "".join(self.transcripts)
[docs]
class RealtimeTarget(OpenAITarget):
[docs]
def __init__(
self,
*,
voice: Optional[RealTimeVoice] = None,
existing_convo: Optional[dict] = None,
**kwargs,
) -> None:
"""
RealtimeTarget class for Azure OpenAI Realtime API.
Read more at https://learn.microsoft.com/en-us/azure/ai-services/openai/realtime-audio-reference
and https://platform.openai.com/docs/guides/realtime-websocket
Args:
model_name (str, Optional): The name of the model.
If no value is provided, the OPENAI_REALTIME_MODEL environment variable will be used.
endpoint (str, Optional): The target URL for the OpenAI service.
Defaults to the `OPENAI_REALTIME_ENDPOINT` environment variable.
api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service,
or a callable that returns an access token. For Azure endpoints with Entra authentication,
pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)).
Defaults to the `OPENAI_REALTIME_API_KEY` environment variable.
headers (str, Optional): Headers of the endpoint (JSON).
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit. The number of requests sent to the target
will be capped at the value provided.
voice (literal str, Optional): The voice to use. Defaults to None.
the only supported voices by the AzureOpenAI Realtime API are "alloy", "echo", and "shimmer".
existing_convo (dict[str, websockets.WebSocketClientProtocol], Optional): Existing conversations.
httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the
httpx.AsyncClient() constructor.
For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180}
"""
super().__init__(**kwargs)
self.voice = voice
self._existing_conversation = existing_convo if existing_convo is not None else {}
self._realtime_client = None
def _set_openai_env_configuration_vars(self):
self.model_name_environment_variable = "OPENAI_REALTIME_MODEL"
self.endpoint_environment_variable = "OPENAI_REALTIME_ENDPOINT"
self.api_key_environment_variable = "OPENAI_REALTIME_API_KEY"
def _get_target_api_paths(self) -> list[str]:
"""Return API paths that should not be in the URL."""
return ["/realtime", "/v1/realtime"]
def _get_provider_examples(self) -> dict[str, str]:
"""Return provider-specific example URLs."""
return {
".openai.azure.com": "wss://{resource}.openai.azure.com/openai/v1",
"api.openai.com": "wss://api.openai.com/v1",
}
def _validate_url_for_target(self, endpoint_url: str) -> None:
"""
Validate URL for Realtime API with websocket-specific checks.
Args:
endpoint_url: The endpoint URL to validate.
"""
# Convert https to wss for validation (this is expected for websockets)
check_url = endpoint_url.replace("https://", "wss://") if endpoint_url.startswith("https://") else endpoint_url
# Check for proper scheme
if not check_url.startswith("wss://"):
logger.warning(
f"Realtime endpoint should use 'wss://' or 'https://' scheme, got: {endpoint_url}. "
"The endpoint may not work correctly."
)
return
# Call parent validation with the wss URL
super()._validate_url_for_target(check_url)
def _warn_if_irregular_endpoint(self, endpoint: str) -> None:
"""
Warns if the endpoint URL does not match expected patterns.
Args:
endpoint: The endpoint URL to validate
"""
# Expected patterns for realtime endpoints:
# Azure old format: wss://resource.openai.azure.com/openai/realtime?api-version=...
# Azure new format: wss://resource.openai.azure.com/openai/v1
# Platform OpenAI: wss://api.openai.com/v1
# Also accept https:// versions that will be converted to wss://
# Check for proper scheme (wss:// or https://)
if not endpoint.startswith(("wss://", "https://")):
logger.warning(
f"Realtime endpoint should start with 'wss://' or 'https://', got: {endpoint}. "
"This may cause connection issues."
)
return
# Pattern for Azure endpoints
azure_pattern = re.compile(
r"^(wss|https)://[a-zA-Z0-9\-]+\.openai\.azure\.com/"
r"(openai/(deployments/[^/]+/)?realtime(\?api-version=[^/]+)?|openai/v1|v1)$"
)
# Pattern for Platform OpenAI
platform_pattern = re.compile(r"^(wss|https)://api\.openai\.com/(v1(/realtime)?|realtime)$")
if not azure_pattern.match(endpoint) and not platform_pattern.match(endpoint):
logger.warning(
f"Realtime endpoint URL does not match expected Azure or Platform OpenAI patterns: {endpoint}. "
"Expected formats: 'wss://resource.openai.azure.com/openai/v1' or 'wss://api.openai.com/v1'"
)
def _get_openai_client(self):
"""
Creates or returns the AsyncOpenAI client configured for Realtime API.
Uses the Azure GA approach with websocket_base_url.
"""
if self._realtime_client is None:
# Convert https:// to wss:// for websocket connections if needed
websocket_base_url = (
self._endpoint.replace("https://", "wss://")
if self._endpoint.startswith("https://")
else self._endpoint
)
logger.info(f"Creating realtime client with websocket_base_url: {websocket_base_url}")
self._realtime_client = AsyncOpenAI(
websocket_base_url=websocket_base_url,
api_key=self._api_key,
)
return self._realtime_client
[docs]
async def connect(self, conversation_id: str):
"""
Connects to Realtime API using AsyncOpenAI client.
Returns the realtime connection.
"""
logger.info(f"Connecting to Realtime API: {self._endpoint}")
client = self._get_openai_client()
connection = await client.realtime.connect(model=self._model_name).__aenter__()
logger.info("Successfully connected to AzureOpenAI Realtime API")
return connection
def _set_system_prompt_and_config_vars(self, system_prompt: str):
"""
Creates session configuration for OpenAI client.
Uses the Azure GA format with nested audio config.
"""
session_config = {
"type": "realtime",
"instructions": system_prompt,
"output_modalities": ["audio"], # Use only audio modality
"audio": {
"input": {
"transcription": {
"model": "whisper-1",
},
"format": {
"type": "audio/pcm",
"rate": 24000,
},
},
"output": {
"format": {
"type": "audio/pcm",
"rate": 24000,
}
},
},
}
if self.voice:
session_config["audio"]["output"]["voice"] = self.voice # type: ignore[index]
return session_config
[docs]
async def send_config(self, conversation_id: str):
"""
Sends the session configuration using OpenAI client.
Args:
conversation_id (str): Conversation ID
"""
# Extract system prompt from conversation history
system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id)
config_variables = self._set_system_prompt_and_config_vars(system_prompt=system_prompt)
connection = self._get_connection(conversation_id=conversation_id)
await connection.session.update(session=config_variables)
logger.info("Session configuration sent")
def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str:
"""
Retrieves the system prompt from conversation history.
Args:
conversation_id (str): The conversation ID
Returns:
str: The system prompt from conversation history, or a default if none found
"""
conversation = self._memory.get_conversation(conversation_id=conversation_id)
# Look for a system message at the beginning of the conversation
if conversation and len(conversation) > 0:
first_message = conversation[0]
if first_message.message_pieces and first_message.message_pieces[0].role == "system":
return first_message.message_pieces[0].converted_value
# Return default system prompt if none found in conversation
return "You are a helpful AI assistant"
@limit_requests_per_minute
@pyrit_target_retry
async def send_prompt_async(self, *, message: Message) -> list[Message]:
conversation_id = message.message_pieces[0].conversation_id
if conversation_id not in self._existing_conversation:
connection = await self.connect(conversation_id=conversation_id)
self._existing_conversation[conversation_id] = connection
# Only send config when creating a new connection
await self.send_config(conversation_id=conversation_id)
# Give the server a moment to process the session update
await asyncio.sleep(0.5)
self._validate_request(message=message)
request = message.message_pieces[0]
response_type = request.converted_value_data_type
# Order of messages sent varies based on the data format of the prompt
if response_type == "audio_path":
output_audio_path, result = await self.send_audio_async(
filename=request.converted_value, conversation_id=conversation_id
)
elif response_type == "text":
output_audio_path, result = await self.send_text_async(
text=request.converted_value, conversation_id=conversation_id
)
else:
raise ValueError(f"Unsupported response type: {response_type}")
text_response_piece = construct_response_from_request(
request=request, response_text_pieces=[result.flatten_transcripts()], response_type="text"
).message_pieces[0]
audio_response_piece = construct_response_from_request(
request=request, response_text_pieces=[output_audio_path], response_type="audio_path"
).message_pieces[0]
response_entry = Message(message_pieces=[text_response_piece, audio_response_piece])
return [response_entry]
[docs]
async def save_audio(
self,
audio_bytes: bytes,
num_channels: int = 1,
sample_width: int = 2,
sample_rate: int = 16000,
output_filename: Optional[str] = None,
) -> str:
"""
Saves audio bytes to a WAV file.
Args:
audio_bytes (bytes): Audio bytes to save.
num_channels (int): Number of audio channels. Defaults to 1 for the PCM16 format
sample_width (int): Sample width in bytes. Defaults to 2 for the PCM16 format
sample_rate (int): Sample rate in Hz. Defaults to 16000 Hz for the PCM16 format
output_filename (str): Output filename. If None, a UUID filename will be used.
Returns:
str: The path to the saved audio file.
"""
data = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path")
await data.save_formatted_audio(
data=audio_bytes,
output_filename=output_filename,
num_channels=num_channels,
sample_width=sample_width,
sample_rate=sample_rate,
)
return data.value
[docs]
async def cleanup_target(self):
"""
Disconnects from the Realtime API connections.
"""
for conversation_id, connection in list(self._existing_conversation.items()):
if connection:
try:
await connection.close()
logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
except Exception as e:
logger.warning(f"Error closing connection for {conversation_id}: {e}")
self._existing_conversation = {}
if self._realtime_client:
try:
await self._realtime_client.close()
except Exception as e:
logger.warning(f"Error closing realtime client: {e}")
self._realtime_client = None
[docs]
async def cleanup_conversation(self, conversation_id: str):
"""
Disconnects from the Realtime API for a specific conversation.
Args:
conversation_id (str): The conversation ID to disconnect from.
"""
connection = self._existing_conversation.get(conversation_id)
if connection:
try:
await connection.close()
logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
except Exception as e:
logger.warning(f"Error closing connection for {conversation_id}: {e}")
del self._existing_conversation[conversation_id]
[docs]
async def send_response_create(self, conversation_id: str):
"""
Sends response.create using OpenAI client.
Args:
conversation_id (str): Conversation ID
"""
connection = self._get_connection(conversation_id=conversation_id)
await connection.response.create()
[docs]
async def receive_events(self, conversation_id: str) -> RealtimeTargetResult:
"""
Continuously receive events from the OpenAI Realtime API connection.
Uses a robust "soft-finish" strategy to handle cases where response.done
may not arrive. After receiving audio.done, waits for a grace period
before soft-finishing if no response.done arrives.
Args:
conversation_id: conversation ID
Returns:
RealtimeTargetResult with audio data and transcripts
Raises:
ConnectionError: If connection is not valid
RuntimeError: If server returns an error
"""
connection = self._get_connection(conversation_id=conversation_id)
result = RealtimeTargetResult()
audio_done_received = False
GRACE_PERIOD_SEC = 1.0 # Wait 1 second after audio.done before soft-finishing
try:
# Create event iterator
event_iter = connection.__aiter__()
while True:
# If we've seen audio.done, wait with a short timeout for response.done
# Otherwise, wait indefinitely for events
timeout = GRACE_PERIOD_SEC if audio_done_received else None
try:
event = await asyncio.wait_for(event_iter.__anext__(), timeout=timeout)
except asyncio.TimeoutError:
# Soft-finish: audio.done was received but no response.done after grace period
if audio_done_received:
logger.warning(
f"Soft-finishing: No response.done {GRACE_PERIOD_SEC}s after audio.done. "
f"Audio bytes: {len(result.audio_bytes)}"
)
break
# Should not happen if timeout is None, but re-raise if it does
raise
except StopAsyncIteration:
# Connection closed normally
logger.debug("Event stream ended")
break
except Exception as conn_err:
# Handle websockets connection errors as soft-finish if we have audio
if "ConnectionClosed" in str(type(conn_err).__name__) and result.audio_bytes:
logger.warning(
f"Connection closed without response.done (likely API issue). "
f"Audio bytes received: {len(result.audio_bytes)}. Soft-finishing."
)
break
# Re-raise if not a connection close or no audio received
raise
event_type = event.type
logger.debug(f"Processing event type: {event_type}")
if event_type == "response.done":
self._handle_response_done_event(event=event, result=result)
logger.debug("Received response.done - finishing normally")
break
elif event_type == "error":
error_message = event.error.message if hasattr(event.error, "message") else str(event.error)
error_type = event.error.type if hasattr(event.error, "type") else "unknown"
logger.error(f"Received 'error' event: [{error_type}] {error_message}")
raise RuntimeError(f"Server error: [{error_type}] {error_message}")
elif event_type in ["response.audio.delta", "response.output_audio.delta"]:
audio_data = base64.b64decode(event.delta)
result.audio_bytes += audio_data
logger.debug(f"Decoded {len(audio_data)} bytes of audio data")
elif event_type in ["response.audio.done", "response.output_audio.done"]:
logger.debug(f"Received audio.done - will soft-finish in {GRACE_PERIOD_SEC}s if no response.done")
audio_done_received = True
elif event_type in ["response.audio_transcript.delta", "response.output_audio_transcript.delta"]:
# Capture transcript deltas as they arrive (needed when response.done never comes)
if hasattr(event, "delta") and event.delta:
result.transcripts.append(event.delta)
logger.debug(f"Captured transcript delta: {event.delta[:50]}...")
elif event_type in ["response.output_text.done"]:
logger.debug("Received text.done")
# Handle lifecycle events that we can safely log
elif event_type in [
"session.created",
"session.updated",
"conversation.created",
"conversation.item.created",
"conversation.item.added",
"conversation.item.done",
"input_audio_buffer.committed",
"input_audio_buffer.speech_started",
"input_audio_buffer.speech_stopped",
"conversation.item.input_audio_transcription.completed",
"response.created",
"response.output_item.added",
"response.output_item.created",
"response.output_item.done",
"response.content_part.added",
"response.content_part.done",
"response.audio_transcript.done",
"response.output_audio_transcript.done",
"response.output_text.delta",
"rate_limits.updated",
]:
logger.debug(f"Lifecycle event '{event_type}'")
else:
logger.debug(f"Unhandled event type '{event_type}'")
except Exception as e:
logger.error(f"An unexpected error occurred for conversation {conversation_id}: {e}")
raise
logger.debug(
f"Completed receive_events with {len(result.transcripts)} transcripts "
f"and {len(result.audio_bytes)} bytes of audio"
)
return result
def _get_connection(self, *, conversation_id: str):
"""
Get and validate the Realtime API connection for a conversation.
Args:
conversation_id: The conversation ID
Returns:
The Realtime API connection
Raises:
ConnectionError: If connection is not established
"""
connection = self._existing_conversation.get(conversation_id)
if connection is None:
raise ConnectionError(f"Realtime API connection is not established for conversation {conversation_id}")
return connection
@staticmethod
def _handle_response_done_event(*, event: Any, result: RealtimeTargetResult) -> None:
"""
Process a response.done event from OpenAI client.
Args:
event: The event object from OpenAI client
result: RealtimeTargetResult to update
Raises:
ValueError: If event structure doesn't match expectations
ServerErrorException: If response status is failed
Note:
We no longer extract transcripts here since we capture them from
transcript.delta events. This avoids duplicates and supports soft-finish
when response.done never arrives.
"""
logger.debug("Processing 'response.done' event")
response = event.response
# Check for failed status
status = response.status
if status == "failed":
error_details = RealtimeTarget._extract_error_details(response=response)
raise ServerErrorException(message=error_details)
# We used to extract transcript here, but now we collect it from delta events
# to support soft-finish when response.done doesn't arrive
logger.debug(f"Response completed successfully with {len(result.transcripts)} transcript fragments")
@staticmethod
def _extract_error_details(*, response: Any) -> str:
"""
Extract error details from a failed response.
Args:
response: The response object from OpenAI client
Returns:
A formatted error message
"""
if hasattr(response, "status_details") and response.status_details:
status_details = response.status_details
if hasattr(status_details, "error") and status_details.error:
error = status_details.error
error_type = error.type if hasattr(error, "type") else "unknown"
error_message = error.message if hasattr(error, "message") else "No error message provided"
return f"[{error_type}] {error_message}"
return "Unknown error occurred"
[docs]
async def send_text_async(self, text: str, conversation_id: str) -> Tuple[str, RealtimeTargetResult]:
"""
Sends text prompt using OpenAI Realtime API client.
Args:
text: prompt to send.
conversation_id: conversation ID
"""
connection = self._get_connection(conversation_id=conversation_id)
# Start listening for responses
receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))
logger.info(f"Sending text message: {text}")
# Send conversation item
await connection.conversation.item.create(
item={
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": text}],
}
)
# Request response from model
await self.send_response_create(conversation_id=conversation_id)
# Wait for response - receive_events has its own soft-finish logic
result = await receive_tasks
if not result.audio_bytes:
raise RuntimeError("No audio received from the server.")
# Close and recreate connection to avoid websockets library state issues with fragmented frames
# This prevents "cannot reset() while queue isn't empty" errors in multi-turn conversations
await self.cleanup_conversation(conversation_id=conversation_id)
new_connection = await self.connect(conversation_id=conversation_id)
self._existing_conversation[conversation_id] = new_connection
# Send session configuration to new connection
system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id)
session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt)
await new_connection.session.update(session=session_config)
# Azure GA uses 24000 Hz sample rate
output_audio_path = await self.save_audio(audio_bytes=result.audio_bytes, sample_rate=24000)
return output_audio_path, result
[docs]
async def send_audio_async(self, filename: str, conversation_id: str) -> Tuple[str, RealtimeTargetResult]:
"""
Send an audio message using OpenAI Realtime API client.
Args:
filename (str): The path to the audio file.
conversation_id (str): Conversation ID
"""
connection = self._get_connection(conversation_id=conversation_id)
with wave.open(filename, "rb") as wav_file:
# Read WAV parameters
num_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth() # Should be 2 bytes for PCM16
frame_rate = wav_file.getframerate()
num_frames = wav_file.getnframes()
audio_content = wav_file.readframes(num_frames)
receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))
try:
audio_base64 = base64.b64encode(audio_content).decode("utf-8")
# Use conversation.item.create with input_audio (like Azure sample)
logger.info(f"Sending audio message via conversation.item.create with {len(audio_base64)} bytes")
await connection.conversation.item.create(
item={
"type": "message",
"role": "user",
"content": [{"type": "input_audio", "audio": audio_base64}],
}
)
except Exception as e:
logger.error(f"Error sending audio: {e}")
raise
logger.debug("Sending response.create")
await self.send_response_create(conversation_id=conversation_id)
logger.debug("Waiting for response events...")
# Wait for response - receive_events has its own soft-finish logic
result = await receive_tasks
if not result.audio_bytes:
raise RuntimeError("No audio received from the server.")
# Close and recreate connection to avoid websockets library state issues with fragmented frames
# This prevents "cannot reset() while queue isn't empty" errors in multi-turn conversations
await self.cleanup_conversation(conversation_id=conversation_id)
new_connection = await self.connect(conversation_id=conversation_id)
self._existing_conversation[conversation_id] = new_connection
# Send session configuration to new connection
system_prompt = self._get_system_prompt_from_conversation(conversation_id=conversation_id)
session_config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt)
await new_connection.session.update(session=session_config)
output_audio_path = await self.save_audio(result.audio_bytes, num_channels, sample_width, frame_rate)
return output_audio_path, result
async def _construct_message_from_response(self, response: Any, request: Any) -> Message:
"""
Not used in RealtimeTarget - message construction handled by receive_events.
This implementation exists to satisfy the abstract base class requirement.
"""
raise NotImplementedError("RealtimeTarget uses receive_events for message construction")
def _validate_request(self, *, message: Message) -> None:
"""
Validates the structure and content of a message for compatibility of this target.
Args:
message (Message): The message object.
Raises:
ValueError: If more than two message pieces are provided.
ValueError: If any of the message pieces have a data type other than 'text' or 'audio_path'.
"""
# Check the number of message pieces
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports one message piece. Received: {n_pieces} pieces.")
piece_type = message.message_pieces[0].converted_value_data_type
if piece_type not in ["text", "audio_path"]:
raise ValueError(f"This target only supports text and audio_path prompt input. Received: {piece_type}.")
[docs]
def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
return False