# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import uuid
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from jinja2 import StrictUndefined, Template
from pyrit.common import utils
from pyrit.common.yaml_loadable import YamlLoadable
from pyrit.models.literals import PromptDataType
[docs]
@dataclass
class SeedPrompt(YamlLoadable):
"""Represents a seed prompt with various attributes and metadata."""
id: Optional[uuid.UUID]
value: str
value_sha256: str
data_type: PromptDataType
name: Optional[str]
dataset_name: Optional[str]
harm_categories: Optional[List[str]]
description: Optional[str]
authors: Optional[List[str]]
groups: Optional[List[str]]
source: Optional[str]
date_added: Optional[datetime]
added_by: Optional[str]
metadata: Optional[Dict[str, str]]
parameters: Optional[List[str]]
prompt_group_id: Optional[uuid.UUID]
prompt_group_alias: Optional[str]
sequence: Optional[int]
[docs]
def __init__(
self,
*,
id: Optional[uuid.UUID] = None,
value: str,
value_sha256: Optional[str] = None,
data_type: PromptDataType,
name: Optional[str] = None,
dataset_name: Optional[str] = None,
harm_categories: Optional[List[str]] = None,
description: Optional[str] = None,
authors: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
source: Optional[str] = None,
date_added: Optional[datetime] = datetime.now(),
added_by: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
parameters: Optional[List[str]] = None,
prompt_group_id: Optional[uuid.UUID] = None,
prompt_group_alias: Optional[str] = None,
sequence: Optional[int] = 0,
):
self.id = id if id else uuid.uuid4()
self.value = value
self.value_sha256 = value_sha256
self.data_type = data_type
self.name = name
self.dataset_name = dataset_name
self.harm_categories = harm_categories or []
self.description = description
self.authors = authors or []
self.groups = groups or []
self.source = source
self.date_added = date_added
self.added_by = added_by
self.metadata = metadata
self.parameters = parameters or []
self.prompt_group_id = prompt_group_id
self.prompt_group_alias = prompt_group_alias
self.sequence = sequence
[docs]
def render_template_value(self, **kwargs) -> str:
"""Renders self.value as a template, applying provided parameters in kwargs
Args:
kwargs:Key-value pairs to replace in the SeedPrompt value.
Returns:
A new prompt with the parameters applied.
Raises:
ValueError: If parameters are missing or invalid in the template.
"""
if self.data_type != "text":
raise ValueError(f"Cannot render non-text values as templates {self.data_type}")
jinja_template = Template(self.value, undefined=StrictUndefined)
try:
return jinja_template.render(**kwargs)
except Exception as e:
raise ValueError(f"Error applying parameters: {str(e)}")
[docs]
async def set_sha256_value_async(self):
"""
This method computes the SHA256 hash value asynchronously.
It should be called after prompt `value` is serialized to text,
as file paths used in the `value` may have changed from local to memory storage paths.
Note, this method is async due to the blob retrieval. And because of that, we opted
to take it out of main and setter functions. The disadvantage is that it must be explicitly called.
"""
from pyrit.models.data_type_serializer import data_serializer_factory
original_serializer = data_serializer_factory(
category="seed-prompt-entries", data_type=self.data_type, value=self.value
)
self.value_sha256 = await original_serializer.get_sha256()
[docs]
class SeedPromptGroup(YamlLoadable):
"""
A group of prompts that need to be sent together.
This class is useful when a target requires multiple (multimodal) prompt pieces to be grouped
and sent together. All prompts in the group should share the same `prompt_group_id`.
"""
prompts: List[SeedPrompt]
[docs]
def __init__(
self,
*,
prompts: Union[List[SeedPrompt], List[Dict[str, Any]]],
):
if not prompts:
raise ValueError("SeedPromptGroup cannot be empty.")
self.prompts = []
for prompt in prompts:
if isinstance(prompt, SeedPrompt):
self.prompts.append(prompt)
elif isinstance(prompt, dict):
self.prompts.append(SeedPrompt(**prompt))
self._enforce_consistent_group_id()
# Check sequence and sort the prompts in the same loop
if len(self.prompts) >= 1:
self.prompts = sorted(self.prompts, key=lambda prompt: prompt.sequence)
def _enforce_consistent_group_id(self):
"""
Ensures that if any of the prompts already have a group ID set,
they share the same ID. If none have a group ID set, assign a
new UUID to all prompts.
Raises:
ValueError: If multiple different group IDs exist among the prompts.
"""
existing_group_ids = {prompt.prompt_group_id for prompt in self.prompts if prompt.prompt_group_id is not None}
if len(existing_group_ids) > 1:
# More than one distinct group ID found among prompts.
raise ValueError("Inconsistent group IDs found across prompts.")
elif len(existing_group_ids) == 1:
# Exactly one group ID is set; apply it to all.
group_id = existing_group_ids.pop()
for prompt in self.prompts:
prompt.prompt_group_id = group_id
else:
# No group IDs set; generate a fresh one and assign it to all.
new_group_id = uuid.uuid4()
for prompt in self.prompts:
prompt.prompt_group_id = new_group_id
[docs]
def is_single_request(self) -> bool:
unique_sequences = {prompt.sequence for prompt in self.prompts}
return len(unique_sequences) <= 1
def __repr__(self):
return f"<SeedPromptGroup(prompts={len(self.prompts)} prompts)>"
[docs]
class SeedPromptDataset(YamlLoadable):
"""
SeedPromptDataset manages seed prompts plus optional top-level defaults.
Prompts are stored as a List[SeedPrompt], so references to prompt properties
are straightforward (e.g. ds.prompts[0].value).
"""
data_type: Optional[str]
name: Optional[str]
dataset_name: Optional[str]
harm_categories: Optional[List[str]]
description: Optional[str]
authors: Optional[List[str]]
groups: Optional[List[str]]
source: Optional[str]
date_added: Optional[datetime]
added_by: Optional[str]
# Now the actual prompts
prompts: List["SeedPrompt"]
[docs]
def __init__(
self,
*,
prompts: Union[List[Dict[str, Any]], List[SeedPrompt]] = None,
data_type: Optional[PromptDataType] = "text",
name: Optional[str] = None,
dataset_name: Optional[str] = None,
harm_categories: Optional[List[str]] = None,
description: Optional[str] = None,
authors: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
source: Optional[str] = None,
date_added: Optional[datetime] = None,
added_by: Optional[str] = None,
):
"""
Initialize the dataset.
Typically, you'll call from_dict or from_yaml_file so that top-level defaults
are merged into each prompt. If you're passing prompts directly, they can be
either a list of SeedPrompt objects or prompt dictionaries (which then get
converted to SeedPrompt objects).
"""
if prompts is None:
prompts = []
if not prompts:
raise ValueError("SeedPromptDataset cannot be empty.")
# Store top-level fields
self.data_type = data_type
self.name = name
self.dataset_name = dataset_name
self.harm_categories = harm_categories
self.description = description
self.authors = authors or []
self.groups = groups or []
self.source = source
self.date_added = date_added or datetime.now()
self.added_by = added_by
# Convert any dictionaries in `prompts` to SeedPrompt objects
self.prompts = []
for p in prompts:
if isinstance(p, dict):
self.prompts.append(SeedPrompt(**p))
elif isinstance(p, SeedPrompt):
self.prompts.append(p)
else:
raise ValueError("Prompts should be either dicts or SeedPrompt objects. Got something else.")
[docs]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SeedPromptDataset":
"""
Builds a SeedPromptDataset by merging top-level defaults into each item in 'prompts'.
"""
# Pop out the prompts section
prompts_data = data.pop("prompts", [])
dataset_defaults = data # everything else is top-level
merged_prompts = []
for p in prompts_data:
# Merge dataset-level fields with the prompt-level fields
merged = utils.combine_dict(dataset_defaults, p)
merged["harm_categories"] = utils.combine_list(
dataset_defaults.get("harm_categories", []),
p.get("harm_categories", []),
)
merged["authors"] = utils.combine_list(
dataset_defaults.get("authors", []),
p.get("authors", []),
)
merged["groups"] = utils.combine_list(
dataset_defaults.get("groups", []),
p.get("groups", []),
)
if "data_type" not in merged:
merged["data_type"] = dataset_defaults.get("data_type", "text")
merged_prompts.append(merged)
for prompt in merged_prompts:
if "prompt_group_id" in prompt:
raise ValueError("prompt_group_id should not be set in prompt data")
SeedPromptDataset._set_seed_prompt_group_id_by_alias(seed_prompts=merged_prompts)
# Now create the dataset with the newly merged prompt dicts
return cls(prompts=merged_prompts, **dataset_defaults)
@staticmethod
def _set_seed_prompt_group_id_by_alias(seed_prompts: List[dict]):
"""
Sets all seed_prompt_group_ids based on prompt_group_id_alias matches
This is important so the prompt_group_id_alias can be set in yaml to group prompts
"""
alias_to_group_id = {}
for prompt in seed_prompts:
alias = prompt.get("prompt_group_alias")
if alias:
if alias not in alias_to_group_id:
alias_to_group_id[alias] = uuid.uuid4()
prompt["prompt_group_id"] = alias_to_group_id[alias]
else:
prompt["prompt_group_id"] = uuid.uuid4()
[docs]
@staticmethod
def group_seed_prompts_by_prompt_group_id(seed_prompts: List[SeedPrompt]) -> List[SeedPromptGroup]:
"""
Groups the given list of SeedPrompts by their prompt_group_id and creates
SeedPromptGroup instances.
Args:
seed_prompts: A list of SeedPrompt objects.
Returns:
A list of SeedPromptGroup objects, with prompts grouped by prompt_group_id.
"""
# Group seed prompts by `prompt_group_id`
grouped_prompts = defaultdict(list)
for prompt in seed_prompts:
if prompt.prompt_group_id:
grouped_prompts[prompt.prompt_group_id].append(prompt)
else:
grouped_prompts[uuid.uuid4()].append(prompt)
# Create SeedPromptGroup instances from grouped prompts
seed_prompt_groups = []
for group_prompts in grouped_prompts.values():
if len(group_prompts) > 1:
group_prompts.sort(key=lambda prompt: prompt.sequence)
seed_prompt_group = SeedPromptGroup(prompts=group_prompts)
seed_prompt_groups.append(seed_prompt_group)
return seed_prompt_groups
def __repr__(self):
return f"<SeedPromptDataset(prompts={len(self.prompts)} prompts)>"