User Guide

Getting Started

Prerequisites

Installation

Install the Azure Storage Connector for PyTorch (azstoragetorch) with pip:

pip install azstoragetorch

Configuration

azstoragetorch should work without any explicit credential configuration.

azstoragetorch interfaces default to DefaultAzureCredential for credentials. DefaultAzureCredential automatically retrieves Microsoft Entra ID tokens based on your current environment. For more information on DefaultAzureCredential, see its documentation.

To override credentials, azstoragetorch interfaces accept a credential keyword argument override and accept SAS tokens in query strings of provided Azure Storage URLs. See the API Reference for more details.

Saving and Loading PyTorch Models (Checkpointing)

PyTorch supports saving and loading trained models (i.e., checkpointing). The core PyTorch interfaces for saving and loading models are torch.save() and torch.load() respectively. Both of these functions accept a file-like object to be written to or read from.

azstoragetorch offers the azstoragetorch.io.BlobIO file-like object class to save and load models directly to and from Azure Blob Storage when using torch.save() and torch.load().

Saving a Model

To save a model to Azure Blob Storage, pass a azstoragetorch.io.BlobIO directly to torch.save(). When creating the BlobIO, specify the URL to the blob you'd like to save the model to and use write mode (i.e., wb):

import torch
import torchvision.models  # Install separately: ``pip install torchvision``
from azstoragetorch.io import BlobIO

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

# Model to save. Replace with your own model.
model = torchvision.models.resnet18(weights="DEFAULT")

# Save trained model to Azure Blob Storage. This saves the model weights
# to a blob named "model_weights.pth" in the container specified by CONTAINER_URL.
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "wb") as f:
    torch.save(model.state_dict(), f)

Loading a Model

To load a model from Azure Blob Storage, pass a azstoragetorch.io.BlobIO directly to torch.load(). When creating the BlobIO, specify the URL to the blob storing the model weights and use read mode (i.e., rb):

import torch
import torchvision.models  # Install separately: ``pip install torchvision``
from azstoragetorch.io import BlobIO

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

# Model to load weights for. Replace with your own model.
model = torchvision.models.resnet18()

# Load trained model from Azure Blob Storage.  This loads the model weights
# from the blob named "model_weights.pth" in the container specified by CONTAINER_URL.
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "rb") as f:
    model.load_state_dict(torch.load(f))

PyTorch Datasets

PyTorch offers the Dataset and DataLoader primitives for loading data samples. azstoragetorch provides implementations for both types of PyTorch datasets, map-style and iterable-style datasets, to load data samples from Azure Blob Storage:

Data samples returned from both datasets map directly one-to-one to blobs in Azure Blob Storage. Both classes can be directly provided to a PyTorch DataLoader (read more here). When instantiating these dataset classes, use one of their class methods:

  • from_container_url() - Instantiate dataset by listing blobs from an Azure Storage container.

  • from_blob_urls() - Instantiate dataset from provided blob URLs

Instantiation directly using __init__() is not supported. Read sections below on how to use these class methods to create datasets.

Create Dataset from Azure Storage Container

To create an azstoragetorch dataset by listing blobs in a single Azure Storage container, use the dataset class's corresponding from_container_url() method:

The methods accept the URL to the Azure Storage container to list blobs from. Listing is performed using the List Blobs API. For example:

from azstoragetorch.datasets import BlobDataset, IterableBlobDataset

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

# Create a map-style dataset by listing blobs in the container specified by CONTAINER_URL.
map_dataset = BlobDataset.from_container_url(CONTAINER_URL)

# Create an iterable-style dataset by listing blobs in the container specified by CONTAINER_URL.
iterable_dataset = IterableBlobDataset.from_container_url(CONTAINER_URL)

The above examples lists all blobs in the container. To only include blobs whose name starts with a specific prefix, provide the prefix keyword argument:

from azstoragetorch.datasets import BlobDataset, IterableBlobDataset

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

# Create a map-style dataset only including blobs whose name starts with the prefix "images/"
map_dataset = BlobDataset.from_container_url(CONTAINER_URL, prefix="images/")

# Create an iterable-style dataset only including blobs whose name starts with the prefix "images/"
iterable_dataset = IterableBlobDataset.from_container_url(CONTAINER_URL, prefix="images/")

Create Dataset from List of Blobs

To create an azstoragetorch dataset from a pre-defined list of blobs, use the dataset class's corresponding from_blob_urls() method:

The method accepts a list of blob URLs to create the dataset from. For example:

from azstoragetorch.datasets import BlobDataset, IterableBlobDataset

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

# List of blob URLs to create dataset from. Update with your own blob names.
blob_urls = [
    f"{CONTAINER_URL}/<blob-name-1>",
    f"{CONTAINER_URL}/<blob-name-2>",
    f"{CONTAINER_URL}/<blob-name-3>",
]

# Create a map-style dataset from the list of blob URLs
map_dataset = BlobDataset.from_blob_urls(blob_urls)

# Create an iterable-style dataset from the list of blob URLs
iterable_dataset = IterableBlobDataset.from_blob_urls(blob_urls)

Transforming Dataset Output

The default output format of dataset samples are dictionaries representing a blob in the dataset. Each dictionary has the keys:

  • url: The full endpoint URL of the blob.

  • data: The content of the blob as bytes.

For example, when accessing a dataset sample:

print(map_dataset[0])

It will have the following return format:

{
    "url": "https://<account-name>.blob.core.windows.net/<container-name>/<blob-name>",
    "data": b"<blob-content>"
}

To override the output format, provide a transform callable to either from_blob_urls or from_container_url when creating the dataset. The transform callable accepts a single positional argument of type azstoragetorch.datasets.Blob representing a blob in the dataset. This Blob object can be used to retrieve properties and content of the blob as part of the transform callable.

Emulating the PyTorch transform tutorial, the example below shows how to transform a Blob object to a torch.Tensor of a PIL.Image:

from azstoragetorch.datasets import BlobDataset, Blob
import PIL.Image  # Install separately: ``pip install pillow``
import torch
import torchvision.transforms  # Install separately: ``pip install torchvision``

# Update URL with your own Azure Storage account, container, and blob containing an image
IMAGE_BLOB_URL = "https://<storage-account-name>.blob.core.windows.net/<container-name>/<blob-image-name>"

# Define transform to convert blob to a tuple of (image_name, image_tensor)
def to_img_name_and_tensor(blob: Blob) -> tuple[str, torch.Tensor]:
    # Use blob reader to retrieve blob contents and then transform to an image tensor.
    with blob.reader() as f:
        image = PIL.Image.open(f)
        image_tensor = torchvision.transforms.ToTensor()(image)
    return blob.blob_name, image_tensor

# Provide transform to dataset constructor
dataset = BlobDataset.from_blob_urls(
    IMAGE_BLOB_URL,
    transform=to_img_name_and_tensor,
)

print(dataset[0])  # Prints tuple of (image_name, image_tensor) for blob in dataset

The output should include the blob name and Tensor of the image:

("<blob-image-name>", tensor([...]))

Using Dataset with PyTorch DataLoader

Once instantiated, azstoragetorch datasets can be provided directly to a PyTorch DataLoader for loading samples:

from azstoragetorch.datasets import BlobDataset
from torch.utils.data import DataLoader

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

dataset = BlobDataset.from_container_url(CONTAINER_URL)

# Create a DataLoader to load data samples from the dataset in batches of 32
dataloader = DataLoader(dataset, batch_size=32)

for batch in dataloader:
    print(batch["url"])  # Prints blob URLs for each 32 sample batch

Iterable-style Datasets with Multiple Workers

When using a IterableBlobDataset and DataLoader with multiple workers (i.e., num_workers > 1), the IterableBlobDataset automatically shards data samples returned across workers to avoid a DataLoader from returning duplicate samples from its workers:

from azstoragetorch.datasets import IterableBlobDataset
from torch.utils.data import DataLoader

# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"

dataset = IterableBlobDataset.from_container_url(CONTAINER_URL)

# Iterate over the dataset to get the number of samples in it
num_samples_from_dataset = len([blob["url"] for blob in dataset])

# Create a DataLoader to load data samples from the dataset in batches of 32 using 4 workers
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

# Iterate over the DataLoader to get the number of samples returned from it
num_samples_from_dataloader = 0
for batch in dataloader:
    num_samples_from_dataloader += len(batch["url"])

# The number of samples returned from the dataset should be equal to the number of samples
# returned from the DataLoader. If the dataset did not handle sharding, the number of samples
# returned from the DataLoader would be ``num_workers`` times (i.e., four times) the number
# of samples in the dataset.
assert num_samples_from_dataset == num_samples_from_dataloader