# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Sequence, Union
from tinytag import TinyTag
from pyrit.common.path import PATHS_DICT
from pyrit.models import DataTypeSerializer
from pyrit.models.literals import ChatMessageRole
from pyrit.models.seed import Seed
logger = logging.getLogger(__name__)
[docs]
@dataclass
class SeedPrompt(Seed):
"""Represents a seed prompt with various attributes and metadata."""
# Alias for the prompt group
prompt_group_alias: Optional[str] = None
# Role of the prompt in a conversation (e.g., "user", "assistant")
role: Optional[ChatMessageRole] = None
# Sequence number for ordering prompts in a conversation, prompts with
# the same sequence number are grouped together if they also share the same prompt_group_id
sequence: Optional[int] = 0
# Parameters that can be used in the prompt template
parameters: Optional[Sequence[str]] = field(default_factory=lambda: [])
def __post_init__(self) -> None:
"""Post-initialization to render the template to replace existing values"""
self.value = self.render_template_value_silent(**PATHS_DICT)
if not self.data_type:
# If data_type is not provided, infer it from the value
# Note: Does not assign 'error' or 'url' implicitly
if os.path.isfile(self.value):
_, ext = os.path.splitext(self.value)
ext = ext.lstrip(".")
if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]:
self.data_type = "video_path"
elif ext in ["flac", "mp3", "mpeg", "mpga", "m4a", "ogg", "wav"]:
self.data_type = "audio_path"
elif ext in ["jpg", "jpeg", "png", "gif", "bmp", "tiff", "tif"]:
self.data_type = "image_path"
else:
raise ValueError(f"Unable to infer data_type from file extension: {ext}")
else:
self.data_type = "text"
[docs]
@classmethod
def from_yaml_with_required_parameters(
cls,
template_path: Union[str, Path],
required_parameters: list[str],
error_message: Optional[str] = None,
) -> "SeedPrompt":
"""
Load a Seed from a YAML file and validate that it contains specific parameters.
Args:
template_path: Path to the YAML file containing the template.
required_parameters: List of parameter names that must exist in the template.
error_message: Custom error message if validation fails. If None, a default message is used.
Returns:
SeedPrompt: The loaded and validated SeedPrompt of the specific subclass type.
Raises:
ValueError: If the template doesn't contain all required parameters.
"""
sp = cls.from_yaml_file(template_path)
if sp.parameters is None or not all(param in sp.parameters for param in required_parameters):
if error_message is None:
error_message = f"Template must have these parameters: {', '.join(required_parameters)}"
raise ValueError(f"{error_message}: '{sp}'")
return sp