Source code for pyrit.auth.azure_auth

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import time

import msal
from azure.core.credentials import AccessToken
from azure.identity import (
    AzureCliCredential,
    DefaultAzureCredential,
    InteractiveBrowserCredential,
    ManagedIdentityCredential,
    get_bearer_token_provider,
)

from pyrit.auth.auth_config import AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE, REFRESH_TOKEN_BEFORE_MSEC
from pyrit.auth.authenticator import Authenticator

logger = logging.getLogger(__name__)


[docs] class AzureAuth(Authenticator): """ Azure CLI Authentication. """ _access_token: AccessToken _tenant_id: str _token_scope: str
[docs] def __init__(self, token_scope: str, tenant_id: str = ""): self._tenant_id = tenant_id self._token_scope = token_scope azure_creds = AzureCliCredential(tenant_id=tenant_id) self._access_token = azure_creds.get_token(self._token_scope) # Make the token available to the user self.token = self._access_token.token
[docs] def refresh_token(self) -> str: """Refresh the access token if it is expired. Returns: A token """ curr_epoch_time_in_ms = int(time.time()) * 1_000 access_token_epoch_expiration_time_in_ms = int(self._access_token.expires_on) * 1_000 # Adjust the expiration time to be before the actual expiration time so that user can use the token # for a while before it expires. This improves user experience. The token is refreshed REFRESH_TOKEN_BEFORE_MSEC # before it expires. token_expires_on_in_ms = access_token_epoch_expiration_time_in_ms - REFRESH_TOKEN_BEFORE_MSEC if token_expires_on_in_ms <= curr_epoch_time_in_ms: # Token is expired, generate a new one azure_creds = AzureCliCredential(tenant_id=self._tenant_id) self._access_token = azure_creds.get_token(self._token_scope) self.token = self._access_token.token return self.token
[docs] def get_token(self) -> str: """ Get the current token. Returns: The current token """ return self.token
def get_access_token_from_azure_msi(*, client_id: str, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Connect to an AOAI endpoint via managed identity credential attached to an Azure resource. For proper setup and configuration of MSI https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/overview. Args: client id of the service Returns: Authentication token """ try: credential = ManagedIdentityCredential(client_id=client_id) token = credential.get_token(scope) return token.token except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise def get_access_token_from_msa_public_client(*, client_id: str, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Uses MSA account to connect to an AOAI endpoint via interactive login. A browser window will open and ask for login credentials. Args: client id Returns: Authentication token """ try: app = msal.PublicClientApplication(client_id) result = app.acquire_token_interactive(scopes=[scope]) return result["access_token"] except Exception as e: logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") raise def get_access_token_from_interactive_login(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Connects to an OpenAI endpoint with an interactive login from Azure. A browser window will open and ask for login credentials. The token will be scoped for Azure Cognitive services. Returns: Authentication token """ try: token_provider = get_bearer_token_provider(InteractiveBrowserCredential(), scope) return token_provider() except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") raise def get_token_provider_from_default_azure_credential(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Connect to an AOAI endpoint via default Azure credential. Returns: Authentication token provider """ try: token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope) return token_provider except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") raise