# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import pathlib
# Import PyRITInitializer for type checking (with TYPE_CHECKING to avoid circular imports)
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union, get_args
import dotenv
from pyrit.common import path
from pyrit.common.apply_defaults import reset_default_values
from pyrit.memory import (
AzureSQLMemory,
CentralMemory,
MemoryInterface,
SQLiteMemory,
)
if TYPE_CHECKING:
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
logger = logging.getLogger(__name__)
IN_MEMORY = "InMemory"
SQLITE = "SQLite"
AZURE_SQL = "AzureSQL"
MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"]
def _load_environment_files() -> None:
"""
Loads the base environment file from .env if it exists,
and then loads a single .env.local file if it exists, overriding previous values.
"""
base_file_path = path.HOME_PATH / ".env"
local_file_path = path.HOME_PATH / ".env.local"
# Load the base .env file if it exists
if base_file_path.exists():
dotenv.load_dotenv(base_file_path, override=True, interpolate=True)
logger.info(f"Loaded {base_file_path}")
else:
dotenv.load_dotenv(verbose=True)
# Load the .env.local file if it exists, to override base .env values
if local_file_path.exists():
dotenv.load_dotenv(local_file_path, override=True, interpolate=True)
logger.info(f"Loaded {local_file_path}")
else:
dotenv.load_dotenv(dotenv_path=dotenv.find_dotenv(".env.local"), override=True, verbose=True)
def _load_initializers_from_scripts(
*, script_paths: Sequence[Union[str, pathlib.Path]]
) -> Sequence["PyRITInitializer"]:
"""
Load PyRITInitializer instances from external Python files.
Each script file should contain one or more PyRITInitializer classes. All classes
that inherit from PyRITInitializer will be automatically discovered and instantiated.
Args:
script_paths (Sequence[Union[str, pathlib.Path]]): Sequence of file paths to Python scripts to load.
Returns:
Sequence[PyRITInitializer]: List of PyRITInitializer instances loaded from the scripts.
Raises:
FileNotFoundError: If a script path does not exist.
ValueError: If a script path is not a Python file or doesn't contain valid initializers.
Example:
Script content should be a subclass of PyRITInitializer e.g. like SimpleInitializer
"""
# Import here to avoid circular imports
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
loaded_initializers = []
for script_path in script_paths:
# Convert to Path object if string
script = pathlib.Path(script_path)
# Validate the script exists
if not script.exists():
raise FileNotFoundError(f"Initialization script not found: {script}")
# Validate it's a Python file
if script.suffix != ".py":
raise ValueError(f"Initialization script must be a Python file (.py): {script}")
logger.info(f"Loading initializers from script: {script}")
# Load the script as a module
try:
import importlib.util
spec = importlib.util.spec_from_file_location(f"init_script_{script.stem}", script)
if spec is None or spec.loader is None:
raise ValueError(f"Could not load initialization script: {script}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Auto-discover PyRITInitializer subclasses in the module
script_initializers = []
# Look for all PyRITInitializer subclasses defined in the module
for name in dir(module):
obj = getattr(module, name)
# Check if it's a class, is a subclass of PyRITInitializer,
# and is not the base class itself
if isinstance(obj, type) and issubclass(obj, PyRITInitializer) and obj is not PyRITInitializer:
try:
# Instantiate the initializer class
initializer = obj()
script_initializers.append(initializer)
logger.debug(f"Found and instantiated {name} in {script.name}")
except Exception as e:
logger.warning(f"Could not instantiate {name} from {script.name}: {e}")
# Continue to try other classes rather than failing completely
if not script_initializers:
raise ValueError(
f"Initialization script {script} must contain at least one PyRITInitializer subclass. "
f"Define a class that inherits from PyRITInitializer."
)
loaded_initializers.extend(script_initializers)
logger.debug(f"Loaded {len(script_initializers)} initializer(s) from {script.name}")
except Exception as e:
logger.error(f"Error loading initializers from script {script}: {e}")
raise
return loaded_initializers
def _execute_initializers(*, initializers: Sequence["PyRITInitializer"]) -> None:
"""
Execute PyRITInitializer instances in execution order.
Initializers are sorted by their execution_order property before execution.
Lower execution_order values run first.
Args:
initializers: Sequence of PyRITInitializer instances to execute.
Raises:
ValueError: If an initializer is not a PyRITInitializer instance.
Exception: If an initializer's validation or initialization fails.
"""
# Import here to avoid circular imports
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
# Validate all initializers first before sorting
for initializer in initializers:
if not isinstance(initializer, PyRITInitializer):
raise ValueError(
f"All initializers must be PyRITInitializer instances. "
f"Got {type(initializer).__name__}: {initializer}"
)
# Sort initializers by execution_order (lower numbers first)
sorted_initializers = sorted(initializers, key=lambda x: x.execution_order)
for initializer in sorted_initializers:
logger.info(f"Executing initializer: {initializer.name}")
logger.debug(f"Description: {initializer.description}")
try:
# Validate first
initializer.validate()
# Then initialize with tracking to capture what was configured
initializer.initialize_with_tracking()
logger.debug(f"Successfully executed initializer: {initializer.name}")
except Exception as e:
logger.error(f"Error executing initializer {initializer.name}: {e}")
raise
[docs]
def initialize_pyrit(
memory_db_type: Union[MemoryDatabaseType, str],
*,
initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None,
initializers: Optional[Sequence["PyRITInitializer"]] = None,
**memory_instance_kwargs: Any,
) -> None:
"""
Initializes PyRIT with the provided memory instance and loads environment files.
Args:
memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory
instance to use for central memory. Options include "InMemory", "SQLite", and "AzureSQL".
initialization_scripts (Optional[Sequence[Union[str, pathlib.Path]]]): Optional sequence of Python script paths
that contain PyRITInitializer classes. Each script must define either a get_initializers() function
or an 'initializers' variable that returns/contains a list of PyRITInitializer instances.
initializers (Optional[Sequence[PyRITInitializer]]): Optional sequence of PyRITInitializer instances
to execute directly. These provide type-safe, validated configuration with clear documentation.
**memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance.
"""
# Handle DuckDB deprecation before validation
if memory_db_type == "DuckDB":
logger.warning(
"DuckDB is no longer supported and has been replaced by SQLite for better compatibility and performance. "
"Please update your code to use SQLite instead. "
"For migration guidance, see the SQLite Memory documentation at: "
"doc/code/memory/1_sqlite_memory.ipynb. "
"Using in-memory SQLite instead."
)
memory_db_type = IN_MEMORY
_load_environment_files()
# Reset all default values before executing initialization scripts
# This ensures a clean state for each initialization
reset_default_values()
# Set up memory BEFORE executing initialization scripts
# This is critical because initialization scripts may instantiate objects
# (like prompt targets) that require central memory to be initialized
memory: MemoryInterface
if memory_db_type == IN_MEMORY:
logger.info("Using in-memory SQLite database.")
memory = SQLiteMemory(db_path=":memory:", **memory_instance_kwargs)
elif memory_db_type == SQLITE:
logger.info("Using persistent SQLite database.")
memory = SQLiteMemory(**memory_instance_kwargs)
elif memory_db_type == AZURE_SQL:
logger.info("Using AzureSQL database.")
memory = AzureSQLMemory(**memory_instance_kwargs)
else:
raise ValueError(
f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}"
)
CentralMemory.set_memory_instance(memory)
# Combine directly provided initializers with those loaded from scripts
all_initializers = list(initializers) if initializers else []
# Load additional initializers from scripts
if initialization_scripts:
script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts)
all_initializers.extend(script_initializers)
# Execute all initializers (sorted by execution_order)
if all_initializers:
_execute_initializers(initializers=all_initializers)