PyTorch Data Primitives
torch.utils.data.Dataset: samples and labels- You can retrieve one sample and label at a time by indexing
torch.utils.data.DataLoader: wraps an iterable aroundDatasetfor access- You can iterate over the dataset in batches
Table of contents
Downloading Existing Datasets
To download the Fashion-MNIST dataset, for example:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
In this example, a directory named data will be created in the current working directory and the Fashion-MNIST dataset will be downloaded into it.
Set train=True for the training set and train=False for the test set.
You can index into the dataset just like a python list, img, label = train_data[0].
Custom Datasets
import torch
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
Any custom dataset should inherit from torch.utils.data.Dataset and override:
__init__(self, annotations_file, img_dir, transform=None, target_transform=None)__len__(self)__getitem__(self, idx)
Wrapping Dataset in DataLoader
You can retrieve data in batches and shuffle it using DataLoader.
Shuffling prevents model overfitting. After each epoch (all batches have been seen once), data is shuffled.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
A single iteration over the DataLoader will return a batch of samples and labels.
train_features, train_labels = next(iter(train_dataloader))
# torch.Size([64, 1, 28, 28]), torch.Size([64])
Transforms
TorchVision Datasets come with two parameters
transform: modifies the featurestarget_transform: modifies the labels
For example, in the Fashion-MNIST dataset, features are in PIL image format and labels are integers. We want to change features to normalized float tensors and labels to one-hot encoded tensors.
ToTensor()converts PIL images andndarrayto normalized float tensors.Lambda()applies a lambda function to the target.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(
lambda y:
torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)
scatter_ fills the tensor with the value 1 at the indices specified by y.