Source code for pyrit.prompt_converter.audio_volume_converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import io
import logging
from typing import Any, Literal

import numpy as np
from scipy.io import wavfile

from pyrit.models import PromptDataType, data_serializer_factory
from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter

logger = logging.getLogger(__name__)


[docs] class AudioVolumeConverter(PromptConverter): """ Changes the volume of an audio file by scaling the amplitude. A volume_factor > 1.0 increases the volume (louder), while a volume_factor < 1.0 decreases it (quieter). A volume_factor of 1.0 leaves the audio unchanged. The converter scales all audio samples by the given factor and clips the result to the valid range for the original data type. Sample rate, bit depth, and number of channels are preserved. """ SUPPORTED_INPUT_TYPES = ("audio_path",) SUPPORTED_OUTPUT_TYPES = ("audio_path",) #: Accepted audio formats for conversion. AcceptedAudioFormats = Literal["wav"]
[docs] def __init__( self, *, output_format: AcceptedAudioFormats = "wav", volume_factor: float = 1.5, ) -> None: """ Initialize the converter with the specified output format and volume factor. Args: output_format (str): The format of the audio file, defaults to "wav". volume_factor (float): The factor by which to scale the volume. Values > 1.0 increase volume, values < 1.0 decrease volume. Must be greater than 0. Defaults to 1.5. Raises: ValueError: If volume_factor is not positive. """ if volume_factor <= 0: raise ValueError("volume_factor must be greater than 0.") self._output_format = output_format self._volume_factor = volume_factor
def _apply_volume(self, data: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: """ Scale audio samples by the volume factor and clip to the valid range. Args: data: 1-D numpy array of audio samples. Returns: numpy array with the volume adjusted, same length and dtype as input. """ scaled = data.astype(np.float64) * self._volume_factor # Clip to the valid range for the original dtype if np.issubdtype(data.dtype, np.integer): info = np.iinfo(data.dtype) scaled = np.clip(scaled, info.min, info.max) return scaled
[docs] async def convert_async(self, *, prompt: str, input_type: PromptDataType = "audio_path") -> ConverterResult: """ Convert the given audio file by changing its volume. The audio samples are scaled by the volume factor. For integer audio formats the result is clipped to prevent overflow. Args: prompt (str): File path to the audio file to be converted. input_type (PromptDataType): The type of input data. Returns: ConverterResult: The result containing the converted audio file path. Raises: ValueError: If the input type is not supported. Exception: If there is an error during the conversion process. """ if not self.input_supported(input_type): raise ValueError("Input type not supported") try: # Create serializer to read audio data audio_serializer = data_serializer_factory( category="prompt-memory-entries", data_type="audio_path", extension=self._output_format, value=prompt ) audio_bytes = await audio_serializer.read_data() # Read the audio file bytes and process the data bytes_io = io.BytesIO(audio_bytes) sample_rate, data = wavfile.read(bytes_io) original_dtype = data.dtype # Apply volume scaling to each channel if data.ndim == 1: # Mono audio volume_data = self._apply_volume(data).astype(original_dtype) else: # Multi-channel audio (e.g., stereo) channels = [] for ch in range(data.shape[1]): channels.append(self._apply_volume(data[:, ch])) volume_data = np.column_stack(channels).astype(original_dtype) # Write the processed data as a new WAV file output_bytes_io = io.BytesIO() wavfile.write(output_bytes_io, sample_rate, volume_data) # Save the converted bytes using the serializer converted_bytes = output_bytes_io.getvalue() await audio_serializer.save_data(data=converted_bytes) audio_serializer_file = str(audio_serializer.value) logger.info( "Volume changed by factor %.2f for [%s], and the audio was saved to [%s]", self._volume_factor, prompt, audio_serializer_file, ) except Exception as e: logger.error("Failed to convert audio volume: %s", str(e)) raise return ConverterResult(output_text=audio_serializer_file, output_type=input_type)