# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import re
from abc import abstractmethod
from typing import Any, Awaitable, Callable, Optional
from urllib.parse import urlparse
from openai import (
AsyncOpenAI,
BadRequestError,
ContentFilterFinishReasonError,
RateLimitError,
)
from openai._exceptions import (
APIConnectionError,
APIStatusError,
APITimeoutError,
AuthenticationError,
)
from pyrit.common import default_values
from pyrit.exceptions.exception_classes import (
RateLimitException,
handle_bad_request_exception,
)
from pyrit.models import Message, MessagePiece
from pyrit.prompt_target import PromptChatTarget
from pyrit.prompt_target.openai.openai_error_handling import (
_extract_error_payload,
_extract_request_id_from_exception,
_extract_retry_after_from_exception,
)
logger = logging.getLogger(__name__)
[docs]
class OpenAITarget(PromptChatTarget):
"""
Abstract base class for OpenAI-based prompt targets.
This class provides common functionality for interacting with OpenAI API
endpoints, handling authentication, rate limiting, and request/response processing.
Read more about the various models here:
https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models.
"""
ADDITIONAL_REQUEST_HEADERS: str = "OPENAI_ADDITIONAL_REQUEST_HEADERS"
model_name_environment_variable: str
endpoint_environment_variable: str
api_key_environment_variable: str
_async_client: Optional[AsyncOpenAI] = None
[docs]
def __init__(
self,
*,
model_name: Optional[str] = None,
endpoint: Optional[str] = None,
api_key: Optional[str | Callable[[], str | Awaitable[str]]] = None,
headers: Optional[str] = None,
max_requests_per_minute: Optional[int] = None,
httpx_client_kwargs: Optional[dict] = None,
) -> None:
"""
Initialize an instance of OpenAITarget.
Args:
model_name (str, Optional): The name of the model.
If no value is provided, the environment variable will be used (set by subclass).
endpoint (str, Optional): The target URL for the OpenAI service.
api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service,
or a callable that returns an access token. For Azure endpoints with Entra authentication,
pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)).
Defaults to the target-specific API key environment variable.
headers (str, Optional): Extra headers of the endpoint (JSON).
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit. The number of requests sent to the target
will be capped at the value provided.
httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the
`httpx.AsyncClient()` constructor.
"""
self._headers: dict = {}
self._httpx_client_kwargs = httpx_client_kwargs or {}
request_headers = default_values.get_non_required_value(
env_var_name=self.ADDITIONAL_REQUEST_HEADERS, passed_value=headers
)
if request_headers and isinstance(request_headers, str):
self._headers = json.loads(request_headers)
self._set_openai_env_configuration_vars()
self._model_name: str = default_values.get_required_value(
env_var_name=self.model_name_environment_variable, passed_value=model_name
)
endpoint_value = default_values.get_required_value(
env_var_name=self.endpoint_environment_variable, passed_value=endpoint
)
# Initialize parent with endpoint and model_name
PromptChatTarget.__init__(
self, max_requests_per_minute=max_requests_per_minute, endpoint=endpoint_value, model_name=self._model_name
)
# API key is required - either from parameter or environment variable
self._api_key = default_values.get_required_value( # type: ignore[assignment]
env_var_name=self.api_key_environment_variable, passed_value=api_key
)
self._initialize_openai_client()
def _extract_deployment_from_azure_url(self, url: str) -> str:
"""
Extract deployment/model name from Azure OpenAI URL.
Azure URLs have formats like:
- https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions
- https://{resource}.openai.azure.com/openai/deployments/{deployment}/responses
Args:
url: The Azure endpoint URL.
Returns:
The deployment name, or empty string if not found.
"""
# Match /deployments/{deployment_name}/
match = re.search(r"/deployments/([^/]+)/", url)
if match:
deployment = match.group(1)
logger.info(f"Extracted deployment name from URL: {deployment}")
return deployment
return ""
def _warn_old_azure_url_format(self, url: str) -> None:
"""
Warn users about old Azure URL format without modifying the URL.
Old formats that trigger warnings:
- Deployment in path: /openai/deployments/{deployment}/...
- API version in query: ?api-version=X
These can appear independently or together.
Recommended new format: https://{resource}.openai.azure.com/openai/v1
Pass deployment name as model_name parameter.
Args:
url: The Azure endpoint URL to validate.
"""
parsed = urlparse(url)
suggested_url = f"{parsed.scheme}://{parsed.netloc}/openai/v1"
# Check for both deployment in path and api-version
deployment = self._extract_deployment_from_azure_url(url)
has_api_version = "api-version" in parsed.query
# Build the specific issue description
if deployment and has_api_version:
issue_desc = "with deployment in path and api-version parameter"
recommendation = (
f"with deployment '{deployment}' passed as model_name parameter and api-version parameter removed"
)
elif deployment:
issue_desc = "with deployment in path"
recommendation = f"with deployment '{deployment}' passed as model_name parameter"
elif has_api_version:
issue_desc = "with api-version parameter"
recommendation = "without api-version parameter"
else:
return # No issues found
logger.warning(
f"Old Azure URL format detected {issue_desc}. "
f"Current URL: {url}. "
f"Recommended format: {suggested_url} {recommendation}. "
f"Old format URLs will be deprecated in a future release. "
f"See https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation "
"for more information."
)
@abstractmethod
def _get_target_api_paths(self) -> list[str]:
"""
Return list of API paths that should not be in the URL for this target.
The SDK automatically appends these paths, so they shouldn't be in the base URL.
Returns:
List of API paths (e.g., ["/chat/completions", "/v1/chat/completions"])
"""
pass
@abstractmethod
def _get_provider_examples(self) -> dict[str, str]:
"""
Return provider-specific example URLs for this target.
Used in warnings to show users the correct format.
Returns:
Dict mapping provider patterns to example URLs
(e.g., {".openai.azure.com": "https://{resource}.openai.azure.com/openai/v1"})
"""
pass
def _validate_url_for_target(self, endpoint_url: str) -> None:
"""
Validate the URL format for this specific target and warn about issues.
Checks for:
- API-specific paths that should not be in the URL
- Query parameters like api-version
This method does NOT modify the URL - it only logs warnings.
Args:
endpoint_url: The endpoint URL to validate.
"""
# Check for API paths that shouldn't be in the URL
api_paths = self._get_target_api_paths()
provider_examples = self._get_provider_examples()
for api_path in api_paths:
if api_path in endpoint_url:
self._warn_url_with_api_path(endpoint_url, api_path, provider_examples)
break # Only warn once
# Warn if query parameters are present
self._warn_url_with_query_params(endpoint_url)
def _warn_azure_url_path_issues(self, endpoint_url: str) -> None:
"""
Warn about Azure URL path structure issues without modifying the URL.
Expected formats:
- Azure OpenAI: https://{resource}.openai.azure.com/openai/v1
- Azure Foundry: https://{resource}.models.ai.azure.com (no /openai/v1 needed)
Args:
endpoint_url: The Azure endpoint URL to validate.
"""
parsed = urlparse(endpoint_url)
if ".openai.azure.com" in endpoint_url:
# Check for various issues with Azure OpenAI URLs
path = parsed.path.rstrip("/")
if not path or path == "":
logger.warning(
f"Azure OpenAI URL is missing path structure. "
f"Current: {endpoint_url}. "
f"Recommended: {endpoint_url.rstrip('/')}/openai/v1"
)
elif path == "/openai":
logger.warning(
f"Azure OpenAI URL is missing /v1 suffix. "
f"Current: {endpoint_url}. "
f"Recommended: {endpoint_url.rstrip('/')}/v1"
)
elif not path.endswith("/openai/v1") and not path.startswith("/openai/v1"):
# Check if it has an API extension that should be removed
if any(
api_path in path
for api_path in [
"/chat/completions",
"/responses",
"/completions",
"/videos",
"/images/generations",
"/audio/speech",
]
):
# This is handled by target-specific validation
pass
elif "/openai" not in path:
logger.warning(
f"Azure OpenAI URL should include /openai/v1 path. "
f"Current: {endpoint_url}. "
f"Recommended: {parsed.scheme}://{parsed.netloc}/openai/v1"
)
def _initialize_openai_client(self) -> None:
"""
Initialize the OpenAI client using AsyncOpenAI.
Validates the URL format and warns about potential issues, but does NOT modify
the user-provided URL. This allows flexibility for custom endpoints and non-standard
providers while helping users identify common configuration mistakes.
Supported formats:
- Platform OpenAI: https://api.openai.com/v1
- Azure OpenAI: https://{resource}.openai.azure.com/openai/v1
- Azure Foundry: https://{resource}.models.ai.azure.com
- Anthropic: https://api.anthropic.com/v1
- Google Gemini: https://generativelanguage.googleapis.com/v1beta/openai
- Custom endpoints: Any format (warnings may be shown but URL is not modified)
"""
# Merge custom headers with httpx_client_kwargs
httpx_kwargs = self._httpx_client_kwargs.copy()
if self._headers:
httpx_kwargs.setdefault("default_headers", {}).update(self._headers)
# Determine if this is Azure OpenAI based on the endpoint
is_azure = "azure" in self._endpoint.lower() if self._endpoint else False
# Warn about old Azure format but don't modify
warned_old_format = False
if is_azure:
parsed_url = urlparse(self._endpoint)
# Check if it has api-version query parameter OR /deployments/ in path
has_api_version = "api-version" in parsed_url.query
has_deployments = "/deployments/" in parsed_url.path
if has_deployments or has_api_version:
self._warn_old_azure_url_format(self._endpoint)
warned_old_format = True
# Validate URL format for target-specific issues
# Skip if we already warned about old format (to avoid duplicate warnings)
if not warned_old_format:
self._validate_url_for_target(self._endpoint)
# Warn about Azure path structure issues
if is_azure:
self._warn_azure_url_path_issues(self._endpoint)
# Use endpoint as-is - the user knows their provider best
base_url = self._endpoint
# Pass api_key directly to the SDK - it handles both strings and callables
self._async_client = AsyncOpenAI(
base_url=base_url,
api_key=self._api_key,
**httpx_kwargs,
)
async def _handle_openai_request(
self,
*,
api_call: Callable,
request: Message,
) -> Message:
"""
Unified error handling wrapper for all OpenAI SDK requests.
This method wraps any OpenAI SDK call and handles all common error scenarios:
- Content filtering (both proactive checks and SDK exceptions)
- Bad request errors (400s with content filter detection)
- Rate limiting (429s with retry-after extraction)
- API status errors (other HTTP errors)
- Transient errors (timeouts, connection issues)
- Authentication errors
Automatically detects the response type and applies appropriate validation and content
filter checks via abstract methods. On success, constructs and returns a Message object.
Args:
api_call: Async callable that invokes the OpenAI SDK method.
request: The Message representing the user's request (for error responses).
Returns:
Message: The constructed message response (success or error).
Raises:
RateLimitException: For 429 rate limit errors.
APIStatusError: For other API status errors.
APITimeoutError: For transient infrastructure errors.
APIConnectionError: For transient infrastructure errors.
AuthenticationError: For authentication failures.
"""
try:
# Execute the API call
response = await api_call()
# Extract MessagePiece for validation and construction (most targets use single piece)
request_piece = request.message_pieces[0] if request.message_pieces else None
# Check for content filter via subclass implementation
if self._check_content_filter(response):
return self._handle_content_filter_response(response, request_piece)
# Validate response via subclass implementation
error_message = self._validate_response(response, request_piece)
if error_message:
return error_message
# Construct and return Message from validated response
return await self._construct_message_from_response(response, request_piece)
except ContentFilterFinishReasonError as e:
# Content filter error raised by SDK during parse/structured output flows
request_id = _extract_request_id_from_exception(e)
logger.error(f"Content filter error (SDK raised). request_id={request_id} error={e}")
# Convert exception to response-like object for consistent handling
error_str = str(e)
class _ErrorResponse:
def model_dump_json(self):
return error_str
request_piece = request.message_pieces[0] if request.message_pieces else None
return self._handle_content_filter_response(_ErrorResponse(), request_piece)
except BadRequestError as e:
# Handle 400 errors - includes input policy filters and some Azure output-filter 400s
payload, is_content_filter = _extract_error_payload(e)
request_id = _extract_request_id_from_exception(e)
# Safely serialize payload for logging
try:
payload_str = payload if isinstance(payload, str) else json.dumps(payload)[:200]
except (TypeError, ValueError):
# If JSON serialization fails (e.g., contains non-serializable objects), use str()
payload_str = str(payload)[:200]
logger.warning(
f"BadRequestError request_id={request_id} is_content_filter={is_content_filter} "
f"payload={payload_str}"
)
request_piece = request.message_pieces[0] if request.message_pieces else None
return handle_bad_request_exception(
response_text=str(payload),
request=request_piece,
error_code=400,
is_content_filter=is_content_filter,
)
except RateLimitError as e:
# SDK's RateLimitError (429)
request_id = _extract_request_id_from_exception(e)
retry_after = _extract_retry_after_from_exception(e)
logger.warning(f"RateLimitError request_id={request_id} retry_after={retry_after} error={e}")
raise RateLimitException()
except APIStatusError as e:
# Other API status errors - check for 429 here as well
request_id = _extract_request_id_from_exception(e)
if getattr(e, "status_code", None) == 429:
retry_after = _extract_retry_after_from_exception(e)
logger.warning(f"429 via APIStatusError request_id={request_id} retry_after={retry_after}")
raise RateLimitException()
else:
logger.exception(
f"APIStatusError request_id={request_id} status={getattr(e, 'status_code', None)} error={e}"
)
raise
except (APITimeoutError, APIConnectionError) as e:
# Transient infrastructure errors - these are retryable
request_id = _extract_request_id_from_exception(e)
logger.warning(f"Transient API error ({e.__class__.__name__}) request_id={request_id} error={e}")
raise
except AuthenticationError as e:
# Authentication errors - non-retryable, surface quickly
request_id = _extract_request_id_from_exception(e)
logger.error(f"Authentication error request_id={request_id} error={e}")
raise
@abstractmethod
async def _construct_message_from_response(self, response: Any, request: MessagePiece) -> Message:
"""
Construct a Message from the OpenAI SDK response.
This method extracts the relevant data from the SDK response object and
constructs a Message with appropriate message pieces. It may include
async operations like saving files for image/audio/video responses.
Args:
response: The response object from OpenAI SDK (e.g., ChatCompletion, Response, etc.).
request: The original request MessagePiece.
Returns:
Message: Constructed message with extracted content.
"""
pass
def _check_content_filter(self, response: Any) -> bool:
"""
Check if the response indicates content filtering.
Override this method in subclasses that need content filter detection.
Default implementation returns False (no content filter).
Args:
response: The response object from OpenAI SDK.
Returns:
bool: True if content filter detected, False otherwise.
"""
return False
def _handle_content_filter_response(self, response: Any, request: MessagePiece) -> Message:
"""
Handle content filter errors by creating a proper error Message.
Args:
response: The response object from OpenAI SDK.
request: The original request message piece.
Returns:
Message object with error type indicating content was filtered.
"""
logger.warning("Output content filtered by content policy.")
return handle_bad_request_exception(
response_text=response.model_dump_json(),
request=request,
error_code=200,
is_content_filter=True,
)
def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]:
"""
Validate the response and return error Message if needed.
Override this method in subclasses that need custom response validation.
Default implementation returns None (no validation errors).
Args:
response: The response object from OpenAI SDK.
request: The original request MessagePiece.
Returns:
Optional[Message]: Error Message if validation fails, None otherwise.
Raises:
Various exceptions for validation failures.
"""
return None
@abstractmethod
def _set_openai_env_configuration_vars(self) -> None:
"""
Set deployment_environment_variable, endpoint_environment_variable,
and api_key_environment_variable which are read from .env file.
"""
raise NotImplementedError
def _warn_url_with_api_path(
self, endpoint_url: str, api_path: str, provider_examples: dict[str, str] = None
) -> None:
"""
Warn if URL includes API-specific path that should be handled by the SDK.
Args:
endpoint_url: The endpoint URL to check.
api_path: The API path to check for (e.g., "/chat/completions", "/responses").
provider_examples: Optional dict mapping provider patterns to example base URLs.
"""
if api_path in endpoint_url:
parsed = urlparse(endpoint_url)
base_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path.replace(api_path, '')}"
message = (
f"URL includes API path '{api_path}' which the OpenAI SDK handles automatically. "
f"Current URL: {endpoint_url}. "
f"Recommended: Remove '{api_path}' from the URL. "
)
# Add provider-specific guidance
if provider_examples:
for pattern, example in provider_examples.items():
if pattern in endpoint_url:
message += f"Example: {example}. "
break
else:
message += f"Suggested: {base_url}. "
logger.warning(message)
def _warn_url_with_query_params(self, endpoint_url: str) -> None:
"""
Warn if URL includes query parameters like api-version.
Args:
endpoint_url: The endpoint URL to check.
"""
parsed = urlparse(endpoint_url)
if parsed.query:
base_url = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
logger.warning(
f"URL includes query parameters '{parsed.query}' which should be removed. "
f"Current URL: {endpoint_url}. "
f"Recommended: {base_url}"
)
def _warn_if_irregular_endpoint(self, expected_url_regex) -> None:
"""
Validate that the endpoint URL ends with one of the expected routes for this OpenAI target.
Args:
expected_url_regex: Expected regex pattern(s) for this target. Should be a list of regex strings.
Prints a warning if the endpoint doesn't match any of the expected routes.
This validation helps ensure the endpoint is configured correctly for the specific API.
"""
if not self._endpoint or not expected_url_regex:
return
# Use urllib to extract the path part and normalize it
parsed_url = urlparse(self._endpoint)
normalized_route = parsed_url.path.lower().rstrip("/")
# Check if the endpoint matches any of the expected regex patterns
for regex_pattern in expected_url_regex:
if re.search(regex_pattern, normalized_route):
return
# No matches found, log warning
if len(expected_url_regex) == 1:
# Convert regex back to human-readable format for the warning
pattern_str = expected_url_regex[0].replace(r"[^/]+", "*").replace("$", "")
expected_routes_str = pattern_str
else:
# Convert all regex patterns to human-readable format
readable_patterns = [p.replace(r"[^/]+", "*").replace("$", "") for p in expected_url_regex]
expected_routes_str = f"one of: {', '.join(readable_patterns)}"
logger.warning(
f"The provided endpoint URL {parsed_url} does not match any of the expected formats: {expected_routes_str}."
f"This may be intentional, especially if you are using an endpoint other than Azure or OpenAI."
f"For more details and guidance, please see the .env_example file in the repository."
)
[docs]
@abstractmethod
def is_json_response_supported(self) -> bool:
"""
Abstract method to determine if JSON response format is supported by the target.
Returns:
bool: True if JSON response is supported, False otherwise.
"""
pass