Source code for pyrit.datasets.seed_datasets.seed_dataset_provider

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import inspect
import logging
from abc import ABC, abstractmethod
from typing import Any, Optional

from tqdm import tqdm

from pyrit.models.seeds import SeedDataset

logger = logging.getLogger(__name__)


[docs] class SeedDatasetProvider(ABC): """ Abstract base class for providing seed datasets with automatic registration. All concrete subclasses are automatically registered and can be discovered via get_all_providers() class method. This enables automatic discovery of both local and remote dataset providers. Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset """ _registry: dict[str, type["SeedDatasetProvider"]] = {} def __init_subclass__(cls, **kwargs: Any) -> None: """ Automatically register non-abstract subclasses. This is called when a class inherits from SeedDatasetProvider. """ super().__init_subclass__(**kwargs) # Only register concrete (non-abstract) classes if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") @property @abstractmethod def dataset_name(self) -> str: """ Return the human-readable name of the dataset. Returns: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """
[docs] @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ Fetch the dataset and return as a SeedDataset. Args: cache: Whether to cache the fetched dataset. Defaults to True. Remote datasets will use DB_DATA_PATH for caching. Returns: SeedDataset: The fetched dataset with prompts. Raises: Exception: If the dataset cannot be fetched or processed. """
[docs] @classmethod def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ Get all registered dataset provider classes. Returns: Dict[str, Type[SeedDatasetProvider]]: Dictionary mapping class names to provider classes. """ return cls._registry.copy()
[docs] @classmethod def get_all_dataset_names(cls) -> list[str]: """ Get the names of all registered datasets. Returns: List[str]: List of dataset names from all registered providers. Raises: ValueError: If no providers are registered or if providers cannot be instantiated. Example: >>> names = SeedDatasetProvider.get_all_dataset_names() >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names)
[docs] @classmethod async def fetch_datasets_async( cls, *, dataset_names: Optional[list[str]] = None, cache: bool = True, max_concurrency: int = 5, ) -> list[SeedDataset]: """ Fetch all registered datasets with optional filtering and caching. Datasets are fetched concurrently for improved performance. Args: dataset_names: Optional list of dataset names to fetch. If None, fetches all. Names should match the dataset_name property of providers. cache: Whether to cache the fetched datasets. Defaults to True. This uses DB_DATA_PATH for caching remote datasets. max_concurrency: Maximum number of datasets to fetch concurrently. Defaults to 5. Set to 1 for fully sequential execution. Returns: List[SeedDataset]: List of all fetched datasets. Raises: ValueError: If any requested dataset_name does not exist. Exception: If any dataset fails to load. Example: >>> # Fetch all datasets >>> all_datasets = await SeedDatasetProvider.fetch_datasets_async() >>> >>> # Fetch specific datasets >>> specific = await SeedDatasetProvider.fetch_datasets_async( ... dataset_names=["harmbench", "DarkBench"] ... ) """ # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] ) -> Optional[tuple[str, SeedDataset]]: """ Fetch a single dataset with error handling. Returns: Optional[Tuple[str, SeedDataset]]: Tuple of provider name and dataset, or None if filtered. """ provider = provider_class() # Apply dataset name filter if specified if dataset_names is not None and provider.dataset_name not in dataset_names: logger.debug(f"Skipping {provider_name} - not in filter list") return None dataset = await provider.fetch_dataset(cache=cache) return (provider.dataset_name, dataset) # Create semaphore to limit concurrency semaphore = asyncio.Semaphore(max_concurrency) # Progress tracking total_count = len(cls._registry) pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] ) -> Optional[tuple[str, SeedDataset]]: """ Enforce concurrency limit and update progress during dataset fetch. Returns: Optional[Tuple[str, SeedDataset]]: Tuple of provider name and dataset, or None if filtered. """ async with semaphore: result = await fetch_single_dataset(provider_name, provider_class) pbar.update(1) return result # Fetch all datasets with controlled concurrency and progress bar tasks = [ fetch_with_semaphore(provider_name, provider_class) for provider_name, provider_class in cls._registry.items() ] results = await asyncio.gather(*tasks) pbar.close() # Merge datasets with the same name datasets: dict[str, SeedDataset] = {} for result in results: # Skip None results (filtered datasets) if result is None: continue dataset_name, dataset = result if dataset_name in datasets: logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values())