# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from enum import Enum
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
MutableSequence,
Optional,
)
from pyrit.common import convert_local_image_to_data_url
from pyrit.exceptions import (
EmptyResponseException,
PyritException,
pyrit_target_retry,
)
from pyrit.models import (
Message,
MessagePiece,
PromptDataType,
PromptResponseError,
)
from pyrit.prompt_target import (
OpenAITarget,
PromptChatTarget,
limit_requests_per_minute,
)
from pyrit.prompt_target.common.utils import validate_temperature, validate_top_p
from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error
logger = logging.getLogger(__name__)
# Tool function registry (agentic extension)
ToolExecutor = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
class MessagePieceType(str, Enum):
MESSAGE = "message"
REASONING = "reasoning"
IMAGE_GENERATION_CALL = "image_generation_call"
FILE_SEARCH_CALL = "file_search_call"
FUNCTION_CALL = "function_call"
WEB_SEARCH_CALL = "web_search_call"
COMPUTER_CALL = "computer_call"
CODE_INTERPRETER_CALL = "code_interpreter_call"
LOCAL_SHELL_CALL = "local_shell_call"
MCP_CALL = "mcp_call"
MCP_LIST_TOOLS = "mcp_list_tools"
MCP_APPROVAL_REQUEST = "mcp_approval_request"
[docs]
class OpenAIResponseTarget(OpenAITarget, PromptChatTarget):
"""
This class enables communication with endpoints that support the OpenAI Response API.
This works with models such as o1, o3, and o4-mini.
Depending on the endpoint this allows for a variety of inputs, outputs, and tool calls.
For more information, see the OpenAI Response API documentation:
https://platform.openai.com/docs/api-reference/responses/create
"""
[docs]
def __init__(
self,
*,
custom_functions: Optional[Dict[str, ToolExecutor]] = None,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
extra_body_parameters: Optional[dict[str, Any]] = None,
fail_on_missing_function: bool = False,
**kwargs,
):
"""
Initializes the OpenAIResponseTarget with the provided parameters.
Args:
custom_functions: Mapping of user-defined function names (e.g., "my_func").
model_name (str, Optional): The name of the model.
If no value is provided, the OPENAI_RESPONSES_MODEL environment variable will be used.
endpoint (str, Optional): The target URL for the OpenAI service.
api_key (str, Optional): The API key for accessing the Azure OpenAI service.
Defaults to the OPENAI_RESPONSES_KEY environment variable.
headers (str, Optional): Headers of the endpoint (JSON).
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit. The number of requests sent to the target
will be capped at the value provided.
max_output_tokens (int, Optional): The maximum number of tokens that can be
generated in the response. This value can be used to control
costs for text generated via API.
temperature (float, Optional): The temperature parameter for controlling the
randomness of the response.
top_p (float, Optional): The top-p parameter for controlling the diversity of the
response.
is_json_supported (bool, Optional): If True, the target will support formatting responses as JSON by
setting the response_format header. Official OpenAI models all support this, but if you are using
this target with different models, is_json_supported should be set correctly to avoid issues when
using adversarial infrastructure (e.g. Crescendo scorers will set this flag).
extra_body_parameters (dict, Optional): Additional parameters to be included in the request body.
fail_on_missing_function: if True, raise when a function_call references
an unknown function or does not output a function; if False, return a structured error so we can
wrap it as function_call_output and let the model potentially recover
(e.g., pick another tool or ask for clarification).
httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the
httpx.AsyncClient() constructor.
For example, to specify a 3 minute timeout: httpx_client_kwargs={"timeout": 180}
Raises:
PyritException: If the temperature or top_p values are out of bounds.
ValueError: If the temperature is not between 0 and 2 (inclusive).
ValueError: If the top_p is not between 0 and 1 (inclusive).
ValueError: If both `max_output_tokens` and `max_tokens` are provided.
RateLimitException: If the target is rate-limited.
httpx.HTTPStatusError: If the request fails with a 400 Bad Request or 429 Too Many Requests error.
json.JSONDecodeError: If the response from the target is not valid JSON.
Exception: If the request fails for any other reason.
"""
super().__init__(**kwargs)
# Validate temperature and top_p
validate_temperature(temperature)
validate_top_p(top_p)
self._temperature = temperature
self._top_p = top_p
self._max_output_tokens = max_output_tokens
# Reasoning parameters are not yet supported by PyRIT.
# See https://platform.openai.com/docs/api-reference/responses/create#responses-create-reasoning
# for more information.
self._extra_body_parameters = extra_body_parameters
# Per-instance tool/func registries:
self._custom_functions: Dict[str, ToolExecutor] = custom_functions or {}
self._fail_on_missing_function: bool = fail_on_missing_function
# Extract the grammar 'tool' if one is present
# See
# https://platform.openai.com/docs/guides/function-calling#context-free-grammars
self._grammar_name: str | None = None
if extra_body_parameters:
tools = extra_body_parameters.get("tools", [])
for tool in tools:
if tool.get("type") == "custom" and tool.get("format", {}).get("type") == "grammar":
if self._grammar_name is not None:
raise ValueError("Multiple grammar tools detected; only one is supported.")
tool_name = tool.get("name")
logger.debug("Detected grammar tool: %s", tool_name)
self._grammar_name = tool_name
def _set_openai_env_configuration_vars(self):
self.model_name_environment_variable = "OPENAI_RESPONSES_MODEL"
self.endpoint_environment_variable = "OPENAI_RESPONSES_ENDPOINT"
self.api_key_environment_variable = "OPENAI_RESPONSES_KEY"
def _get_target_api_paths(self) -> list[str]:
"""Return API paths that should not be in the URL."""
return ["/responses", "/v1/responses"]
def _get_provider_examples(self) -> dict[str, str]:
"""Return provider-specific example URLs."""
return {
".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1",
"api.openai.com": "https://api.openai.com/v1",
}
async def _construct_input_item_from_piece(self, piece: MessagePiece) -> Dict[str, Any]:
"""
Convert a single inline piece into a Responses API content item.
Args:
piece: The inline piece (text or image_path).
Returns:
A dict in the Responses API content item shape.
Raises:
ValueError: If the piece type is not supported for inline content. Supported types are text and
image paths.
"""
if piece.converted_value_data_type == "text":
return {
"type": "input_text" if piece.role in ["developer", "user"] else "output_text",
"text": piece.converted_value,
}
if piece.converted_value_data_type == "image_path":
data_url = await convert_local_image_to_data_url(piece.converted_value)
return {"type": "input_image", "image_url": {"url": data_url}}
raise ValueError(f"Unsupported piece type for inline content: {piece.converted_value_data_type}")
async def _build_input_for_multi_modal_async(self, conversation: MutableSequence[Message]) -> List[Dict[str, Any]]:
"""
Build the Responses API `input` array.
Groups inline content (text/images) into role messages and emits tool artifacts
(reasoning, function_call, function_call_output, web_search_call, etc.) as top-level
items — per the Responses API schema.
Each Message is processed as a complete unit. All MessagePieces within a Message
share the same role, so content is accumulated and appended once per Message.
Args:
conversation: Ordered list of user/assistant/tool artifacts to serialize.
Returns:
A list of input items ready for the Responses API.
Raises:
ValueError: If the conversation is empty or a system message has >1 piece.
"""
if not conversation:
raise ValueError("Conversation cannot be empty")
input_items: List[Dict[str, Any]] = []
for msg_idx, message in enumerate(conversation):
pieces = message.message_pieces
if not pieces:
raise ValueError(
f"Failed to process conversation message at index {msg_idx}: Message contains no message pieces"
)
# System message (remapped to developer)
if pieces[0].role == "system":
system_content = []
for piece in pieces:
system_content.append({"type": "input_text", "text": piece.converted_value})
input_items.append({"role": "developer", "content": system_content})
continue
# All pieces in a Message share the same role
role = pieces[0].role
content: List[Dict[str, Any]] = []
for piece in pieces:
dtype = piece.converted_value_data_type
# Skip reasoning - it's stored in memory but not sent back to API
if dtype == "reasoning":
continue
# Inline content (text/images) - accumulate in content list
if dtype in {"text", "image_path"}:
content.append(await self._construct_input_item_from_piece(piece))
continue
# Top-level artifacts - emit as standalone items
if dtype not in {"function_call", "function_call_output", "tool_call"}:
raise ValueError(f"Unsupported data type '{dtype}' in message index {msg_idx}")
if dtype in {"function_call", "tool_call"}:
# Parse the stored JSON and filter to only API-expected fields
stored = json.loads(piece.original_value)
if dtype == "function_call":
# Only include fields the API expects for function_call
input_items.append(
{
"type": stored["type"],
"call_id": stored["call_id"],
"name": stored["name"],
"arguments": stored["arguments"],
}
)
elif dtype == "tool_call":
# Filter tool_call fields based on type
tool_type = stored.get("type")
if tool_type == "web_search_call":
# Web search call structure
input_items.append(
{
"type": stored["type"],
"call_id": stored.get("call_id"),
"query": stored.get("query"),
}
)
else:
# For unknown tool types, try to include only known fields
filtered = {"type": stored["type"]}
if "call_id" in stored:
filtered["call_id"] = stored["call_id"]
if "query" in stored:
filtered["query"] = stored["query"]
if "name" in stored:
filtered["name"] = stored["name"]
if "arguments" in stored:
filtered["arguments"] = stored["arguments"]
input_items.append(filtered)
if dtype == "function_call_output":
payload = json.loads(piece.original_value)
output = payload.get("output")
if not isinstance(output, str):
# Responses API requires string output; serialize if needed
output = json.dumps(output, separators=(",", ":"))
input_items.append(
{
"type": "function_call_output",
"call_id": payload["call_id"],
"output": output,
}
)
# Append accumulated inline content for this message
if content:
input_items.append({"role": role, "content": content})
return input_items
async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict:
"""
Construct the request body to send to the Responses API.
NOTE: The Responses API uses top-level `response_format` for JSON,
not `text.format` from the old Chat Completions style.
"""
input_items = await self._build_input_for_multi_modal_async(conversation)
body_parameters = {
"model": self._model_name,
"max_output_tokens": self._max_output_tokens,
"temperature": self._temperature,
"top_p": self._top_p,
"stream": False,
"input": input_items,
# Correct JSON response format per Responses API
"response_format": {"type": "json_object"} if is_json_response else None,
}
if self._extra_body_parameters:
body_parameters.update(self._extra_body_parameters)
# Filter out None values
return {k: v for k, v in body_parameters.items() if v is not None}
def _check_content_filter(self, response: Any) -> bool:
"""
Check if a Response API response has a content filter error.
Args:
response: A Response object from the OpenAI SDK.
Returns:
True if content was filtered, False otherwise.
"""
if hasattr(response, "error") and response.error is not None:
# Convert response to dict and use common filter detection
response_dict = response.model_dump()
return _is_content_filter_error(response_dict)
return False
def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]:
"""
Validate a Response API response for errors.
Checks for:
- Error responses (excluding content filtering which is checked separately)
- Invalid status
- Empty output
Args:
response: The Response object from the OpenAI SDK.
request: The original request MessagePiece.
Returns:
None if valid, does not return Message for content filter (handled by _check_content_filter).
Raises:
PyritException: For unexpected response structures or errors.
EmptyResponseException: When the API returns no valid output.
"""
# Check for error response - error is a ResponseError object or None
# (content_filter is handled by _check_content_filter)
if response.error is not None and response.error.code != "content_filter":
raise PyritException(message=f"Response error: {response.error.code} - {response.error.message}")
# Check status - should be "completed" for successful responses
if response.status != "completed":
raise PyritException(message=f"Unexpected status: {response.status}")
# Check for empty output
if not response.output:
logger.error("The response returned no valid output.")
raise EmptyResponseException(message="The response returned an empty response.")
return None
async def _construct_message_from_response(self, response: Any, request: MessagePiece) -> Message:
"""
Construct a Message from a Response API response.
Args:
response: The Response object from OpenAI SDK.
request: The original request MessagePiece.
Returns:
Message: Constructed message with extracted content from output sections.
"""
# Extract and parse message pieces from validated output sections
extracted_response_pieces: List[MessagePiece] = []
for section in response.output:
piece = self._parse_response_output_section(
section=section,
message_piece=request,
error=None, # error is already handled in validation
)
if piece is None:
continue
extracted_response_pieces.append(piece)
return Message(message_pieces=extracted_response_pieces)
@limit_requests_per_minute
@pyrit_target_retry
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Send prompt, handle agentic tool calls (function_call), return all messages.
The Responses API supports structured outputs and tool execution. This method handles both:
- Simple text/reasoning responses
- Agentic tool-calling loops that may require multiple back-and-forth exchanges
Args:
message: The initial prompt from the user.
Returns:
List of messages generated during the interaction (assistant responses and tool messages).
The normalizer will persist all of these to memory.
"""
self._validate_request(message=message)
message_piece: MessagePiece = message.message_pieces[0]
is_json_response = self.is_response_format_json(message_piece)
# Get full conversation history from memory and append the current message
conversation: MutableSequence[Message] = self._memory.get_conversation(
conversation_id=message_piece.conversation_id
)
conversation.append(message)
# Track all responses generated during this interaction
responses_to_return: list[Message] = []
# Main agentic loop - each back-and-forth creates a new message
tool_call_section: Optional[dict[str, Any]] = None
while True:
logger.info(f"Sending conversation with {len(conversation)} messages to the prompt target")
body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response)
# Use unified error handling - automatically detects Response and validates
result = await self._handle_openai_request(
api_call=lambda: self._async_client.responses.create(**body),
request=message,
)
# Add result to conversation and responses list
conversation.append(result)
responses_to_return.append(result)
# Extract tool call if present
tool_call_section = self._find_last_pending_tool_call(result)
# If no tool call, we're done
if not tool_call_section:
break
# Execute the tool/function
tool_output = await self._execute_call_section(tool_call_section)
# Create a new message with the tool output
tool_piece = self._make_tool_piece(tool_output, tool_call_section["call_id"], reference_piece=message_piece)
tool_message = Message(message_pieces=[tool_piece], skip_validation=True)
# Add tool output message to conversation and responses list
conversation.append(tool_message)
responses_to_return.append(tool_message)
# Continue loop to send tool result and get next response
# Return all responses (normalizer will persist all of them to memory)
return responses_to_return
[docs]
def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
return True
def _parse_response_output_section(
self, *, section, message_piece: MessagePiece, error: Optional[PromptResponseError]
) -> MessagePiece | None:
"""
Parse model output sections, forwarding tool-calls for the agentic loop.
Args:
section: The section object from OpenAI SDK (Pydantic model).
message_piece: The original message piece.
error: Any error information from OpenAI.
Returns:
A MessagePiece for this section, or None to skip.
"""
section_type = section.type
piece_type: PromptDataType = "text" # Default, always set!
piece_value = ""
if section_type == MessagePieceType.MESSAGE:
section_content = section.content
if len(section_content) == 0:
raise EmptyResponseException(message="The chat returned an empty message section.")
piece_value = section_content[0].text
elif section_type == MessagePieceType.REASONING:
# Store reasoning in memory for debugging/logging, but won't be sent back to API
piece_value = json.dumps(
{
"id": section.id,
"type": section.type,
"summary": section.summary,
"content": section.content,
"encrypted_content": section.encrypted_content,
},
separators=(",", ":"),
)
piece_type = "reasoning"
elif section_type == MessagePieceType.FUNCTION_CALL:
# Only store fields the API expects for function_call (exclude status, etc.)
piece_value = json.dumps(
{
"type": "function_call",
"call_id": section.call_id,
"name": section.name,
"arguments": section.arguments,
},
separators=(",", ":"),
)
piece_type = "function_call"
elif section_type == MessagePieceType.WEB_SEARCH_CALL:
# Forward web_search_call with only API-expected fields
# Note: web search may have different field structure than function calls
web_search_data = {
"type": "web_search_call",
}
# Add optional fields if they exist
if hasattr(section, "call_id") and section.call_id:
web_search_data["call_id"] = section.call_id
if hasattr(section, "query") and section.query:
web_search_data["query"] = section.query
if hasattr(section, "id") and section.id:
web_search_data["id"] = section.id
piece_value = json.dumps(web_search_data, separators=(",", ":"))
piece_type = "tool_call"
elif section_type == "custom_tool_call":
# Had a Lark grammar (hopefully)
# See
# https://platform.openai.com/docs/guides/function-calling#context-free-grammars
logger.debug("Detected custom_tool_call in response, assuming grammar constraint.")
extracted_grammar_name = section.name
if extracted_grammar_name != self._grammar_name:
msg = "Mismatched grammar name in custom_tool_call "
msg += f"(expected {self._grammar_name}, got {extracted_grammar_name})"
logger.error(msg)
raise ValueError(msg)
piece_value = section.input
if len(piece_value) == 0:
raise EmptyResponseException(message="The chat returned an empty message section.")
else:
# Other possible types are not yet handled in PyRIT
return None
# Handle empty response
if not piece_value:
raise EmptyResponseException(message="The chat returned an empty response.")
return MessagePiece(
role="assistant",
original_value=piece_value,
conversation_id=message_piece.conversation_id,
labels=message_piece.labels,
prompt_target_identifier=message_piece.prompt_target_identifier,
attack_identifier=message_piece.attack_identifier,
original_value_data_type=piece_type,
response_error=error or "none",
)
def _validate_request(self, *, message: Message) -> None:
"""
Validates the structure and content of a message for compatibility of this target.
Args:
message (Message): The message object.
Raises:
ValueError: If any of the message pieces have a data type other than supported set.
"""
# Some models may not support all of these; we accept them at the transport layer
# so the Responses API can decide. We include reasoning and function_call_output now.
allowed_types = {"text", "image_path", "function_call", "tool_call", "function_call_output", "reasoning"}
for message_piece in message.message_pieces:
if message_piece.converted_value_data_type not in allowed_types:
raise ValueError(f"Unsupported data type: {message_piece.converted_value_data_type}")
return
# Agentic helpers (module scope)
def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any]]:
"""
Return the last tool-call section in assistant messages, or None.
Looks for a piece whose value parses as JSON with a 'type' key matching function_call.
"""
for piece in reversed(reply.message_pieces):
if piece.role == "assistant":
try:
section = json.loads(piece.original_value)
except Exception:
continue
if section.get("type") == "function_call":
# Do NOT skip function_call even if status == "completed" — we still need to emit the output.
return section
return None
async def _execute_call_section(self, tool_call_section: dict[str, Any]) -> dict[str, Any]:
"""
Execute a function_call from the custom_functions registry.
Returns:
A dict payload (will be serialized and sent as function_call_output).
If fail_on_missing_function=False and a function is missing or no function is not called, returns:
{"error": "function_not_found", "missing_function": "<name>", "available_functions": [...]}
"""
name = tool_call_section.get("name")
if not name:
if self._fail_on_missing_function:
raise ValueError("Function call section missing 'name' field")
return {
"error": "missing_function_name",
"tool_call_section": tool_call_section,
}
args_json = tool_call_section.get("arguments", "{}")
try:
args = json.loads(args_json)
except Exception:
# If arguments are not valid JSON, surface a structured error (or raise)
if self._fail_on_missing_function:
raise ValueError(f"Malformed arguments for function '{name}': {args_json}")
logger.warning("Malformed arguments for function '%s': %s", name, args_json)
return {
"error": "malformed_arguments",
"function": name,
"raw_arguments": args_json,
}
fn = self._custom_functions.get(name)
if fn is None:
if self._fail_on_missing_function:
raise KeyError(f"Function '{name}' is not registered")
# Tolerant mode: return a structured error so we can wrap it as function_call_output
available = sorted(self._custom_functions.keys())
logger.warning("Function '%s' not registered. Available: %s", name, available)
return {
"error": "function_not_found",
"missing_function": name,
"available_functions": available,
}
return await fn(args)
def _make_tool_piece(self, output: dict[str, Any], call_id: str, *, reference_piece: MessagePiece) -> MessagePiece:
"""
Create a function_call_output MessagePiece.
Args:
output: The tool output to wrap.
call_id: The call ID for the function call.
reference_piece: A reference piece to copy conversation context from.
Returns:
A MessagePiece containing the function call output.
"""
output_str = output if isinstance(output, str) else json.dumps(output, separators=(",", ":"))
return MessagePiece(
role="tool",
original_value=json.dumps(
{"type": "function_call_output", "call_id": call_id, "output": output_str},
separators=(",", ":"),
),
original_value_data_type="function_call_output",
conversation_id=reference_piece.conversation_id,
labels={"call_id": call_id},
prompt_target_identifier=reference_piece.prompt_target_identifier,
attack_identifier=reference_piece.attack_identifier,
)