Source code for pyrit.setup.initialization

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import pathlib

# Import PyRITInitializer for type checking (with TYPE_CHECKING to avoid circular imports)
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Union, get_args

import dotenv

from pyrit.common import path
from pyrit.common.apply_defaults import reset_default_values
from pyrit.memory import (
    AzureSQLMemory,
    CentralMemory,
    MemoryInterface,
    SQLiteMemory,
)

if TYPE_CHECKING:
    from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer

logger = logging.getLogger(__name__)

IN_MEMORY = "InMemory"
SQLITE = "SQLite"
AZURE_SQL = "AzureSQL"
MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"]


def _load_environment_files() -> None:
    """
    Loads the base environment file from .env if it exists,
    and then loads a single .env.local file if it exists, overriding previous values.
    """
    base_file_path = path.HOME_PATH / ".env"
    local_file_path = path.HOME_PATH / ".env.local"

    # Load the base .env file if it exists
    if base_file_path.exists():
        dotenv.load_dotenv(base_file_path, override=True, interpolate=True)
        logger.info(f"Loaded {base_file_path}")
    else:
        dotenv.load_dotenv(verbose=True)

    # Load the .env.local file if it exists, to override base .env values
    if local_file_path.exists():
        dotenv.load_dotenv(local_file_path, override=True, interpolate=True)
        logger.info(f"Loaded {local_file_path}")
    else:
        dotenv.load_dotenv(dotenv_path=dotenv.find_dotenv(".env.local"), override=True, verbose=True)


def _load_initializers_from_scripts(
    *, script_paths: Sequence[Union[str, pathlib.Path]]
) -> Sequence["PyRITInitializer"]:
    """
    Load PyRITInitializer instances from external Python files.

    Each script file should contain one or more PyRITInitializer classes. All classes
    that inherit from PyRITInitializer will be automatically discovered and instantiated.

    Args:
        script_paths (Sequence[Union[str, pathlib.Path]]): Sequence of file paths to Python scripts to load.

    Returns:
        Sequence[PyRITInitializer]: List of PyRITInitializer instances loaded from the scripts.

    Raises:
        FileNotFoundError: If a script path does not exist.
        ValueError: If a script path is not a Python file or doesn't contain valid initializers.

    Example:
        Script content should be a subclass of PyRITInitializer e.g. like SimpleInitializer
    """
    # Import here to avoid circular imports
    from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer

    loaded_initializers = []

    for script_path in script_paths:
        # Convert to Path object if string
        script = pathlib.Path(script_path)

        # Validate the script exists
        if not script.exists():
            raise FileNotFoundError(f"Initialization script not found: {script}")

        # Validate it's a Python file
        if script.suffix != ".py":
            raise ValueError(f"Initialization script must be a Python file (.py): {script}")

        logger.info(f"Loading initializers from script: {script}")

        # Load the script as a module
        try:
            import importlib.util

            spec = importlib.util.spec_from_file_location(f"init_script_{script.stem}", script)
            if spec is None or spec.loader is None:
                raise ValueError(f"Could not load initialization script: {script}")

            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)

            # Auto-discover PyRITInitializer subclasses in the module
            script_initializers = []

            # Look for all PyRITInitializer subclasses defined in the module
            for name in dir(module):
                obj = getattr(module, name)
                # Check if it's a class, is a subclass of PyRITInitializer,
                # and is not the base class itself
                if isinstance(obj, type) and issubclass(obj, PyRITInitializer) and obj is not PyRITInitializer:
                    try:
                        # Instantiate the initializer class
                        initializer = obj()
                        script_initializers.append(initializer)
                        logger.debug(f"Found and instantiated {name} in {script.name}")
                    except Exception as e:
                        logger.warning(f"Could not instantiate {name} from {script.name}: {e}")
                        # Continue to try other classes rather than failing completely

            if not script_initializers:
                raise ValueError(
                    f"Initialization script {script} must contain at least one PyRITInitializer subclass. "
                    f"Define a class that inherits from PyRITInitializer."
                )

            loaded_initializers.extend(script_initializers)
            logger.debug(f"Loaded {len(script_initializers)} initializer(s) from {script.name}")

        except Exception as e:
            logger.error(f"Error loading initializers from script {script}: {e}")
            raise

    return loaded_initializers


def _execute_initializers(*, initializers: Sequence["PyRITInitializer"]) -> None:
    """
    Execute PyRITInitializer instances in execution order.

    Initializers are sorted by their execution_order property before execution.
    Lower execution_order values run first.

    Args:
        initializers: Sequence of PyRITInitializer instances to execute.

    Raises:
        ValueError: If an initializer is not a PyRITInitializer instance.
        Exception: If an initializer's validation or initialization fails.
    """
    # Import here to avoid circular imports
    from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer

    # Validate all initializers first before sorting
    for initializer in initializers:
        if not isinstance(initializer, PyRITInitializer):
            raise ValueError(
                f"All initializers must be PyRITInitializer instances. "
                f"Got {type(initializer).__name__}: {initializer}"
            )

    # Sort initializers by execution_order (lower numbers first)
    sorted_initializers = sorted(initializers, key=lambda x: x.execution_order)

    for initializer in sorted_initializers:

        logger.info(f"Executing initializer: {initializer.name}")
        logger.debug(f"Description: {initializer.description}")

        try:
            # Validate first
            initializer.validate()

            # Then initialize with tracking to capture what was configured
            initializer.initialize_with_tracking()

            logger.debug(f"Successfully executed initializer: {initializer.name}")

        except Exception as e:
            logger.error(f"Error executing initializer {initializer.name}: {e}")
            raise


[docs] def initialize_pyrit( memory_db_type: Union[MemoryDatabaseType, str], *, initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, initializers: Optional[Sequence["PyRITInitializer"]] = None, **memory_instance_kwargs: Any, ) -> None: """ Initializes PyRIT with the provided memory instance and loads environment files. Args: memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory instance to use for central memory. Options include "InMemory", "SQLite", and "AzureSQL". initialization_scripts (Optional[Sequence[Union[str, pathlib.Path]]]): Optional sequence of Python script paths that contain PyRITInitializer classes. Each script must define either a get_initializers() function or an 'initializers' variable that returns/contains a list of PyRITInitializer instances. initializers (Optional[Sequence[PyRITInitializer]]): Optional sequence of PyRITInitializer instances to execute directly. These provide type-safe, validated configuration with clear documentation. **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. """ # Handle DuckDB deprecation before validation if memory_db_type == "DuckDB": logger.warning( "DuckDB is no longer supported and has been replaced by SQLite for better compatibility and performance. " "Please update your code to use SQLite instead. " "For migration guidance, see the SQLite Memory documentation at: " "doc/code/memory/1_sqlite_memory.ipynb. " "Using in-memory SQLite instead." ) memory_db_type = IN_MEMORY _load_environment_files() # Reset all default values before executing initialization scripts # This ensures a clean state for each initialization reset_default_values() # Set up memory BEFORE executing initialization scripts # This is critical because initialization scripts may instantiate objects # (like prompt targets) that require central memory to be initialized memory: MemoryInterface if memory_db_type == IN_MEMORY: logger.info("Using in-memory SQLite database.") memory = SQLiteMemory(db_path=":memory:", **memory_instance_kwargs) elif memory_db_type == SQLITE: logger.info("Using persistent SQLite database.") memory = SQLiteMemory(**memory_instance_kwargs) elif memory_db_type == AZURE_SQL: logger.info("Using AzureSQL database.") memory = AzureSQLMemory(**memory_instance_kwargs) else: raise ValueError( f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}" ) CentralMemory.set_memory_instance(memory) # Combine directly provided initializers with those loaded from scripts all_initializers = list(initializers) if initializers else [] # Load additional initializers from scripts if initialization_scripts: script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts) all_initializers.extend(script_initializers) # Execute all initializers (sorted by execution_order) if all_initializers: _execute_initializers(initializers=all_initializers)