Source code for pyrit.prompt_target.crucible_target
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional
from httpx import HTTPStatusError
from pyrit.common import default_values, net_utility
from pyrit.exceptions import EmptyResponseException, handle_bad_request_exception, pyrit_target_retry
from pyrit.models import PromptRequestResponse, construct_response_from_request
from pyrit.prompt_target import PromptTarget, limit_requests_per_minute
logger = logging.getLogger(__name__)
[docs]
class CrucibleTarget(PromptTarget):
API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY"
[docs]
def __init__(
self,
*,
endpoint: str,
api_key: str = None,
max_requests_per_minute: Optional[int] = None,
) -> None:
super().__init__(max_requests_per_minute=max_requests_per_minute)
self._endpoint = endpoint
self._api_key: str = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]
logger.info(f"Sending the following prompt to the prompt target: {request}")
try:
response = await self._complete_text_async(request.converted_value)
response_entry = construct_response_from_request(request=request, response_text_pieces=[response])
except HTTPStatusError as bre:
if bre.response.status_code == 400:
response_entry = handle_bad_request_exception(
response_text=bre.response.text, request=request, is_content_filter=True
)
else:
raise
return response_entry
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports a single prompt request piece.")
if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")
@pyrit_target_retry
async def _complete_text_async(self, text: str) -> str:
payload: dict[str, object] = {
"data": text,
}
resp = await net_utility.make_request_and_raise_if_error_async(
endpoint_uri=f"{self._endpoint.rstrip('/')}/score",
method="POST",
request_body=payload,
headers={"X-API-Key": self._api_key},
)
if not resp.text:
raise EmptyResponseException()
logger.info(f'Received the following response from the prompt target "{resp.text}"')
return resp.text