Source code for pyrit.setup.initializers.scenarios.load_default_datasets
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Scenario Basic Dataset Loader.
If you don't have a database already, this can enable you to run all scenarios using
the pre-defined datasets in PyRIT. These are meant as a starting point only.
"""
import logging
import textwrap
from typing import List
from pyrit.cli.scenario_registry import ScenarioRegistry
from pyrit.datasets import SeedDatasetProvider
from pyrit.memory import CentralMemory
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
logger = logging.getLogger(__name__)
[docs]
class LoadDefaultDatasets(PyRITInitializer):
"""Load default datasets for all registered scenarios."""
@property
def name(self) -> str:
"""Return the name of this initializer."""
return "Default Dataset Loader for Scenarios"
@property
def execution_order(self) -> int:
"""Should be executed after most initializers."""
return 10
@property
def description(self) -> str:
"""Return a description of this initializer."""
return textwrap.dedent(
"""
This configuration uses the DatasetLoader to load default datasets into memory.
This will enable all scenarios to run. Datasets can be customized in memory.
Note: if you are using persistent memory, avoid calling this every time as datasets
can take time to load.
"""
).strip()
@property
def required_env_vars(self) -> List[str]:
"""Return the list of required environment variables."""
return []
[docs]
async def initialize_async(self) -> None:
"""Load default datasets from all registered scenarios."""
# Get ScenarioRegistry to discover all scenarios
registry = ScenarioRegistry()
# Collect all required datasets from all scenarios
all_required_datasets: List[str] = []
# Get all scenario names from registry
scenario_names = registry.get_scenario_names()
for scenario_name in scenario_names:
scenario_class = registry.get_scenario(scenario_name)
if scenario_class:
# Get required_datasets from the scenario class
try:
datasets = scenario_class.required_datasets()
all_required_datasets.extend(datasets)
logger.info(f"Scenario '{scenario_name}' requires datasets: {datasets}")
except Exception as e:
logger.warning(f"Could not get required datasets from scenario '{scenario_name}': {e}")
# Remove duplicates
unique_datasets = list(dict.fromkeys(all_required_datasets))
if not unique_datasets:
logger.warning("No datasets required by any scenario")
return
logger.info(f"Loading {len(unique_datasets)} unique datasets required by all scenarios")
# Fetch the datasets
dataset_list = await SeedDatasetProvider.fetch_datasets_async(
dataset_names=unique_datasets,
)
# Store datasets in CentralMemory
memory = CentralMemory.get_memory_instance()
await memory.add_seed_datasets_to_memory_async(datasets=dataset_list, added_by="LoadDefaultDatasets")
logger.info(f"Successfully loaded {len(dataset_list)} datasets into CentralMemory")