Source code for pyrit.prompt_target.openai.openai_tts_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Literal

from httpx import HTTPStatusError

from pyrit.common import net_utility
from pyrit.exceptions import RateLimitException, handle_bad_request_exception
from pyrit.models import PromptRequestResponse, construct_response_from_request, data_serializer_factory
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute

logger = logging.getLogger(__name__)

TTSModel = Literal["tts-1", "tts-1-hd"]
TTSVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
TTSResponseFormat = Literal["flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"]


[docs] class OpenAITTSTarget(OpenAITarget):
[docs] def __init__( self, voice: TTSVoice = "alloy", response_format: TTSResponseFormat = "mp3", model: TTSModel = "tts-1", language: str = "en", api_version: str = "2024-03-01-preview", *args, **kwargs, ): if (kwargs.get("use_aad_auth") is not None) and (kwargs.get("use_aad_auth") is True): raise NotImplementedError("AAD authentication not implemented for TTSTarget yet.") super().__init__(*args, **kwargs) self._voice = voice self._model = model self._response_format = response_format self._language = language self._api_version = api_version
def _set_azure_openai_env_configuration_vars(self): self.deployment_environment_variable = "AZURE_OPENAI_TTS_DEPLOYMENT" self.endpoint_uri_environment_variable = "AZURE_OPENAI_TTS_ENDPOINT" self.api_key_environment_variable = "AZURE_OPENAI_TTS_KEY" @limit_requests_per_minute async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: self._validate_request(prompt_request=prompt_request) request = prompt_request.request_pieces[0] logger.info(f"Sending the following prompt to the prompt target: {request}") body: dict[str, object] = { "model": self._model, "input": request.converted_value, "voice": self._voice, "file": self._response_format, "language": self._language, } self._extra_headers["api-key"] = self._api_key response_entry = None try: # Note the openai client doesn't work here, potentially due to a mismatch response = await net_utility.make_request_and_raise_if_error_async( endpoint_uri=f"{self._endpoint}/openai/deployments/{self._deployment_name}/" f"audio/speech?api-version={self._api_version}", method="POST", headers=self._extra_headers, request_body=body, ) except HTTPStatusError as hse: if hse.response.status_code == 400: # Handle Bad Request response_entry = handle_bad_request_exception(response_text=hse.response.text, request=request) elif hse.response.status_code == 429: raise RateLimitException() else: raise hse logger.info("Received valid response from the prompt target") audio_response = data_serializer_factory(data_type="audio_path", extension=self._response_format) data = response.content await audio_response.save_data(data=data) if not response_entry: response_entry = construct_response_from_request( request=request, response_text_pieces=[str(audio_response.value)], response_type="audio_path" ) return response_entry def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: if len(prompt_request.request_pieces) != 1: raise ValueError("This target only supports a single prompt request piece.") if prompt_request.request_pieces[0].converted_value_data_type != "text": raise ValueError("This target only supports text prompt input.") request = prompt_request.request_pieces[0] messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) if len(messages) > 0: raise ValueError("This target only supports a single turn conversation.")