Source code for pyrit.common.download_hf_model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import logging
import os
from pathlib import Path
import httpx
from huggingface_hub import HfApi
logger = logging.getLogger(__name__)
[docs]
def get_available_files(model_id: str, token: str):
"""Fetches available files for a model from the Hugging Face repository."""
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
available_files = [file.rfilename for file in model_info.siblings]
# Perform simple validation: raise a ValueError if no files are available
if not len(available_files):
raise ValueError(f"No available files found for the model: {model_id}")
return available_files
except Exception as e:
logger.info(f"Error fetching model files for {model_id}: {e}")
return []
[docs]
async def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path):
"""
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
Returns:
List of URLs for the downloaded files.
"""
os.makedirs(cache_dir, exist_ok=True)
available_files = get_available_files(model_id, token)
# If no file patterns are provided, download all available files
if file_patterns is None:
files_to_download = available_files
logger.info(f"Downloading all files for model {model_id}.")
else:
# Filter files based on the patterns provided
files_to_download = [file for file in available_files if any(pattern in file for pattern in file_patterns)]
if not files_to_download:
logger.info(f"No files matched the patterns provided for model {model_id}.")
return
# Generate download URLs directly
base_url = f"https://huggingface.co/{model_id}/resolve/main/"
urls = [base_url + file for file in files_to_download]
# Download the files
await download_files(urls, token, cache_dir)
[docs]
async def download_chunk(url, headers, start, end, client):
"""Download a chunk of the file with a specified byte range."""
range_header = {"Range": f"bytes={start}-{end}", **headers}
response = await client.get(url, headers=range_header)
response.raise_for_status()
return response.content
[docs]
async def download_file(url, token, download_dir, num_splits):
"""Download a file in multiple segments (splits) using byte-range requests."""
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(follow_redirects=True) as client:
# Get the file size to determine chunk size
response = await client.head(url, headers=headers)
response.raise_for_status()
file_size = int(response.headers["Content-Length"])
chunk_size = file_size // num_splits
# Prepare tasks for each chunk
tasks = []
file_name = url.split("/")[-1]
file_path = Path(download_dir, file_name)
for i in range(num_splits):
start = i * chunk_size
end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1
tasks.append(download_chunk(url, headers, start, end, client))
# Download all chunks concurrently
chunks = await asyncio.gather(*tasks)
# Write chunks to the file in order
with open(file_path, "wb") as f:
for chunk in chunks:
f.write(chunk)
logger.info(f"Downloaded {file_name} to {file_path}")
[docs]
async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4):
"""Download multiple files with parallel downloads and segmented downloading."""
# Limit the number of parallel downloads
semaphore = asyncio.Semaphore(parallel_downloads)
async def download_with_limit(url):
async with semaphore:
await download_file(url, token, download_dir, num_splits)
# Run downloads concurrently, but limit to parallel_downloads at a time
await asyncio.gather(*(download_with_limit(url) for url in urls))