PyTorch Data Primitives

  • torch.utils.data.Dataset: samples and labels
    • You can retrieve one sample and label at a time by indexing
  • torch.utils.data.DataLoader: wraps an iterable around Dataset for access
    • You can iterate over the dataset in batches
Table of contents
  1. Downloading Existing Datasets
  2. Custom Datasets
  3. Wrapping Dataset in DataLoader
  4. Transforms

Downloading Existing Datasets

To download the Fashion-MNIST dataset, for example:

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In this example, a directory named data will be created in the current working directory and the Fashion-MNIST dataset will be downloaded into it.

Set train=True for the training set and train=False for the test set.

You can index into the dataset just like a python list, img, label = train_data[0].


Custom Datasets

import torch
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image

Any custom dataset should inherit from torch.utils.data.Dataset and override:

  • __init__(self, annotations_file, img_dir, transform=None, target_transform=None)
  • __len__(self)
  • __getitem__(self, idx)

Wrapping Dataset in DataLoader

You can retrieve data in batches and shuffle it using DataLoader.

Shuffling prevents model overfitting. After each epoch (all batches have been seen once), data is shuffled.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

A single iteration over the DataLoader will return a batch of samples and labels.

train_features, train_labels = next(iter(train_dataloader))
# torch.Size([64, 1, 28, 28]), torch.Size([64])

Transforms

TorchVision Datasets come with two parameters

  • transform: modifies the features
  • target_transform: modifies the labels

For example, in the Fashion-MNIST dataset, features are in PIL image format and labels are integers. We want to change features to normalized float tensors and labels to one-hot encoded tensors.

  • ToTensor() converts PIL images and ndarray to normalized float tensors.
  • Lambda() applies a lambda function to the target.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(
      lambda y:
        torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

scatter_ fills the tensor with the value 1 at the indices specified by y.