Source code for pyrit.prompt_target.crucible_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Optional

from httpx import HTTPStatusError

from pyrit.common import default_values, net_utility
from pyrit.exceptions import (
    EmptyResponseException,
    handle_bad_request_exception,
    pyrit_target_retry,
)
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target import PromptTarget, limit_requests_per_minute

logger = logging.getLogger(__name__)


[docs] class CrucibleTarget(PromptTarget): API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY"
[docs] def __init__( self, *, endpoint: str, api_key: Optional[str] = None, max_requests_per_minute: Optional[int] = None, ) -> None: """ Initialize the Crucible target. Args: endpoint (str): The endpoint URL for the Crucible service. api_key (str, Optional): The API key for accessing the Crucible service. Defaults to the `CRUCIBLE_API_KEY` environment variable. max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. """ super().__init__(max_requests_per_minute=max_requests_per_minute, endpoint=endpoint) self._api_key: str = default_values.get_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key )
@limit_requests_per_minute async def send_prompt_async(self, *, prompt_request: Message) -> Message: self._validate_request(prompt_request=prompt_request) request = prompt_request.message_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") try: response = await self._complete_text_async(request.converted_value) response_entry = construct_response_from_request(request=request, response_text_pieces=[response]) except HTTPStatusError as bre: if bre.response.status_code == 400: response_entry = handle_bad_request_exception( response_text=bre.response.text, request=request, is_content_filter=True ) else: raise return response_entry def _validate_request(self, *, prompt_request: Message) -> None: n_pieces = len(prompt_request.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") piece_type = prompt_request.message_pieces[0].converted_value_data_type if piece_type != "text": raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") @pyrit_target_retry async def _complete_text_async(self, text: str) -> str: payload: dict[str, object] = { "data": text, } resp = await net_utility.make_request_and_raise_if_error_async( endpoint_uri=f"{self._endpoint.rstrip('/')}/score", method="POST", request_body=payload, headers={"X-API-Key": self._api_key}, ) if not resp.text: raise EmptyResponseException() logger.info(f'Received the following response from the prompt target "{resp.text}"') return resp.text