Source code for pyrit.prompt_target.openai.openai_completion_target
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional
from openai import NOT_GIVEN, NotGiven
from openai.types.completion import Completion
from pyrit.models import PromptRequestResponse, PromptResponse, construct_response_from_request
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute
logger = logging.getLogger(__name__)
[docs]
class OpenAICompletionTarget(OpenAITarget):
[docs]
def __init__(
self,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
temperature: float = 1.0,
top_p: float = 1.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
*args,
**kwargs,
):
"""
Args:
max_tokens (int, Optional): The maximum number of tokens that can be generated in the
completion. The token count of your prompt plus `max_tokens` cannot exceed the model's
context length.
"""
super().__init__(*args, **kwargs)
self._max_tokens = max_tokens
self._temperature = temperature
self._top_p = top_p
self._frequency_penalty = frequency_penalty
self._presence_penalty = presence_penalty
def _set_azure_openai_env_configuration_vars(self):
self.deployment_environment_variable = "AZURE_OPENAI_COMPLETION_DEPLOYMENT"
self.endpoint_uri_environment_variable = "AZURE_OPENAI_COMPLETION_ENDPOINT"
self.api_key_environment_variable = "AZURE_OPENAI_COMPLETION_KEY"
@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Sends a normalized prompt async to the prompt target.
"""
self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]
logger.info(f"Sending the following prompt to the prompt target: {request}")
text_response: Completion = await self._async_client.completions.create(
model=self._deployment_name,
prompt=request.converted_value,
top_p=self._top_p,
temperature=self._temperature,
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
max_tokens=self._max_tokens,
)
prompt_response = PromptResponse(
completion=text_response.choices[0].text,
prompt=request.converted_value,
id=text_response.id,
completion_tokens=text_response.usage.completion_tokens,
prompt_tokens=text_response.usage.prompt_tokens,
total_tokens=text_response.usage.total_tokens,
model=text_response.model,
object=text_response.object,
)
response_entry = construct_response_from_request(
request=request,
response_text_pieces=[prompt_response.completion],
prompt_metadata=prompt_response.to_json(),
)
return response_entry
def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if len(prompt_request.request_pieces) != 1:
raise ValueError("This target only supports a single prompt request piece.")
if prompt_request.request_pieces[0].converted_value_data_type != "text":
raise ValueError("This target only supports text prompt input.")
request = prompt_request.request_pieces[0]
messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id)
if len(messages) > 0:
raise ValueError("This target only supports a single turn conversation.")