Source code for pyrit.prompt_target.common.utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
from typing import Callable, Optional

from pyrit.exceptions import PyritException


def validate_temperature(temperature: Optional[float]) -> None:
    """
    Validate that temperature parameter is within valid range.

    Args:
        temperature: The temperature value to validate (0-2 inclusive).

    Raises:
        PyritException: If temperature is not between 0 and 2 (inclusive).
    """
    if temperature is not None and (temperature < 0 or temperature > 2):
        raise PyritException(message="temperature must be between 0 and 2 (inclusive).")


def validate_top_p(top_p: Optional[float]) -> None:
    """
    Validate that top_p parameter is within valid range.

    Args:
        top_p: The top_p value to validate (0-1 inclusive).

    Raises:
        PyritException: If top_p is not between 0 and 1 (inclusive).
    """
    if top_p is not None and (top_p < 0 or top_p > 1):
        raise PyritException(message="top_p must be between 0 and 1 (inclusive).")


[docs] def limit_requests_per_minute(func: Callable) -> Callable: """ A decorator to enforce rate limit of the target through setting requests per minute. This should be applied to all send_prompt_async() functions on PromptTarget and PromptChatTarget. Args: func (Callable): The function to be decorated. Returns: Callable: The decorated function with a sleep introduced. """ async def set_max_rpm(*args, **kwargs): self = args[0] rpm = getattr(self, "_max_requests_per_minute", None) if rpm and rpm > 0: await asyncio.sleep(60 / rpm) return await func(*args, **kwargs) return set_max_rpm