Source code for pyrit.datasets.dataset_helper

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import hashlib
import io
import tempfile
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, TextIO

import requests

from pyrit.common.csv_helper import read_csv, write_csv
from pyrit.common.json_helper import read_json, read_jsonl, write_json, write_jsonl
from pyrit.common.path import DB_DATA_PATH
from pyrit.common.text_helper import read_txt, write_txt

# Define the type for the file handlers
FileHandlerRead = Callable[[TextIO], List[Dict[str, str]]]
FileHandlerWrite = Callable[[TextIO, List[Dict[str, str]]], None]

FILE_TYPE_HANDLERS: Dict[str, Dict[str, Callable]] = {
    "json": {"read": read_json, "write": write_json},
    "jsonl": {"read": read_jsonl, "write": write_jsonl},
    "csv": {"read": read_csv, "write": write_csv},
    "txt": {"read": read_txt, "write": write_txt},
}


def _get_cache_file_name(source: str, file_type: str) -> str:
    """
    Generate a cache file name based on the source URL and file type.
    """
    hash_source = hashlib.md5(source.encode("utf-8")).hexdigest()
    return f"{hash_source}.{file_type}"


def _read_cache(cache_file: Path, file_type: str) -> List[Dict[str, str]]:
    """
    Read data from cache.
    """
    with cache_file.open("r", encoding="utf-8") as file:
        if file_type in FILE_TYPE_HANDLERS:
            return FILE_TYPE_HANDLERS[file_type]["read"](file)
        else:
            valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
            raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.")


def _write_cache(cache_file: Path, examples: List[Dict[str, str]], file_type: str):
    """
    Write data to cache.
    """
    cache_file.parent.mkdir(parents=True, exist_ok=True)
    with cache_file.open("w", encoding="utf-8") as file:
        if file_type in FILE_TYPE_HANDLERS:
            FILE_TYPE_HANDLERS[file_type]["write"](file, examples)
        else:
            valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
            raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.")


def _fetch_from_public_url(source: str, file_type: str) -> List[Dict[str, str]]:
    """
    Fetch examples from a repository.
    """
    response = requests.get(source)
    if response.status_code == 200:
        if file_type in FILE_TYPE_HANDLERS:
            if file_type == "json":
                return FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text))
            else:
                return FILE_TYPE_HANDLERS[file_type]["read"](
                    io.StringIO("\n".join(response.text.splitlines()))
                )  # noqa: E501
        else:
            valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
            raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.")
    else:
        raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}")


def _fetch_from_file(source: str, file_type: str) -> List[Dict[str, str]]:
    """
    Fetch examples from a local file.
    """
    with open(source, "r", encoding="utf-8") as file:
        if file_type in FILE_TYPE_HANDLERS:
            return FILE_TYPE_HANDLERS[file_type]["read"](file)
        else:
            valid_types = ", ".join(FILE_TYPE_HANDLERS.keys())
            raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.")


[docs] def fetch_examples( source: str, source_type: Literal["public_url", "file"] = "public_url", cache: bool = True, data_home: Optional[Path] = None, ) -> List[Dict[str, str]]: """ Fetch examples from a specified source with caching support. Example usage >>> examples = fetch_examples( >>> source='https://raw.githubusercontent.com/KutalVolkan/many-shot-jailbreaking-dataset/5eac855/examples.json', >>> source_type='public_url' >>> ) Args: source (str): The source from which to fetch examples. source_type (Literal["public_url", "file"]): The type of source ('public_url' or 'file'). cache (bool): Whether to cache the fetched examples. Defaults to True. data_home (Optional[Path]): Directory to store cached data. Defaults to None. Returns: List[Dict[str, str]]: A list of examples. """ file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") if not data_home: data_home = DB_DATA_PATH / "seed-prompt-entries" else: data_home = Path(data_home) cache_file = data_home / _get_cache_file_name(source, file_type) if cache and cache_file.exists(): return _read_cache(cache_file, file_type) if source_type == "public_url": examples = _fetch_from_public_url(source, file_type) elif source_type == "file": examples = _fetch_from_file(source, file_type) if cache: _write_cache(cache_file, examples, file_type) else: with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=f".{file_type}") as temp_file: FILE_TYPE_HANDLERS[file_type]["write"](temp_file, examples) return examples