Source code for pyrit.prompt_target.hugging_face.hugging_face_endpoint_target
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from pyrit.prompt_target import PromptTarget
from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.models.prompt_request_response import PromptRequestResponse, construct_response_from_request
logger = logging.getLogger(__name__)
[docs]
class HuggingFaceEndpointTarget(PromptTarget):
"""The HuggingFaceEndpointTarget interacts with HuggingFace models hosted on cloud endpoints.
Inherits from PromptTarget to comply with the current design standards.
"""
[docs]
def __init__(
self,
*,
hf_token: str,
endpoint: str,
model_id: str,
max_tokens: int = 400,
temperature: float = 1.0,
top_p: float = 1.0,
verbose: bool = False,
) -> None:
"""Initializes the HuggingFaceEndpointTarget with API credentials and model parameters.
Args:
hf_token (str): The Hugging Face token for authenticating with the Hugging Face endpoint.
endpoint (str): The endpoint URL for the Hugging Face model.
model_id (str): The model ID to be used at the endpoint.
max_tokens (int, Optional): The maximum number of tokens to generate. Defaults to 400.
temperature (float, Optional): The sampling temperature to use. Defaults to 1.0.
top_p (float, Optional): The cumulative probability for nucleus sampling. Defaults to 1.0.
verbose (bool, Optional): Flag to enable verbose logging. Defaults to False.
"""
super().__init__(verbose=verbose)
self.hf_token = hf_token
self.endpoint = endpoint
self.model_id = model_id
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
[docs]
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Sends a normalized prompt asynchronously to a cloud-based HuggingFace model endpoint.
Args:
prompt_request (PromptRequestResponse): The prompt request containing the input data and associated details
such as conversation ID and role.
Returns:
PromptRequestResponse: A response object containing generated text pieces as a list of `PromptRequestPiece`
objects. Each `PromptRequestPiece` includes the generated text and relevant information such as
conversation ID, role, and any additional response attributes.
Raises:
ValueError: If the response from the Hugging Face API is not successful.
Exception: If an error occurs during the HTTP request to the Hugging Face endpoint.
"""
self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]
headers = {"Authorization": f"Bearer {self.hf_token}"}
payload: dict[str, object] = {
"inputs": request.converted_value,
"parameters": {
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
},
}
logger.info(f"Sending the following prompt to the cloud endpoint: {request.converted_value}")
try:
# Use the utility method to make the request
response = await make_request_and_raise_if_error_async(
endpoint_uri=self.endpoint,
method="POST",
request_body=payload,
headers=headers,
post_type="json",
)
response_data = response.json()
# Check if the response is a list and handle appropriately
if isinstance(response_data, list):
# Access the first element if it's a list and extract 'generated_text' safely
response_message = response_data[0].get("generated_text", "")
else:
response_message = response_data.get("generated_text", "")
prompt_response = construct_response_from_request(
request=request,
response_text_pieces=[response_message],
prompt_metadata=str({"model_id": self.model_id}),
)
return prompt_response
except Exception as e:
logger.error(f"Error occurred during HTTP request to the Hugging Face endpoint: {e}")
raise
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
"""
Validates the provided prompt request response.
Args:
prompt_request (PromptRequestResponse): The prompt request to validate.
Raises:
ValueError: If the request is not valid for this target.
"""
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports a single prompt request piece.")
if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")