Source code for pyrit.exceptions.exception_classes

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import logging
import os
from abc import ABC
from typing import Callable, Optional

from openai import RateLimitError
from tenacity import (
    retry,
    retry_if_exception_type,
    retry_if_result,
    stop_after_attempt,
    wait_random_exponential,
)

from pyrit.exceptions.exceptions_helpers import log_exception
from pyrit.models import Message, MessagePiece, construct_response_from_request

logger = logging.getLogger(__name__)


def _get_custom_result_retry_max_num_attempts() -> int:
    """Get the maximum number of retry attempts for custom result retry decorator."""
    return int(os.getenv("CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS", 10))


def _get_retry_max_num_attempts() -> int:
    """Get the maximum number of retry attempts."""
    return int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 10))


def _get_retry_wait_min_seconds() -> int:
    """Get the minimum wait time in seconds between retries."""
    return int(os.getenv("RETRY_WAIT_MIN_SECONDS", 5))


def _get_retry_wait_max_seconds() -> int:
    """Get the maximum wait time in seconds between retries."""
    return int(os.getenv("RETRY_WAIT_MAX_SECONDS", 220))


[docs] class PyritException(Exception, ABC): def __init__(self, status_code=500, *, message: str = "An error occurred"): self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}")
[docs] def process_exception(self) -> str: """ Logs and returns a string representation of the exception. """ log_message = f"{self.__class__.__name__} encountered: Status Code: {self.status_code}, Message: {self.message}" logger.error(log_message) # Return a string representation of the exception so users can extract and parse return json.dumps({"status_code": self.status_code, "message": self.message})
[docs] class BadRequestException(PyritException): """Exception class for bad client requests.""" def __init__(self, status_code: int = 400, *, message: str = "Bad Request"): super().__init__(status_code, message=message)
[docs] class RateLimitException(PyritException): """Exception class for authentication errors.""" def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"): super().__init__(status_code, message=message)
class ServerErrorException(PyritException): """Exception class for opaque 5xx errors returned by the server.""" def __init__(self, status_code: int = 500, *, message: str = "Server Error", body: Optional[str] = None): super().__init__(status_code, message=message) self.body = body
[docs] class EmptyResponseException(BadRequestException): """Exception class for empty response errors.""" def __init__(self, status_code: int = 204, *, message: str = "No Content"): super().__init__(status_code=status_code, message=message)
[docs] class InvalidJsonException(PyritException): """Exception class for blocked content errors.""" def __init__(self, *, message: str = "Invalid JSON Response"): super().__init__(message=message)
[docs] class MissingPromptPlaceholderException(PyritException): """Exception class for missing prompt placeholder errors.""" def __init__(self, *, message: str = "No prompt placeholder"): super().__init__(message=message)
[docs] def pyrit_custom_result_retry(retry_function: Callable, retry_max_num_attempts: Optional[int] = None) -> Callable: """ A decorator to apply retry logic with exponential backoff to a function. Retries the function if the result of the retry_function is True, with a wait time between retries that follows an exponential backoff strategy. Logs retry attempts at the INFO level and stops after a maximum number of attempts. Args: retry_function (Callable): The boolean function to determine if a retry should occur based on the result of the decorated function. retry_max_num_attempts (Optional, int): The maximum number of retry attempts. Defaults to environment variable CUSTOM_RESULT_RETRY_MAX_NUM_ATTEMPTS or 10. func (Callable): The function to be decorated. Returns: Callable: The decorated function with retry logic applied. """ def inner_retry(func): max_attempts = retry_max_num_attempts or _get_custom_result_retry_max_num_attempts() return retry( reraise=True, retry=retry_if_result(retry_function), wait=wait_random_exponential(min=_get_retry_wait_min_seconds(), max=_get_retry_wait_max_seconds()), after=log_exception, stop=stop_after_attempt(max_attempts), )(func) return inner_retry
[docs] def pyrit_target_retry(func: Callable) -> Callable: """ A decorator to apply retry logic with exponential backoff to a function. Retries the function if it raises RateLimitError or EmptyResponseException, with a wait time between retries that follows an exponential backoff strategy. Logs retry attempts at the INFO level and stops after a maximum number of attempts. Args: func (Callable): The function to be decorated. Returns: Callable: The decorated function with retry logic applied. """ return retry( reraise=True, retry=retry_if_exception_type(RateLimitError) | retry_if_exception_type(EmptyResponseException) | retry_if_exception_type(RateLimitException), wait=wait_random_exponential(min=_get_retry_wait_min_seconds(), max=_get_retry_wait_max_seconds()), after=log_exception, stop=stop_after_attempt(_get_retry_max_num_attempts()), )(func)
[docs] def pyrit_json_retry(func: Callable) -> Callable: """ A decorator to apply retry logic to a function. Retries the function if it raises a JSON error. Logs retry attempts at the INFO level and stops after a maximum number of attempts. Args: func (Callable): The function to be decorated. Returns: Callable: The decorated function with retry logic applied. """ return retry( reraise=True, retry=retry_if_exception_type(InvalidJsonException), after=log_exception, stop=stop_after_attempt(_get_retry_max_num_attempts()), )(func)
[docs] def pyrit_placeholder_retry(func: Callable) -> Callable: """ A decorator to apply retry logic. Retries the function if it raises MissingPromptPlaceholderException. Logs retry attempts at the INFO level and stops after a maximum number of attempts. Args: func (Callable): The function to be decorated. Returns: Callable: The decorated function with retry logic applied. """ return retry( reraise=True, retry=retry_if_exception_type(MissingPromptPlaceholderException), after=log_exception, stop=stop_after_attempt(_get_retry_max_num_attempts()), )(func)
[docs] def handle_bad_request_exception( response_text: str, request: MessagePiece, is_content_filter=False, error_code: int = 400, ) -> Message: if ( "content_filter" in response_text or "Invalid prompt: your prompt was flagged as potentially violating our usage policy." in response_text or is_content_filter ): # Handle bad request error when content filter system detects harmful content bad_request_exception = BadRequestException(error_code, message=response_text) resp_text = bad_request_exception.process_exception() response_entry = construct_response_from_request( request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" ) else: raise return response_entry