User Guide¶
Getting Started¶
Prerequisites¶
Python 3.9 or later installed
Have an Azure subscription and an Azure storage account
Installation¶
Install the Azure Storage Connector for PyTorch (azstoragetorch
) with pip:
pip install azstoragetorch
Configuration¶
azstoragetorch
should work without any explicit credential configuration.
azstoragetorch
interfaces default to DefaultAzureCredential
for credentials. DefaultAzureCredential
automatically retrieves
Microsoft Entra ID tokens based on your current environment. For more information
on DefaultAzureCredential
, see its documentation.
To override credentials, azstoragetorch
interfaces accept a credential
keyword argument override and accept SAS tokens in query strings of
provided Azure Storage URLs. See the API Reference for more details.
Saving and Loading PyTorch Models (Checkpointing)¶
PyTorch supports saving and loading trained models
(i.e., checkpointing). The core PyTorch interfaces for saving and loading models are
torch.save()
and torch.load()
respectively. Both of these functions
accept a file-like object to be written to or read from.
azstoragetorch
offers the azstoragetorch.io.BlobIO
file-like object class
to save and load models directly to and from Azure Blob Storage when using torch.save()
and torch.load()
.
Saving a Model¶
To save a model to Azure Blob Storage, pass a azstoragetorch.io.BlobIO
directly to torch.save()
. When creating the BlobIO
,
specify the URL to the blob you'd like to save the model to and use write mode (i.e., wb
):
import torch
import torchvision.models # Install separately: ``pip install torchvision``
from azstoragetorch.io import BlobIO
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Model to save. Replace with your own model.
model = torchvision.models.resnet18(weights="DEFAULT")
# Save trained model to Azure Blob Storage. This saves the model weights
# to a blob named "model_weights.pth" in the container specified by CONTAINER_URL.
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "wb") as f:
torch.save(model.state_dict(), f)
Loading a Model¶
To load a model from Azure Blob Storage, pass a azstoragetorch.io.BlobIO
directly to torch.load()
. When creating the BlobIO
,
specify the URL to the blob storing the model weights and use read mode (i.e., rb
):
import torch
import torchvision.models # Install separately: ``pip install torchvision``
from azstoragetorch.io import BlobIO
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Model to load weights for. Replace with your own model.
model = torchvision.models.resnet18()
# Load trained model from Azure Blob Storage. This loads the model weights
# from the blob named "model_weights.pth" in the container specified by CONTAINER_URL.
with BlobIO(f"{CONTAINER_URL}/model_weights.pth", "rb") as f:
model.load_state_dict(torch.load(f))
PyTorch Datasets¶
PyTorch offers the Dataset and DataLoader primitives for
loading data samples. azstoragetorch
provides implementations for both types
of PyTorch datasets, map-style and iterable-style datasets,
to load data samples from Azure Blob Storage:
azstoragetorch.datasets.BlobDataset
- Map-style dataset. Use this class for random access to data samples. The class eagerly lists samples in dataset on instantiation.azstoragetorch.datasets.IterableBlobDataset
- Iterable-style dataset. Use this class when working with large datasets that may not fit in memory. The class lazily lists samples as dataset is iterated over.
Data samples returned from both datasets map directly one-to-one to blobs in Azure Blob Storage.
Both classes can be directly provided to a PyTorch DataLoader
(read more here). When instantiating these dataset
classes, use one of their class methods:
from_container_url()
- Instantiate dataset by listing blobs from an Azure Storage container.from_blob_urls()
- Instantiate dataset from provided blob URLs
Instantiation directly using __init__()
is not supported. Read sections below on
how to use these class methods to create datasets.
Create Dataset from Azure Storage Container¶
To create an azstoragetorch
dataset by listing blobs in a single Azure Storage container,
use the dataset class's corresponding from_container_url()
method:
azstoragetorch.datasets.BlobDataset.from_container_url()
for map-style datasetazstoragetorch.datasets.IterableBlobDataset.from_container_url()
for iterable-style dataset
The methods accept the URL to the Azure Storage container to list blobs from. Listing is performed using the List Blobs API. For example:
from azstoragetorch.datasets import BlobDataset, IterableBlobDataset
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Create a map-style dataset by listing blobs in the container specified by CONTAINER_URL.
map_dataset = BlobDataset.from_container_url(CONTAINER_URL)
# Create an iterable-style dataset by listing blobs in the container specified by CONTAINER_URL.
iterable_dataset = IterableBlobDataset.from_container_url(CONTAINER_URL)
The above examples lists all blobs in the container. To only include blobs whose name starts with
a specific prefix, provide the prefix
keyword argument:
from azstoragetorch.datasets import BlobDataset, IterableBlobDataset
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# Create a map-style dataset only including blobs whose name starts with the prefix "images/"
map_dataset = BlobDataset.from_container_url(CONTAINER_URL, prefix="images/")
# Create an iterable-style dataset only including blobs whose name starts with the prefix "images/"
iterable_dataset = IterableBlobDataset.from_container_url(CONTAINER_URL, prefix="images/")
Create Dataset from List of Blobs¶
To create an azstoragetorch
dataset from a pre-defined list of blobs, use the dataset class's
corresponding from_blob_urls()
method:
azstoragetorch.datasets.BlobDataset.from_blob_urls()
for map-style datasetazstoragetorch.datasets.IterableBlobDataset.from_blob_urls()
for iterable-style dataset
The method accepts a list of blob URLs to create the dataset from. For example:
from azstoragetorch.datasets import BlobDataset, IterableBlobDataset
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
# List of blob URLs to create dataset from. Update with your own blob names.
blob_urls = [
f"{CONTAINER_URL}/<blob-name-1>",
f"{CONTAINER_URL}/<blob-name-2>",
f"{CONTAINER_URL}/<blob-name-3>",
]
# Create a map-style dataset from the list of blob URLs
map_dataset = BlobDataset.from_blob_urls(blob_urls)
# Create an iterable-style dataset from the list of blob URLs
iterable_dataset = IterableBlobDataset.from_blob_urls(blob_urls)
Transforming Dataset Output¶
The default output format of dataset samples are dictionaries representing a blob in the dataset. Each dictionary has the keys:
url
: The full endpoint URL of the blob.data
: The content of the blob asbytes
.
For example, when accessing a dataset sample:
print(map_dataset[0])
It will have the following return format:
{
"url": "https://<account-name>.blob.core.windows.net/<container-name>/<blob-name>",
"data": b"<blob-content>"
}
To override the output format, provide a transform
callable to either from_blob_urls
or from_container_url
when creating the dataset. The transform
callable accepts a
single positional argument of type azstoragetorch.datasets.Blob
representing
a blob in the dataset. This Blob
object can be used to
retrieve properties and content of the blob as part of the transform
callable.
Emulating the PyTorch transform tutorial, the example below shows
how to transform a Blob
object to a torch.Tensor
of
a PIL.Image
:
from azstoragetorch.datasets import BlobDataset, Blob
import PIL.Image # Install separately: ``pip install pillow``
import torch
import torchvision.transforms # Install separately: ``pip install torchvision``
# Update URL with your own Azure Storage account, container, and blob containing an image
IMAGE_BLOB_URL = "https://<storage-account-name>.blob.core.windows.net/<container-name>/<blob-image-name>"
# Define transform to convert blob to a tuple of (image_name, image_tensor)
def to_img_name_and_tensor(blob: Blob) -> tuple[str, torch.Tensor]:
# Use blob reader to retrieve blob contents and then transform to an image tensor.
with blob.reader() as f:
image = PIL.Image.open(f)
image_tensor = torchvision.transforms.ToTensor()(image)
return blob.blob_name, image_tensor
# Provide transform to dataset constructor
dataset = BlobDataset.from_blob_urls(
IMAGE_BLOB_URL,
transform=to_img_name_and_tensor,
)
print(dataset[0]) # Prints tuple of (image_name, image_tensor) for blob in dataset
The output should include the blob name and Tensor
of the image:
("<blob-image-name>", tensor([...]))
Using Dataset with PyTorch DataLoader¶
Once instantiated, azstoragetorch
datasets can be provided directly to a PyTorch
DataLoader
for loading samples:
from azstoragetorch.datasets import BlobDataset
from torch.utils.data import DataLoader
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
dataset = BlobDataset.from_container_url(CONTAINER_URL)
# Create a DataLoader to load data samples from the dataset in batches of 32
dataloader = DataLoader(dataset, batch_size=32)
for batch in dataloader:
print(batch["url"]) # Prints blob URLs for each 32 sample batch
Iterable-style Datasets with Multiple Workers¶
When using a IterableBlobDataset
and
DataLoader
with multiple workers (i.e., num_workers > 1
), the
IterableBlobDataset
automatically shards data samples
returned across workers to avoid a DataLoader
from returning
duplicate samples from its workers:
from azstoragetorch.datasets import IterableBlobDataset
from torch.utils.data import DataLoader
# Update URL with your own Azure Storage account and container name
CONTAINER_URL = "https://<my-storage-account-name>.blob.core.windows.net/<my-container-name>"
dataset = IterableBlobDataset.from_container_url(CONTAINER_URL)
# Iterate over the dataset to get the number of samples in it
num_samples_from_dataset = len([blob["url"] for blob in dataset])
# Create a DataLoader to load data samples from the dataset in batches of 32 using 4 workers
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
# Iterate over the DataLoader to get the number of samples returned from it
num_samples_from_dataloader = 0
for batch in dataloader:
num_samples_from_dataloader += len(batch["url"])
# The number of samples returned from the dataset should be equal to the number of samples
# returned from the DataLoader. If the dataset did not handle sharding, the number of samples
# returned from the DataLoader would be ``num_workers`` times (i.e., four times) the number
# of samples in the dataset.
assert num_samples_from_dataset == num_samples_from_dataloader