Source code for pyrit.embedding.azure_text_embedding

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional

from openai import AzureOpenAI

from pyrit.auth.azure_auth import AzureAuth, get_default_scope
from pyrit.common import default_values
from pyrit.embedding._text_embedding import _TextEmbedding


[docs] class AzureTextEmbedding(_TextEmbedding): """ Provide text embedding functionality using Azure OpenAI services. """ API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_KEY" ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_ENDPOINT" DEPLOYMENT_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_DEPLOYMENT"
[docs] def __init__( self, *, api_key: Optional[str] = None, endpoint: Optional[str] = None, deployment: Optional[str] = None, api_version: str = "2024-02-01", use_entra_auth: bool = False, ) -> None: """ Generate embedding using the Azure API. Authenticate with either an API key or Entra authentication. Args: api_key: The API key to use (only if you're not using Entra authentication). Defaults to environment variable. endpoint: The API base to use, sometimes referred to as the api_base. Defaults to environment variable. deployment: The engine to use, in AOAI referred to as deployment, in some APIs referred to as model. Defaults to environment variable. api_version: The API version to use. Defaults to "2024-02-01". use_entra_auth: Whether to use Entra authentication. Defaults to False. Raises: ValueError: If using Entra ID auth and an api_key is also provided. """ endpoint = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) deployment = default_values.get_required_value( env_var_name=self.DEPLOYMENT_ENVIRONMENT_VARIABLE, passed_value=deployment ) if use_entra_auth: if api_key: raise ValueError("If using Entra ID auth, please do not specify api_key.") scope = get_default_scope(endpoint) token = AzureAuth(token_scope=scope).get_token() self._client = AzureOpenAI( api_version=api_version, azure_endpoint=endpoint, azure_deployment=deployment, azure_ad_token=token, ) else: api_key = default_values.get_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) self._client = AzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=endpoint, azure_deployment=deployment, ) self._model = deployment super().__init__()