Source code for pyrit.common.initialization
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Literal, Optional, Union, get_args
import dotenv
from pyrit.common import path
logger = logging.getLogger(__name__)
IN_MEMORY = "InMemory"
DUCK_DB = "DuckDB"
AZURE_SQL = "AzureSQL"
MemoryDatabaseType = Literal["InMemory", "DuckDB", "AzureSQL"]
def _load_environment_files() -> None:
"""
Loads the base environment file from .env if it exists,
and then loads a single .env.local file if it exists, overriding previous values.
"""
base_file_path = path.HOME_PATH / ".env"
local_file_path = path.HOME_PATH / ".env.local"
# Load the base .env file if it exists
if base_file_path.exists():
dotenv.load_dotenv(base_file_path, override=True)
logger.info(f"Loaded {base_file_path}")
else:
dotenv.load_dotenv()
# Load the .env.local file if it exists, to override base .env values
if local_file_path.exists():
dotenv.load_dotenv(local_file_path, override=True)
logger.info(f"Loaded {local_file_path}")
[docs]
def initialize_pyrit(memory_db_type: Union[MemoryDatabaseType, str], **memory_instance_kwargs: Optional[Any]) -> None:
"""
Initializes PyRIT with the provided memory instance and loads environment files.
Args:
memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory
instance to use for central memory. Options include "InMemory", "DuckDB", and "AzureSQL".
**memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance.
"""
if memory_db_type not in get_args(MemoryDatabaseType):
raise ValueError(
f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}"
)
_load_environment_files()
from pyrit.memory import (
AzureSQLMemory,
CentralMemory,
DuckDBMemory,
MemoryInterface,
)
memory: MemoryInterface = None
if memory_db_type == IN_MEMORY:
logger.info("Using in-memory DuckDB database.")
memory = DuckDBMemory(db_path=":memory:", **memory_instance_kwargs)
elif memory_db_type == DUCK_DB:
logger.info("Using persistent DuckDB database.")
memory = DuckDBMemory(**memory_instance_kwargs)
else:
logger.info("Using AzureSQL database.")
memory = AzureSQLMemory(**memory_instance_kwargs)
CentralMemory.set_memory_instance(memory)