# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import pathlib
# Import PyRITInitializer for type checking (with TYPE_CHECKING to avoid circular imports)
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union, get_args
import dotenv
from pyrit.common import path
from pyrit.common.apply_defaults import reset_default_values
from pyrit.memory import (
AzureSQLMemory,
CentralMemory,
MemoryInterface,
SQLiteMemory,
)
if TYPE_CHECKING:
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
logger = logging.getLogger(__name__)
IN_MEMORY = "InMemory"
SQLITE = "SQLite"
AZURE_SQL = "AzureSQL"
MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"]
def _load_environment_files(env_files: Optional[Sequence[pathlib.Path]], *, silent: bool = False) -> None:
"""
Load environment files in the order they are provided.
Later files override values from earlier files.
Args:
env_files: Optional sequence of environment file paths. If None, loads default
.env and .env.local from PyRIT home directory (only if they exist).
silent: If True, suppresses print statements about environment file loading.
Defaults to False.
Raises:
ValueError: If any provided env_files do not exist.
"""
# Validate env_files exist if they were provided
if env_files is not None:
if not silent:
_print_msg(f"Loading custom environment files: {[str(f) for f in env_files]}", quiet=silent, log=True)
for env_file in env_files:
if not env_file.exists():
raise ValueError(f"Environment file not found: {env_file}")
# By default load .env and .env.local from home directory of the package
else:
default_files = []
base_file = path.CONFIGURATION_DIRECTORY_PATH / ".env"
local_file = path.CONFIGURATION_DIRECTORY_PATH / ".env.local"
if base_file.exists():
default_files.append(base_file)
if local_file.exists():
default_files.append(local_file)
if not silent:
if default_files:
_print_msg(
f"Found default environment files: {[str(f) for f in default_files]}", quiet=silent, log=True
)
else:
_print_msg(
"No default environment files found. Using system environment variables only.",
quiet=silent,
log=True,
)
env_files = default_files
for env_file in env_files:
dotenv.load_dotenv(env_file, override=True, interpolate=True)
if not silent:
_print_msg(f"Loaded environment file: {env_file}", quiet=silent, log=True)
def _print_msg(message: str, quiet: bool, log: bool) -> None:
"""
Print a standard initialization message unless quiet is True.
Args:
message (str): The message to print and/or log.
quiet (bool): If True, suppresses the initialization message.
log (bool): If True, logs the message using the logger.
"""
if not quiet:
print(message)
if log:
logger.info(message)
def _load_initializers_from_scripts(
*, script_paths: Sequence[Union[str, pathlib.Path]]
) -> Sequence["PyRITInitializer"]:
"""
Load PyRITInitializer instances from external Python files.
Each script file should contain one or more PyRITInitializer classes. All classes
that inherit from PyRITInitializer will be automatically discovered and instantiated.
Args:
script_paths (Sequence[Union[str, pathlib.Path]]): Sequence of file paths to Python scripts to load.
Returns:
Sequence[PyRITInitializer]: List of PyRITInitializer instances loaded from the scripts.
Raises:
FileNotFoundError: If a script path does not exist.
ValueError: If a script path is not a Python file or doesn't contain valid initializers.
Example:
Script content should be a subclass of PyRITInitializer e.g. like SimpleInitializer
"""
# Import here to avoid circular imports
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
loaded_initializers = []
for script_path in script_paths:
# Convert to Path object if string
script = pathlib.Path(script_path)
# Validate the script exists
if not script.exists():
raise FileNotFoundError(f"Initialization script not found: {script}")
# Validate it's a Python file
if script.suffix != ".py":
raise ValueError(f"Initialization script must be a Python file (.py): {script}")
logger.info(f"Loading initializers from script: {script}")
# Load the script as a module
try:
import importlib.util
spec = importlib.util.spec_from_file_location(f"init_script_{script.stem}", script)
if spec is None or spec.loader is None:
raise ValueError(f"Could not load initialization script: {script}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Auto-discover PyRITInitializer subclasses in the module
script_initializers = []
# Look for all PyRITInitializer subclasses defined in the module
for name in dir(module):
obj = getattr(module, name)
# Check if it's a class, is a subclass of PyRITInitializer,
# and is not the base class itself
if isinstance(obj, type) and issubclass(obj, PyRITInitializer) and obj is not PyRITInitializer:
try:
# Instantiate the initializer class
initializer = obj()
script_initializers.append(initializer)
logger.debug(f"Found and instantiated {name} in {script.name}")
except Exception as e:
logger.warning(f"Could not instantiate {name} from {script.name}: {e}")
# Continue to try other classes rather than failing completely
if not script_initializers:
raise ValueError(
f"Initialization script {script} must contain at least one PyRITInitializer subclass. "
f"Define a class that inherits from PyRITInitializer."
)
loaded_initializers.extend(script_initializers)
logger.debug(f"Loaded {len(script_initializers)} initializer(s) from {script.name}")
except Exception as e:
logger.error(f"Error loading initializers from script {script}: {e}")
raise
return loaded_initializers
async def _execute_initializers_async(*, initializers: Sequence["PyRITInitializer"]) -> None:
"""
Execute PyRITInitializer instances in execution order.
Initializers are sorted by their execution_order property before execution.
Lower execution_order values run first.
Args:
initializers: Sequence of PyRITInitializer instances to execute.
Raises:
ValueError: If an initializer is not a PyRITInitializer instance.
Exception: If an initializer's validation or initialization fails.
"""
# Import here to avoid circular imports
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
# Validate all initializers first before sorting
for initializer in initializers:
if not isinstance(initializer, PyRITInitializer):
raise ValueError(
f"All initializers must be PyRITInitializer instances. "
f"Got {type(initializer).__name__}: {initializer}"
)
# Sort initializers by execution_order (lower numbers first)
sorted_initializers = sorted(initializers, key=lambda x: x.execution_order)
for initializer in sorted_initializers:
logger.info(f"Executing initializer: {initializer.name}")
logger.debug(f"Description: {initializer.description}")
try:
# Validate first
initializer.validate()
# Then initialize with tracking to capture what was configured
await initializer.initialize_with_tracking_async()
logger.debug(f"Successfully executed initializer: {initializer.name}")
except Exception as e:
logger.error(f"Error executing initializer {initializer.name}: {e}")
raise
[docs]
async def initialize_pyrit_async(
memory_db_type: Union[MemoryDatabaseType, str],
*,
initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None,
initializers: Optional[Sequence["PyRITInitializer"]] = None,
env_files: Optional[Sequence[pathlib.Path]] = None,
silent: bool = False,
**memory_instance_kwargs: Any,
) -> None:
"""
Initialize PyRIT with the provided memory instance and loads environment files.
Args:
memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory
instance to use for central memory. Options include "InMemory", "SQLite", and "AzureSQL".
initialization_scripts (Optional[Sequence[Union[str, pathlib.Path]]]): Optional sequence of Python script paths
that contain PyRITInitializer classes. Each script must define either a get_initializers() function
or an 'initializers' variable that returns/contains a list of PyRITInitializer instances.
initializers (Optional[Sequence[PyRITInitializer]]): Optional sequence of PyRITInitializer instances
to execute directly. These provide type-safe, validated configuration with clear documentation.
env_files (Optional[Sequence[pathlib.Path]]): Optional sequence of environment file paths to load
in order. If not provided, will load default .env and .env.local files from PyRIT home if they exist.
All paths must be valid pathlib.Path objects.
silent (bool): If True, suppresses print statements about environment file loading.
Defaults to False.
**memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance.
Raises:
ValueError: If an unsupported memory_db_type is provided or if env_files contains non-existent files.
"""
_load_environment_files(env_files=env_files, silent=silent)
# Reset all default values before executing initialization scripts
# This ensures a clean state for each initialization
reset_default_values()
# Set up memory BEFORE executing initialization scripts
# This is critical because initialization scripts may instantiate objects
# (like prompt targets) that require central memory to be initialized
memory: MemoryInterface
if memory_db_type == IN_MEMORY:
logger.info("Using in-memory SQLite database.")
memory = SQLiteMemory(db_path=":memory:", **memory_instance_kwargs)
elif memory_db_type == SQLITE:
logger.info("Using persistent SQLite database.")
memory = SQLiteMemory(**memory_instance_kwargs)
elif memory_db_type == AZURE_SQL:
logger.info("Using AzureSQL database.")
memory = AzureSQLMemory(**memory_instance_kwargs)
else:
raise ValueError(
f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}"
)
CentralMemory.set_memory_instance(memory)
# Combine directly provided initializers with those loaded from scripts
all_initializers = list(initializers) if initializers else []
# Load additional initializers from scripts
if initialization_scripts:
script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts)
all_initializers.extend(script_initializers)
# Execute all initializers (sorted by execution_order)
if all_initializers:
await _execute_initializers_async(initializers=all_initializers)