# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
from abc import ABC
from typing import Callable
from openai import RateLimitError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from pyrit.exceptions.exceptions_helpers import log_exception
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.prompt_request_response import PromptRequestResponse, construct_response_from_request
RETRY_MAX_NUM_ATTEMPTS = int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 10))
RETRY_WAIT_MIN_SECONDS = int(os.getenv("RETRY_WAIT_MIN_SECONDS", 5))
RETRY_WAIT_MAX_SECONDS = int(os.getenv("RETRY_WAIT_MAX_SECONDS", 220))
logger = logging.getLogger(__name__)
[docs]
class PyritException(Exception, ABC):
def __init__(self, status_code=500, *, message: str = "An error occurred"):
self.status_code = status_code
self.message = message
super().__init__(f"Status Code: {status_code}, Message: {message}")
[docs]
def process_exception(self) -> str:
"""
Logs and returns a string representation of the exception.
"""
log_message = f"{self.__class__.__name__} encountered: Status Code: {self.status_code}, Message: {self.message}"
logger.error(log_message)
# Return a string representation of the exception so users can extract and parse
return json.dumps({"status_code": self.status_code, "message": self.message})
[docs]
class BadRequestException(PyritException):
"""Exception class for bad client requests."""
def __init__(self, status_code: int = 400, *, message: str = "Bad Request"):
super().__init__(status_code, message=message)
[docs]
class RateLimitException(PyritException):
"""Exception class for authentication errors."""
def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"):
super().__init__(status_code, message=message)
[docs]
class EmptyResponseException(BadRequestException):
"""Exception class for empty response errors."""
def __init__(self, status_code: int = 204, *, message: str = "No Content"):
super().__init__(status_code=status_code, message=message)
[docs]
class InvalidJsonException(PyritException):
"""Exception class for blocked content errors."""
def __init__(self, *, message: str = "Invalid JSON Response"):
super().__init__(message=message)
[docs]
class MissingPromptPlaceholderException(PyritException):
"""Exception class for missing prompt placeholder errors."""
def __init__(self, *, message: str = "No prompt placeholder"):
super().__init__(message=message)
[docs]
def pyrit_target_retry(func: Callable) -> Callable:
"""
A decorator to apply retry logic with exponential backoff to a function.
Retries the function if it raises RateLimitError or EmptyResponseException,
with a wait time between retries that follows an exponential backoff strategy.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.
Args:
func (Callable): The function to be decorated.
Returns:
Callable: The decorated function with retry logic applied.
"""
global RETRY_MAX_NUM_ATTEMPTS, RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS
return retry(
reraise=True,
retry=retry_if_exception_type(RateLimitError) | retry_if_exception_type(EmptyResponseException),
wait=wait_random_exponential(min=RETRY_WAIT_MIN_SECONDS, max=RETRY_WAIT_MAX_SECONDS),
after=log_exception,
stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS),
)(func)
[docs]
def pyrit_json_retry(func: Callable) -> Callable:
"""
A decorator to apply retry logic with exponential backoff to a function.
Retries the function if it raises a JSON error,
with a wait time between retries that follows an exponential backoff strategy.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.
Args:
func (Callable): The function to be decorated.
Returns:
Callable: The decorated function with retry logic applied.
"""
global RETRY_MAX_NUM_ATTEMPTS, RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS
return retry(
reraise=True,
retry=retry_if_exception_type(InvalidJsonException),
wait=wait_random_exponential(min=RETRY_WAIT_MIN_SECONDS, max=RETRY_WAIT_MAX_SECONDS),
after=log_exception,
stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS),
)(func)
[docs]
def pyrit_placeholder_retry(func: Callable) -> Callable:
"""
A decorator to apply retry logic.
Retries the function if it raises MissingPromptPlaceholderException.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.
Args:
func (Callable): The function to be decorated.
Returns:
Callable: The decorated function with retry logic applied.
"""
global RETRY_MAX_NUM_ATTEMPTS
return retry(
reraise=True,
retry=retry_if_exception_type(MissingPromptPlaceholderException),
after=log_exception,
stop=stop_after_attempt(RETRY_MAX_NUM_ATTEMPTS),
)(func)
[docs]
def handle_bad_request_exception(
response_text: str,
request: PromptRequestPiece,
is_content_filter=False,
) -> PromptRequestResponse:
if (
"content_filter" in response_text
or "Invalid prompt: your prompt was flagged as potentially violating our usage policy." in response_text
or is_content_filter
):
# Handle bad request error when content filter system detects harmful content
bad_request_exception = BadRequestException(400, message=response_text)
resp_text = bad_request_exception.process_exception()
response_entry = construct_response_from_request(
request=request, response_text_pieces=[resp_text], response_type="error", error="blocked"
)
else:
raise
return response_entry