Source code for pyrit.prompt_target.ollama_chat_target

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Optional

from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer
from pyrit.common import default_values, net_utility
from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse
from pyrit.models import construct_response_from_request
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute


logger = logging.getLogger(__name__)


[docs] class OllamaChatTarget(PromptChatTarget): ENDPOINT_URI_ENVIRONMENT_VARIABLE = "OLLAMA_ENDPOINT" MODEL_NAME_ENVIRONMENT_VARIABLE = "OLLAMA_MODEL_NAME"
[docs] def __init__( self, *, endpoint: str = None, model_name: str = None, chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(), max_requests_per_minute: Optional[int] = None, ) -> None: PromptChatTarget.__init__(self, max_requests_per_minute=max_requests_per_minute) self.endpoint = endpoint or default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) self.model_name = model_name or default_values.get_required_value( env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=model_name ) self.chat_message_normalizer = chat_message_normalizer
@limit_requests_per_minute async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: self._validate_request(prompt_request=prompt_request) request: PromptRequestPiece = prompt_request.request_pieces[0] messages = self._memory.get_chat_messages_with_conversation_id(conversation_id=request.conversation_id) messages.append(request.to_chat_message()) logger.info(f"Sending the following prompt to the prompt target: {self} {request}") resp = await self._complete_chat_async(messages=messages) if not resp: raise ValueError("The chat returned an empty response.") logger.info(f'Received the following response from the prompt target "{resp}"') return construct_response_from_request(request=request, response_text_pieces=[resp]) async def _complete_chat_async( self, messages: list[ChatMessage], ) -> str: headers = self._get_headers() payload = self._construct_http_body(messages) response = await net_utility.make_request_and_raise_if_error_async( endpoint_uri=self.endpoint, method="POST", request_body=payload, headers=headers ) return response.json()["message"]["content"] def _construct_http_body( self, messages: list[ChatMessage], ) -> dict: squashed_messages = self.chat_message_normalizer.normalize(messages) messages_list = [message.model_dump(exclude_none=True) for message in squashed_messages] data = { "model": self.model_name, "messages": messages_list, "stream": False, } return data def _get_headers(self) -> dict: headers: dict = { "Content-Type": "application/json", } return headers def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: if len(prompt_request.request_pieces) != 1: raise ValueError("This target only supports a single prompt request piece.") if prompt_request.request_pieces[0].converted_value_data_type != "text": raise ValueError("This target only supports text prompt input.")