Source code for pyrit.prompt_converter.azure_speech_audio_to_text_converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import time
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
    import azure.cognitiveservices.speech as speechsdk  # noqa: F401

from pyrit.auth.azure_auth import get_speech_config
from pyrit.common import default_values
from pyrit.models import PromptDataType
from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.prompt_converter import ConverterResult, PromptConverter

logger = logging.getLogger(__name__)


[docs] class AzureSpeechAudioToTextConverter(PromptConverter): """ Transcribes a .wav audio file into text using Azure AI Speech service. https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-to-text """ #: The name of the Azure region. AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE: str = "AZURE_SPEECH_REGION" #: The API key for accessing the service. AZURE_SPEECH_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_SPEECH_KEY" #: The resource ID for accessing the service when using Entra ID auth. AZURE_SPEECH_RESOURCE_ID_ENVIRONMENT_VARIABLE: str = "AZURE_SPEECH_RESOURCE_ID"
[docs] def __init__( self, azure_speech_region: Optional[str] = None, azure_speech_key: Optional[str] = None, azure_speech_resource_id: Optional[str] = None, use_entra_auth: bool = False, recognition_language: str = "en-US", ) -> None: """ Initializes the converter with Azure Speech service credentials and recognition language. Args: azure_speech_region (str, Optional): The name of the Azure region. azure_speech_key (str, Optional): The API key for accessing the service (if not using Entra ID auth). azure_speech_resource_id (str, Optional): The resource ID for accessing the service when using Entra ID auth. This can be found by selecting 'Properties' in the 'Resource Management' section of your Azure Speech resource in the Azure portal. use_entra_auth (bool): Whether to use Entra ID authentication. If True, azure_speech_resource_id must be provided. If False, azure_speech_key must be provided. Defaults to False. recognition_language (str): Recognition voice language. Defaults to "en-US". For more on supported languages, see the following link: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support Raises: ValueError: If the required environment variables are not set, if azure_speech_key is passed in when use_entra_auth is True, or if azure_speech_resource_id is passed in when use_entra_auth is False. """ self._azure_speech_region: str = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_REGION_ENVIRONMENT_VARIABLE, passed_value=azure_speech_region, ) if use_entra_auth: if azure_speech_key: raise ValueError("If using Entra ID auth, please do not specify azure_speech_key.") self._azure_speech_resource_id = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_RESOURCE_ID_ENVIRONMENT_VARIABLE, passed_value=azure_speech_resource_id, ) self._azure_speech_key = None else: if azure_speech_resource_id: raise ValueError("If using key auth, please do not specify azure_speech_resource_id.") self._azure_speech_key = default_values.get_required_value( env_var_name=self.AZURE_SPEECH_KEY_ENVIRONMENT_VARIABLE, passed_value=azure_speech_key, ) self._azure_speech_resource_id = None self._recognition_language = recognition_language # Create a flag to indicate when recognition is finished self.done = False
[docs] def input_supported(self, input_type: PromptDataType) -> bool: return input_type == "audio_path"
[docs] def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "text"
[docs] async def convert_async(self, *, prompt: str, input_type: PromptDataType = "audio_path") -> ConverterResult: """ Converts the given audio file into its text representation. Args: prompt (str): File path to the audio file to be transcribed. input_type (PromptDataType): The type of input data. Returns: ConverterResult: The result containing the transcribed text. Raises: ValueError: If the input type is not supported or if the provided file is not a .wav file. """ if not self.input_supported(input_type): raise ValueError("Input type not supported") if not prompt.endswith(".wav"): raise ValueError("Please provide a .wav audio file. Compressed formats are not currently supported.") audio_serializer = data_serializer_factory( category="prompt-memory-entries", data_type="audio_path", value=prompt ) audio_bytes = await audio_serializer.read_data() try: transcript = self.recognize_audio(audio_bytes) except Exception as e: logger.error("Failed to convert audio file to text: %s", str(e)) raise return ConverterResult(output_text=transcript, output_type="text")
[docs] def recognize_audio(self, audio_bytes: bytes) -> str: """ Recognizes audio file and returns transcribed text. Args: audio_bytes (bytes): Audio bytes input. Returns: str: Transcribed text. """ try: import azure.cognitiveservices.speech as speechsdk # noqa: F811 except ModuleNotFoundError as e: logger.error( "Could not import azure.cognitiveservices.speech. " + "You may need to install it via 'pip install pyrit[speech]'" ) raise e speech_config = get_speech_config( resource_id=self._azure_speech_resource_id, key=self._azure_speech_key, region=self._azure_speech_region ) speech_config.speech_recognition_language = self._recognition_language # Create a PullAudioInputStream from the byte stream push_stream = speechsdk.audio.PushAudioInputStream() audio_config = speechsdk.audio.AudioConfig(stream=push_stream) # Instantiate a speech recognizer object speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config) # Create an empty list to store recognized text transcribed_text: list[str] = [] # Flag is set to False to indicate that recognition is not yet finished self.done = False # Connect callbacks to the events fired by the speech recognizer speech_recognizer.recognized.connect(lambda evt: self.transcript_cb(evt, transcript=transcribed_text)) speech_recognizer.recognizing.connect(lambda evt: logger.info("RECOGNIZING: {}".format(evt))) speech_recognizer.recognized.connect(lambda evt: logger.info("RECOGNIZED: {}".format(evt))) speech_recognizer.session_started.connect(lambda evt: logger.info("SESSION STARTED: {}".format(evt))) speech_recognizer.session_stopped.connect(lambda evt: logger.info("SESSION STOPPED: {}".format(evt))) # Stop continuous recognition when stopped or canceled event is fired speech_recognizer.canceled.connect(lambda evt: self.stop_cb(evt, recognizer=speech_recognizer)) speech_recognizer.session_stopped.connect(lambda evt: self.stop_cb(evt, recognizer=speech_recognizer)) # Start continuous recognition speech_recognizer.start_continuous_recognition_async() # Push the entire audio data into the stream push_stream.write(audio_bytes) push_stream.close() while not self.done: time.sleep(0.5) return "".join(transcribed_text)
[docs] def transcript_cb(self, evt: Any, transcript: list[str]) -> None: """ Callback function that appends transcribed text upon receiving a "recognized" event. Args: evt (speechsdk.SpeechRecognitionEventArgs): Event. transcript (list): List to store transcribed text. """ logger.info("RECOGNIZED: {}".format(evt.result.text)) transcript.append(evt.result.text)
[docs] def stop_cb(self, evt: Any, recognizer: Any) -> None: """ Callback function that stops continuous recognition upon receiving an event 'evt'. Args: evt (speechsdk.SpeechRecognitionEventArgs): Event. recognizer (speechsdk.SpeechRecognizer): Speech recognizer object. """ try: import azure.cognitiveservices.speech as speechsdk # noqa: F811 except ModuleNotFoundError as e: logger.error( "Could not import azure.cognitiveservices.speech. " + "You may need to install it via 'pip install pyrit[speech]'" ) raise e logger.info("CLOSING on {}".format(evt)) recognizer.stop_continuous_recognition_async() self.done = True if evt.result.reason == speechsdk.ResultReason.Canceled: cancellation_details = evt.result.cancellation_details logger.info("Speech recognition canceled: {}".format(cancellation_details.reason)) if cancellation_details.reason == speechsdk.CancellationReason.Error: logger.error("Error details: {}".format(cancellation_details.error_details)) elif cancellation_details.reason == speechsdk.CancellationReason.EndOfStream: logger.info("End of audio stream detected.")