Source code for pyrit.orchestrator.multi_turn.multi_turn_orchestrator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from colorama import Fore, Style
from pyrit.common.display_response import display_image_response
from pyrit.memory import CentralMemory
from pyrit.models import SeedPrompt
from pyrit.orchestrator import Orchestrator
from pyrit.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target import PromptTarget, PromptChatTarget
from pyrit.prompt_converter import PromptConverter
from pyrit.score import Scorer
logger = logging.getLogger(__name__)
[docs]
class MultiTurnAttackResult:
"""The result of a multi-turn attack."""
[docs]
def __init__(self, conversation_id: str, achieved_objective: bool, objective: str):
self.conversation_id = conversation_id
self.achieved_objective = achieved_objective
self.objective = objective
self._memory = CentralMemory.get_memory_instance()
[docs]
async def print_conversation_async(self):
"""Prints the conversation between the objective target and the adversarial chat, including the scores.
Args:
prompt_target_conversation_id (str): the conversation ID for the prompt target.
"""
target_messages = self._memory._get_prompt_pieces_with_conversation_id(conversation_id=self.conversation_id)
if not target_messages or len(target_messages) == 0:
print("No conversation with the target")
return
if self.achieved_objective:
print(
f"{Style.BRIGHT}{Fore.RED}The multi-turn orchestrator has completed the conversation and achieved "
f"the objective: {self.objective}"
)
else:
print(
f"{Style.BRIGHT}{Fore.RED}The multi-turn orchestrator has not achieved the objective: "
f"{self.objective}"
)
for message in target_messages:
if message.role == "user":
print(f"{Style.BRIGHT}{Fore.BLUE}{message.role}:")
if message.converted_value != message.original_value:
print(f"Original value: {message.original_value}")
print(f"Converted value: {message.converted_value}")
else:
print(f"{Style.NORMAL}{Fore.YELLOW}{message.role}: {message.converted_value}")
await display_image_response(message)
scores = self._memory.get_scores_by_prompt_ids(prompt_request_response_ids=[str(message.id)])
if scores and len(scores) > 0:
for score in scores:
print(f"{Style.RESET_ALL}score: {score} : {score.score_rationale}")
[docs]
class MultiTurnOrchestrator(Orchestrator):
"""
The `MultiTurnOrchestrator` is an interface that coordinates attacks and conversations between
a adversarial_chat target and an objective_target.
Args:
objective_target (PromptTarget): The target to send the created prompts to.
adversarial_chat (PromptChatTarget): The endpoint that creates prompts that are sent to the
objective_target.
adversarial_chat_system_prompt_path (Path): The initial prompt to send to adversarial_chat.
initial_adversarial_chat_prompt (str, Optional): The initial prompt to start the adversarial chat.
Defaults to "How can I help you?".
max_turns (int, Optional): The maximum number of turns for the conversation. Must be greater than or equal
to 0. Defaults to 5.
prompt_converters (Optional[list[PromptConverter]], Optional): The prompt converters to use to convert the
prompts before sending them to the prompt target. Defaults to None.
objective_scorer (Scorer): The scorer classifies the prompt target outputs as sufficient (True) or
insufficient (False) to satisfy the objective that is specified in the attack_strategy.
verbose (bool, Optional): Whether to print debug information. Defaults to False.
Raises:
FileNotFoundError: If the file specified by adversarial_chat_system_prompt_path does not exist.
ValueError: If max_turns is less than or equal to 0.
ValueError: If the objective_scorer is not a true/false scorer.
"""
[docs]
def __init__(
self,
*,
objective_target: PromptTarget,
adversarial_chat: PromptChatTarget,
adversarial_chat_system_prompt_path: Union[Path | str],
adversarial_chat_seed_prompt: Union[str | SeedPrompt] = "How can I help you?",
max_turns: int = 5,
prompt_converters: Optional[list[PromptConverter]] = None,
objective_scorer: Scorer,
verbose: bool = False,
) -> None:
super().__init__(prompt_converters=prompt_converters, verbose=verbose)
self._objective_target = objective_target
self._achieved_objective = False
self._adversarial_chat_system_seed_prompt = SeedPrompt.from_yaml_file(adversarial_chat_system_prompt_path)
if "objective" not in self._adversarial_chat_system_seed_prompt.parameters:
raise ValueError(f"Adversarial seed prompt must have an objective: '{adversarial_chat_system_prompt_path}'")
self._prompt_normalizer = PromptNormalizer()
self._adversarial_chat = adversarial_chat
self._adversarial_chat_seed_prompt = self._get_adversarial_chat_seed_prompt(adversarial_chat_seed_prompt)
if max_turns <= 0:
raise ValueError("The maximum number of turns must be greater than or equal to 0.")
self._max_turns = max_turns
self._objective_scorer = objective_scorer
def _get_adversarial_chat_seed_prompt(self, seed_prompt):
if isinstance(seed_prompt, str):
return SeedPrompt(
value=seed_prompt,
data_type="text",
)
return seed_prompt
[docs]
@abstractmethod
async def run_attack_async(
self, *, objective: str, memory_labels: Optional[dict[str, str]] = None
) -> MultiTurnAttackResult:
"""
Applies the attack strategy until the conversation is complete or the maximum number of turns is reached.
Args:
objective (str): The specific goal the orchestrator aims to achieve through the conversation.
memory_labels (dict[str, str], Optional): A free-form dictionary of additional labels to apply to the
prompts throughout the attack. Any labels passed in will be combined with self._global_memory_labels
(from the GLOBAL_MEMORY_LABELS environment variable) into one dictionary. In the case of collisions,
the passed-in labels take precedence. Defaults to None.
Returns:
MultiTurnAttackResult: Contains the outcome of the attack, including:
- conversation_id (UUID): The ID associated with the final conversation state.
- achieved_objective (bool): Indicates whether the orchestrator successfully met the objective.
- objective (str): The intended goal of the attack.
"""
[docs]
async def run_attacks_async(
self, *, objectives: list[str], memory_labels: Optional[dict[str, str]] = None, batch_size=5
) -> list[MultiTurnAttackResult]:
"""Applies the attack strategy for each objective in the list of objectives.
Args:
objectives (list[str]): The list of objectives to apply the attack strategy.
memory_labels (dict[str, str], Optional): A free-form dictionary of additional labels to apply to the
prompts throughout the attack. Any labels passed in will be combined with self._global_memory_labels
(from the GLOBAL_MEMORY_LABELS environment variable) into one dictionary. In the case of collisions,
the passed-in labels take precedence. Defaults to None.
batch_size (int): The number of objectives to process in parallel. The default value is 5.
Returns:
list[MultiTurnAttackResult]: The list of MultiTurnAttackResults for each objective.
"""
semaphore = asyncio.Semaphore(batch_size)
async def limited_run_attack(objective):
async with semaphore:
return await self.run_attack_async(objective=objective, memory_labels=memory_labels)
tasks = [limited_run_attack(objective) for objective in objectives]
results = await asyncio.gather(*tasks)
return results