Source code for pyrit.memory.central_memory
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from pyrit.common import default_values
from pyrit.memory import AzureSQLMemory, DuckDBMemory
from pyrit.memory.memory_interface import MemoryInterface
logger = logging.getLogger(__name__)
[docs]
class CentralMemory:
"""
Provides a centralized memory instance across the framework. If a memory instance is passed,
it will be reused for future calls. Otherwise, it uses AzureSQLMemory if configuration values
are found, defaulting to DuckDBMemory if not.
"""
_memory_instance: MemoryInterface = None
[docs]
@classmethod
def set_memory_instance(cls, passed_memory: MemoryInterface) -> None:
"""
Set a provided memory instance as the central instance for subsequent calls.
Args:
passed_memory (MemoryInterface): The memory instance to set as the central instance.
"""
cls._memory_instance = passed_memory
logger.info(f"Central memory instance set to: {type(cls._memory_instance).__name__}")
[docs]
@classmethod
def get_memory_instance(cls) -> MemoryInterface:
"""
Returns a centralized memory instance. Initializes it to AzureSQLMemory if
Azure SQL/Storage Account configuration values are found, otherwise defaults
to DuckDBMemory.
"""
if cls._memory_instance:
logger.info(f"Reusing existing memory instance: {type(cls._memory_instance).__name__}")
return cls._memory_instance
# Check for Azure SQL settings
empty_passed_value = ""
azure_sql_db_conn_string = default_values.get_non_required_value(
env_var_name="AZURE_SQL_DB_CONNECTION_STRING", passed_value=empty_passed_value
)
results_container_url = default_values.get_non_required_value(
env_var_name="AZURE_STORAGE_ACCOUNT_RESULTS_CONTAINER_URL", passed_value=empty_passed_value
)
# If both Azure SQL configs are present, use AzureSQLMemory; otherwise, use DuckDBMemory
if azure_sql_db_conn_string and results_container_url:
logger.info("Using AzureSQLMemory as central memory.")
cls._memory_instance = AzureSQLMemory(
connection_string=azure_sql_db_conn_string, results_container_url=results_container_url
)
else:
logger.info("Using DuckDBMemory due to missing Azure SQL DB configuration.")
cls._memory_instance = DuckDBMemory()
return cls._memory_instance