Source code for pyrit.prompt_target.hugging_face.hugging_face_chat_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import json
import logging
import os
from typing import TYPE_CHECKING, Optional

from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig

from pyrit.common import default_values
from pyrit.common.download_hf_model import download_specific_files
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
from pyrit.models.prompt_request_response import PromptRequestResponse, construct_response_from_request
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    import torch


[docs] class HuggingFaceChatTarget(PromptChatTarget): """The HuggingFaceChatTarget interacts with HuggingFace models, specifically for conducting red teaming activities. Inherits from PromptTarget to comply with the current design standards. """ # Class-level cache for model and tokenizer _cached_model = None _cached_tokenizer = None _cached_model_id = None # Class-level flag to enable or disable cache _cache_enabled = True # Define the environment variable name for the Hugging Face token HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE = "HUGGINGFACE_TOKEN"
[docs] def __init__( self, *, model_id: Optional[str] = None, model_path: Optional[str] = None, hf_access_token: Optional[str] = None, use_cuda: bool = False, tensor_format: str = "pt", necessary_files: list = None, max_new_tokens: int = 20, temperature: float = 1.0, top_p: float = 1.0, skip_special_tokens: bool = True, trust_remote_code: bool = False, device_map: Optional[str] = None, torch_dtype: Optional["torch.dtype"] = None, attn_implementation: Optional[str] = None, ) -> None: super().__init__() if not model_id and not model_path: raise ValueError("Either `model_id` or `model_path` must be provided.") if model_id and model_path: raise ValueError("Provide only one of `model_id` or `model_path`, not both.") self.model_id = model_id self.model_path = model_path self.use_cuda = use_cuda self.tensor_format = tensor_format self.trust_remote_code = trust_remote_code self.device_map = device_map self.torch_dtype = torch_dtype self.attn_implementation = attn_implementation # Only get the Hugging Face token if a model ID is provided if model_id: self.huggingface_token = default_values.get_required_value( env_var_name=self.HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE, passed_value=hf_access_token ) else: self.huggingface_token = None try: import torch except ModuleNotFoundError as e: logger.error("Could not import torch. You may need to install it via 'pip install pyrit[all]'") raise e # Determine the device self.device = "cuda" if self.use_cuda and torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Set necessary files if provided, otherwise set to None to trigger general download self.necessary_files = necessary_files # Set the default parameters for the model generation self.max_new_tokens = max_new_tokens self.temperature = temperature self.top_p = top_p self.skip_special_tokens = skip_special_tokens if self.use_cuda and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available.") self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer())
def _load_from_path(self, path: str, **kwargs): """ Helper function to load the model and tokenizer from a given path. """ logger.info(f"Loading model and tokenizer from path: {path}...") self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=self.trust_remote_code) self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=self.trust_remote_code, **kwargs)
[docs] def is_model_id_valid(self) -> bool: """ Check if the HuggingFace model ID is valid. :return: True if valid, False otherwise. """ try: # Attempt to load the configuration of the model PretrainedConfig.from_pretrained(self.model_id) return True except Exception as e: logger.error(f"Invalid HuggingFace model ID {self.model_id}: {e}") return False
[docs] async def load_model_and_tokenizer(self): """Loads the model and tokenizer, downloading if necessary. Downloads the model to the HF_MODELS_DIR folder if it does not exist, then loads it from there. Raises: Exception: If the model loading fails. """ try: # Determine the identifier for caching purposes model_identifier = self.model_path or self.model_id optional_model_kwargs = { key: value for key, value in { "device_map": self.device_map, "torch_dtype": self.torch_dtype, "attn_implementation": self.attn_implementation, }.items() if value is not None } # Check if the model is already cached if HuggingFaceChatTarget._cache_enabled and HuggingFaceChatTarget._cached_model_id == model_identifier: logger.info(f"Using cached model and tokenizer for {model_identifier}.") self.model = HuggingFaceChatTarget._cached_model self.tokenizer = HuggingFaceChatTarget._cached_tokenizer return if self.model_path: # Load the tokenizer and model from the local directory logger.info(f"Loading model from local path: {self.model_path}...") self._load_from_path(self.model_path, **optional_model_kwargs) else: # Define the default Hugging Face cache directory cache_dir = os.path.join( os.path.expanduser("~"), ".cache", "huggingface", "hub", f"models--{self.model_id.replace('/', '--')}", ) if self.necessary_files is None: # Download all files if no specific files are provided logger.info(f"Downloading all files for {self.model_id}...") await download_specific_files(self.model_id, None, self.huggingface_token, cache_dir) else: # Download only the necessary files logger.info(f"Downloading specific files for {self.model_id}...") await download_specific_files( self.model_id, self.necessary_files, self.huggingface_token, cache_dir ) # Load the tokenizer and model from the specified directory logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code, **optional_model_kwargs, ) # Move the model to the correct device self.model = self.model.to(self.device) # Debug prints to check types logger.info(f"Model loaded: {type(self.model)}") logger.info(f"Tokenizer loaded: {type(self.tokenizer)}") # Cache the loaded model and tokenizer if caching is enabled if HuggingFaceChatTarget._cache_enabled: HuggingFaceChatTarget._cached_model = self.model HuggingFaceChatTarget._cached_tokenizer = self.tokenizer HuggingFaceChatTarget._cached_model_id = model_identifier logger.info(f"Model {model_identifier} loaded successfully.") except Exception as e: logger.error(f"Error loading model {self.model_id}: {e}") raise
[docs] @pyrit_target_retry async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: """ Sends a normalized prompt asynchronously to the HuggingFace model. """ # Load the model and tokenizer using the encapsulated method await self.load_model_and_tokenizer_task self._validate_request(prompt_request=prompt_request) request = prompt_request.request_pieces[0] prompt_template = request.converted_value logger.info(f"Sending the following prompt to the HuggingFace model: {prompt_template}") # Prepare the input messages using chat templates messages = [{"role": "user", "content": prompt_template}] # Apply chat template via the _apply_chat_template method tokenized_chat = self._apply_chat_template(messages) input_ids = tokenized_chat["input_ids"] attention_mask = tokenized_chat["attention_mask"] logger.info(f"Tokenized chat: {input_ids}") try: # Ensure model is on the correct device (should already be the case from `load_model_and_tokenizer`) self.model.to(self.device) # Record the length of the input tokens to later extract only the generated tokens input_length = input_ids.shape[-1] # Generate the response logger.info("Generating response from model...") generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=self.max_new_tokens, temperature=self.temperature, top_p=self.top_p, ) logger.info(f"Generated IDs: {generated_ids}") # Log the generated IDs # Extract the assistant's response by slicing the generated tokens after the input tokens generated_tokens = generated_ids[0][input_length:] # Decode the assistant's response from the generated token IDs assistant_response = self.tokenizer.decode( generated_tokens, skip_special_tokens=self.skip_special_tokens ).strip() if not assistant_response: raise EmptyResponseException() logger.info(f"Assistant's response: {assistant_response}") model_identifier = self.model_id or self.model_path return construct_response_from_request( request=request, response_text_pieces=[assistant_response], prompt_metadata=json.dumps({"model_id": model_identifier}), ) except Exception as e: logger.error(f"Error occurred during inference: {e}") raise
def _apply_chat_template(self, messages): """ A private method to apply the chat template to the input messages and tokenize them. """ # Check if the tokenizer has a chat template if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: logger.info("Tokenizer has a chat template. Applying it to the input messages.") # Apply the chat template to format and tokenize the messages tokenized_chat = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors=self.tensor_format, return_dict=True, ).to(self.device) return tokenized_chat else: error_message = ( "Tokenizer does not have a chat template. " "This model is not supported, as we only support instruct models with a chat template." ) logger.error(error_message) raise ValueError(error_message) def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: """ Validates the provided prompt request response. """ if len(prompt_request.request_pieces) != 1: raise ValueError("This target only supports a single prompt request piece.") if prompt_request.request_pieces[0].converted_value_data_type != "text": raise ValueError("This target only supports text prompt input.")
[docs] @classmethod def enable_cache(cls): """Enables the class-level cache.""" cls._cache_enabled = True logger.info("Class-level cache enabled.")
[docs] @classmethod def disable_cache(cls): """Disables the class-level cache and clears the cache.""" cls._cache_enabled = False cls._cached_model = None cls._cached_tokenizer = None cls._cached_model_id = None logger.info("Class-level cache disabled and cleared.")