# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from typing import Literal, Optional
from pyrit.datasets.dataset_helper import fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
[docs]
def fetch_medsafetybench_dataset(
subset_name: Literal["train", "test", "generated", "all"] = "all",
cache: bool = True,
data_home: Optional[Path] = None,
output_csv_path: Optional[str] = None,
) -> SeedPromptDataset:
"""
Fetch MedSafetyBench examples (merged) and return them as a SeedPromptDataset.
Args:
subset_name (Literal): Choose from "train", "test", "generated", or "all".
cache (bool): Whether to cache the data locally.
data_home (Optional[Path]): Optional path to override default cache location.
output_csv_path (Optional[str]): Path where to save the combined CSV. If None, uses default naming.
Returns:
SeedPromptDataset: A dataset of prompts from MedSafetyBench.
Note:
For more information and access to the original dataset and related materials, visit:
https://github.com/AI4LIFE-GROUP/med-safety-bench.
Based on research in:
https://proceedings.neurips.cc/paper_files/paper/2024/hash/3ac952d0264ef7a505393868a70a46b6-Abstract-Datasets_and_Benchmarks_Track.html
Authors: Tessa Han, Aounon Kumar, Chirag Agarwal, Himabindu Lakkaraju.
"""
base_url = "https://raw.githubusercontent.com/AI4LIFE-GROUP/" "med-safety-bench/main/datasets"
sources = []
if subset_name == "test":
for model in ["gpt4", "llama2"]:
for category in range(1, 10):
sources.append(f"{base_url}/test/{model}/" f"med_safety_demonstrations_category_{category}.csv")
elif subset_name == "train":
for model in ["gpt4", "llama2"]:
for category in range(1, 10):
sources.append(f"{base_url}/train/{model}/" f"med_safety_demonstrations_category_{category}.csv")
elif subset_name == "generated":
for category in range(1, 10):
sources.append(f"{base_url}/med_harm_llama3/category_{category}.txt")
elif subset_name == "all":
for subset in ["test", "train"]:
for model in ["gpt4", "llama2"]:
for category in range(1, 10):
sources.append(f"{base_url}/{subset}/{model}/" f"med_safety_demonstrations_category_{category}.csv")
for category in range(1, 10):
sources.append(f"{base_url}/med_harm_llama3/category_{category}.txt")
else:
raise ValueError(
f"Invalid subset_name: {subset_name}. " "Expected one of: 'train', 'test', 'generated', 'all'."
)
all_prompts = []
combined_data = []
for source in sources:
examples = fetch_examples(
source=source,
source_type="public_url",
cache=cache,
data_home=data_home,
)
for ex in examples:
prompt = ex.get("harmful_medical_request") or ex.get("prompt")
if not prompt:
raise KeyError(f"No 'harmful_medical_request' or 'prompt' found in example from {source}")
url_parts = source.split("/")
model_type = url_parts[-2] if len(url_parts) >= 2 else "unknown"
filename = url_parts[-1]
category_str = ""
category = 0 # Use a default integer value
if filename.endswith(".txt"):
# MYPY FIX: Extract category as a string, then safely convert to int
category_str = filename.split("_")[-1].replace(".txt", "") if "_" in filename else ""
file_type = "generated"
else:
# MYPY FIX: Extract category as a string, then safely convert to int
category_str = filename.split("_")[-1].replace(".csv", "") if "_" in filename else ""
file_type = url_parts[-3] if len(url_parts) >= 3 else "unknown"
if category_str.isdigit():
category = int(category_str)
all_prompts.append(
SeedPrompt(
value=prompt,
data_type="text",
name="MedSafetyBench",
dataset_name="MedSafetyBench",
harm_categories=["medical safety"],
description=(
"Prompt from MedSafetyBench dataset - "
f"{model_type} model, category {category}, type {file_type}."
),
source=source,
)
)
combined_data.append(
{
"prompt": prompt,
"model_type": model_type,
"category": category,
"file_type": file_type,
"source": source,
}
)
return SeedPromptDataset(prompts=all_prompts)