Source code for pyrit.auth.azure_auth
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import msal
import logging
from azure.core.credentials import AccessToken
from azure.identity import AzureCliCredential
from azure.identity import ManagedIdentityCredential
from azure.identity import InteractiveBrowserCredential
from azure.identity import DefaultAzureCredential
from azure.identity import get_bearer_token_provider
from pyrit.auth.auth_config import AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE, REFRESH_TOKEN_BEFORE_MSEC
from pyrit.auth.authenticator import Authenticator
logger = logging.getLogger(__name__)
[docs]
class AzureAuth(Authenticator):
"""
Azure CLI Authentication.
"""
_access_token: AccessToken
_tenant_id: str
_token_scope: str
[docs]
def __init__(self, token_scope: str, tenant_id: str = ""):
self._tenant_id = tenant_id
self._token_scope = token_scope
azure_creds = AzureCliCredential(tenant_id=tenant_id)
self._access_token = azure_creds.get_token(self._token_scope)
# Make the token available to the user
self.token = self._access_token.token
[docs]
def refresh_token(self) -> str:
"""Refresh the access token if it is expired.
Returns:
A token
"""
curr_epoch_time_in_ms = int(time.time()) * 1_000
access_token_epoch_expiration_time_in_ms = int(self._access_token.expires_on) * 1_000
# Adjust the expiration time to be before the actual expiration time so that user can use the token
# for a while before it expires. This improves user experience. The token is refreshed REFRESH_TOKEN_BEFORE_MSEC
# before it expires.
token_expires_on_in_ms = access_token_epoch_expiration_time_in_ms - REFRESH_TOKEN_BEFORE_MSEC
if token_expires_on_in_ms <= curr_epoch_time_in_ms:
# Token is expired, generate a new one
azure_creds = AzureCliCredential(tenant_id=self._tenant_id)
self._access_token = azure_creds.get_token(self._token_scope)
self.token = self._access_token.token
return self.token
[docs]
def get_token(self) -> str:
"""
Get the current token.
Returns: The current token
"""
return self.token
def get_access_token_from_azure_msi(*, client_id: str, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE):
"""Connect to an AOAI endpoint via managed identity credential attached to an Azure resource.
For proper setup and configuration of MSI
https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/overview.
Args:
client id of the service
Returns:
Authentication token
"""
try:
credential = ManagedIdentityCredential(client_id=client_id)
token = credential.get_token(scope)
return token.token
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise
def get_access_token_from_msa_public_client(*, client_id: str, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE):
"""Uses MSA account to connect to an AOAI endpoint via interactive login. A browser window
will open and ask for login credentials.
Args:
client id
Returns:
Authentication token
"""
try:
app = msal.PublicClientApplication(client_id)
result = app.acquire_token_interactive(scopes=[scope])
return result["access_token"]
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}")
raise
def get_access_token_from_interactive_login(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE):
"""Connects to an OpenAI endpoint with an interactive login from Azure. A browser window will
open and ask for login credentials. The token will be scoped for Azure Cognitive services.
Returns:
Authentication token
"""
try:
token_provider = get_bearer_token_provider(InteractiveBrowserCredential(), scope)
return token_provider()
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}': {e}")
raise
def get_token_provider_from_default_azure_credential(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE):
"""Connect to an AOAI endpoint via default Azure credential.
Returns:
Authentication token provider
"""
try:
token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope)
return token_provider
except Exception as e:
logger.error(f"Failed to obtain token for '{scope}': {e}")
raise