Source code for pyrit.embedding.azure_text_embedding
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from openai import AzureOpenAI
from pyrit.common import default_values
from pyrit.embedding._text_embedding import _TextEmbedding
[docs]
class AzureTextEmbedding(_TextEmbedding):
API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_KEY"
ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_ENDPOINT"
DEPLOYMENT_ENVIRONMENT_VARIABLE: str = "AZURE_OPENAI_EMBEDDING_DEPLOYMENT"
[docs]
def __init__(
self, *, api_key: str = None, endpoint: str = None, deployment: str = None, api_version: str = "2024-02-01"
) -> None:
"""Generate embedding using the Azure API
Args:
api_key: The API key to use. Defaults to environment variable.
endpoint: The API base to use, sometimes referred to as the api_base. Defaults to environment variable.
deployment: The engine to use, in AOAI referred to as deployment, in some APIs referred to as model.
Defaults to environment variable.
api_version: The API version to use. Defaults to "2024-02-01".
"""
api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
endpoint = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint
)
deployment = default_values.get_required_value(
env_var_name=self.DEPLOYMENT_ENVIRONMENT_VARIABLE, passed_value=deployment
)
self._client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=endpoint,
azure_deployment=deployment,
)
self._model = deployment
super().__init__()