PyTorch Data Primitives
torch.utils.data.Dataset
: samples and labels- You can retrieve one sample and label at a time by indexing
torch.utils.data.DataLoader
: wraps an iterable aroundDataset
for access- You can iterate over the dataset in batches
Table of contents
Downloading Existing Datasets
To download the Fashion-MNIST dataset, for example:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
In this example, a directory named data
will be created in the current working directory and the Fashion-MNIST dataset will be downloaded into it.
Set train=True
for the training set and train=False
for the test set.
You can index into the dataset just like a python list, img, label = train_data[0]
.
Custom Datasets
import torch
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
Any custom dataset should inherit from torch.utils.data.Dataset
and override:
__init__(self, annotations_file, img_dir, transform=None, target_transform=None)
__len__(self)
__getitem__(self, idx)
Wrapping Dataset in DataLoader
You can retrieve data in batches and shuffle it using DataLoader
.
Shuffling prevents model overfitting. After each epoch (all batches have been seen once), data is shuffled.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
A single iteration over the DataLoader
will return a batch of samples and labels.
train_features, train_labels = next(iter(train_dataloader))
# torch.Size([64, 1, 28, 28]), torch.Size([64])
Transforms
TorchVision Datasets come with two parameters
transform
: modifies the featurestarget_transform
: modifies the labels
For example, in the Fashion-MNIST dataset, features are in PIL image format and labels are integers. We want to change features to normalized float tensors and labels to one-hot encoded tensors.
ToTensor()
converts PIL images andndarray
to normalized float tensors.Lambda()
applies a lambda function to the target.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(
lambda y:
torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)
scatter_
fills the tensor with the value 1
at the indices specified by y
.