# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from httpx import HTTPStatusError
from typing import Optional
from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer
from pyrit.common import default_values, net_utility
from pyrit.exceptions import EmptyResponseException, RateLimitException
from pyrit.exceptions import handle_bad_request_exception, pyrit_target_retry
from pyrit.models import ChatMessage, PromptRequestResponse
from pyrit.models import construct_response_from_request
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute
logger = logging.getLogger(__name__)
[docs]
class AzureMLChatTarget(PromptChatTarget):
endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT"
api_key_environment_variable: str = "AZURE_ML_KEY"
[docs]
def __init__(
self,
*,
endpoint: str = None,
api_key: str = None,
chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(),
max_new_tokens: int = 400,
temperature: float = 1.0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
max_requests_per_minute: Optional[int] = None,
**param_kwargs,
) -> None:
"""
Initializes an instance of the AzureMLChatTarget class. This class works with most chat completion
Instruct models deployed on Azure AI Machine Learning Studio endpoints
(including but not limited to: mistralai-Mixtral-8x7B-Instruct-v01, mistralai-Mistral-7B-Instruct-v01,
Phi-3.5-MoE-instruct, Phi-3-mini-4k-instruct, Llama-3.2-3B-Instruct, and Meta-Llama-3.1-8B-Instruct).
Please create or adjust environment variables (endpoint and key) as needed for the
model you are using.
Args:
endpoint (str, Optional): The endpoint URL for the deployed Azure ML model.
Defaults to the value of the AZURE_ML_MANAGED_ENDPOINT environment variable.
api_key (str, Optional): The API key for accessing the Azure ML endpoint.
Defaults to the value of the AZURE_ML_KEY environment variable.
chat_message_normalizer (ChatMessageNormalizer, Optional): The chat message normalizer.
For models that do not allow system prompts such as mistralai-Mixtral-8x7B-Instruct-v01,
GenericSystemSquash() can be passed in. Defaults to ChatMessageNop(), which does not
alter the chat messages.
max_new_tokens (int, Optional): The maximum number of tokens to generate in the response.
Defaults to 400.
temperature (float, Optional): The temperature for generating diverse responses. 1.0 is most random,
0.0 is least random. Defaults to 1.0.
top_p (float, Optional): The top-p value for generating diverse responses. It represents
the cumulative probability of the top tokens to keep. Defaults to 1.0.
repetition_penalty (float, Optional): The repetition penalty for generating diverse responses.
1.0 means no penalty with a greater value (up to 2.0) meaning more penalty for repeating tokens.
Defaults to 1.2.
max_requests_per_minute (int, Optional): Number of requests the target can handle per
minute before hitting a rate limit. The number of requests sent to the target
will be capped at the value provided.
**param_kwargs: Additional parameters to pass to the model for generating responses. Example
parameters can be found here: https://huggingface.co/docs/api-inference/tasks/text-generation.
Note that the link above may not be comprehensive, and specific acceptable parameters may be
model-dependent. If a model does not accept a certain parameter that is passed in, it will be skipped
without throwing an error.
"""
PromptChatTarget.__init__(self, max_requests_per_minute=max_requests_per_minute)
self._initialize_vars(endpoint=endpoint, api_key=api_key)
self.chat_message_normalizer = chat_message_normalizer
self._max_new_tokens = max_new_tokens
self._temperature = temperature
self._top_p = top_p
self._repetition_penalty = repetition_penalty
self._extra_parameters = param_kwargs
def _set_env_configuration_vars(
self, endpoint_uri_environment_variable: str = None, api_key_environment_variable: str = None
) -> None:
"""
Sets the environment configuration variable names from which to pull the endpoint uri and the api key
to access the deployed Azure ML model. Use this function to set the environment variable names to
however they are named in the .env file and pull the corresponding endpoint uri and api key.
This is the recommended way to pass in a uri and key to access the model endpoint.
Defaults to "AZURE_ML_MANAGED_ENDPOINT" and "AZURE_ML_KEY".
Args:
endpoint_uri_environment_variable (str): The environment variable name for the endpoint uri.
api_key_environment_variable (str): The environment variable name for the api key.
Returns:
None
"""
self.endpoint_uri_environment_variable = endpoint_uri_environment_variable or "AZURE_ML_MANAGED_ENDPOINT"
self.api_key_environment_variable = api_key_environment_variable or "AZURE_ML_KEY"
self._initialize_vars()
def _initialize_vars(self, endpoint: str = None, api_key: str = None) -> None:
"""
Sets the endpoint and key for accessing the Azure ML model. Use this function to manually
pass in your own endpoint uri and api key. Defaults to the values in the .env file for the variables
stored in self.endpoint_uri_environment_variable and self.api_key_environment_variable (which default to
"AZURE_ML_MANAGED_ENDPOINT" and "AZURE_ML_KEY" respectively). It is recommended to set these variables
in the .env file and call _set_env_configuration_vars rather than passing the uri and key directly to
this function or the target constructor.
Args:
endpoint (str): The endpoint uri for the deployed Azure ML model.
api_key (str): The API key for accessing the Azure ML endpoint.
Returns:
None
"""
self._endpoint = default_values.get_required_value(
env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint
)
self._api_key = default_values.get_required_value(
env_var_name=self.api_key_environment_variable, passed_value=api_key
)
def _set_model_parameters(
self,
max_new_tokens: int = None,
temperature: float = None,
top_p: float = None,
repetition_penalty: float = None,
**param_kwargs,
) -> None:
"""
Sets the model parameters for generating responses, offering the option to add additional ones not
explicitly listed.
"""
self._max_new_tokens = max_new_tokens or self._max_new_tokens
self._temperature = temperature or self._temperature
self._top_p = top_p or self._top_p
self._repetition_penalty = repetition_penalty or self._repetition_penalty
# Set any other parameters via additional keyword arguments
self._extra_parameters = param_kwargs
@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]
messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)
messages.append(request.to_chat_message())
logger.info(f"Sending the following prompt to the prompt target: {request}")
try:
resp_text = await self._complete_chat_async(
messages=messages,
)
if not resp_text:
raise EmptyResponseException(message="The chat returned an empty response.")
response_entry = construct_response_from_request(request=request, response_text_pieces=[resp_text])
except HTTPStatusError as hse:
if hse.response.status_code == 400:
# Handle Bad Request
response_entry = handle_bad_request_exception(response_text=hse.response.text, request=request)
elif hse.response.status_code == 429:
raise RateLimitException()
else:
raise hse
logger.info(
"Received the following response from the prompt target"
+ f"{response_entry.request_pieces[0].converted_value}"
)
return response_entry
@pyrit_target_retry
async def _complete_chat_async(
self,
messages: list[ChatMessage],
) -> str:
"""
Completes a chat interaction by generating a response to the given input prompt.
This is a synchronous wrapper for the asynchronous _generate_and_extract_response method.
Args:
messages (list[ChatMessage]): The chat messages objects containing the role and content.
Raises:
Exception: For any errors during the process.
Returns:
str: The generated response message.
"""
headers = self._get_headers()
payload = self._construct_http_body(messages)
response = await net_utility.make_request_and_raise_if_error_async(
endpoint_uri=self._endpoint, method="POST", request_body=payload, headers=headers
)
return response.json()["output"]
def _construct_http_body(
self,
messages: list[ChatMessage],
) -> dict:
"""Constructs the HTTP request body for the AML online endpoint."""
squashed_messages = self.chat_message_normalizer.normalize(messages)
messages_dict = [message.model_dump() for message in squashed_messages]
# parameters include additional ones passed in through **kwargs. Those not accepted by the model will
# be ignored.
data = {
"input_data": {
"input_string": messages_dict,
"parameters": {
"max_new_tokens": self._max_new_tokens,
"temperature": self._temperature,
"top_p": self._top_p,
"stop": ["</s>"],
"stop_sequences": ["</s>"],
"return_full_text": False,
"repetition_penalty": self._repetition_penalty,
}
| self._extra_parameters,
}
}
return data
def _get_headers(self) -> dict:
"""Headers for accessing inference endpoint deployed in AML.
Returns:
headers(dict): contains bearer token as AML key and content-type: JSON
"""
headers: dict = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self._api_key),
}
return headers
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports a single prompt request piece.")
if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")