PyTorch - Custom datasets

Load your own data into PyTorch with Dataset, DataLoader, ImageFolder, samplers and a transforms pipeline.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • explain the contract of torch.utils.data.Dataset,
  • use ImageFolder for the common image-classification layout,
  • write your own Dataset for any data you have on disk,
  • compose a transforms pipeline with deterministic preprocessing and stochastic augmentation,
  • balance unbalanced batches with WeightedRandomSampler,
  • build a small image classifier from a folder of images you assembled yourself.

Suggested timing

BlockTopic
15 minWhy custom datasets, the Dataset contract
25 minGet the data, ImageFolder
30 minTransforms — deterministic vs. augmentation
30 minWrite a Dataset from scratch
25 minDataLoader knobs and WeightedRandomSampler
55 minCapstone — your own image classifier

Why custom datasets

Built-in datasets like FashionMNIST and MNIST are training wheels. The moment you have a real project you’ll be loading your own files: photos in folders, audio in WAV files, sensor logs in CSV, MRI scans in DICOM, etc.

PyTorch has a small, composable API for that:

Building blockPurpose
torch.utils.data.DatasetYour data, indexed by integer.
torch.utils.data.DataLoaderWraps a Dataset to deliver batches, shuffling, parallel loading.
torchvision.datasets.ImageFolderA ready-made Dataset for images organized by folder.
torchvision.transformsImage preprocessing and augmentation.
torch.utils.data.SamplerDecides which indices to draw on each epoch.

We’ll use a small subset of Food-101 (pizza, steak, sushi) as a running example. The same code patterns work for medical images, satellite images, audio spectrograms, or anything else you load from disk.

Setup

PowerShell
uv init --python 3.12 datasets
cd datasets
uv add torch torchvision matplotlib pillow requests
main.py
from pathlib import Path
import zipfile
import requests
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import datasets, transforms
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

Get the data

The dataset ships as a zip on the mrdbourke/pytorch-deep-learning repo. Download it once and unpack it into data/.

main.py
DATA_PATH = Path("data")
IMAGE_PATH = DATA_PATH / "pizza_steak_sushi"
if not IMAGE_PATH.is_dir():
IMAGE_PATH.mkdir(parents=True, exist_ok=True)
url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
zip_file = DATA_PATH / "pizza_steak_sushi.zip"
zip_file.write_bytes(requests.get(url).content)
with zipfile.ZipFile(zip_file, "r") as z:
z.extractall(IMAGE_PATH)
zip_file.unlink()
train_dir = IMAGE_PATH / "train"
test_dir = IMAGE_PATH / "test"
print(list(train_dir.iterdir()))

The folder structure is the standard convention for image classification:

pizza_steak_sushi/
├── train/
│ ├── pizza/
│ ├── steak/
│ └── sushi/
└── test/
├── pizza/
├── steak/
└── sushi/

The directory name is the label. This is so common that ImageFolder automates it.

Inspect the data

Before doing anything else, look at a few samples:

main.py
import random
class_dirs = [d for d in train_dir.iterdir() if d.is_dir()]
fig, axes = plt.subplots(3, 3, figsize=(8, 8))
for ax, cls_dir in zip(axes, class_dirs):
images = list(cls_dir.glob("*.jpg"))
for sub_ax, image_path in zip(ax, random.sample(images, 3)):
sub_ax.imshow(Image.open(image_path))
sub_ax.set_title(cls_dir.name); sub_ax.axis("off")
plt.tight_layout(); plt.show()

The images vary in size, lighting, and framing. Welcome to real data.

Transforms

Models expect tensors of a fixed size and dtype. torchvision.transforms builds a pipeline that runs on each image as it is loaded.

main.py
train_transforms = transforms.Compose([
transforms.Resize(size=(64, 64)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
])
test_transforms = transforms.Compose([
transforms.Resize(size=(64, 64)),
transforms.ToTensor(),
])

Two rules of thumb:

  • The same deterministic preprocessing (resize, normalize) is applied to both train and test.
  • Augmentations (random flips, crops, color jitter, TrivialAugmentWide) belong only on the training transforms. The test set must measure how well you do on the real data, not on randomly modified data.

transforms.ToTensor() converts a PIL image to a [C, H, W] tensor in [0, 1]. Put it last — most other transforms operate on PIL images, not tensors.

Visualize what augmentation does

main.py
sample_path = next(train_dir.glob("**/*.jpg"))
sample = Image.open(sample_path)
fig, axes = plt.subplots(1, 5, figsize=(15, 4))
for ax in axes:
ax.imshow(train_transforms(sample).permute(1, 2, 0))
ax.axis("off")
plt.suptitle("same image, different augmentations")
plt.show()

Run it twice. The augmented versions vary; the underlying image is the same.

Try it — design augmentations for a new task

For each task, suggest transforms you would and wouldn’t include in training. Why?

  • Recognising pets vs. wild animals from photos.
  • Reading handwritten digits like MNIST.
  • Detecting cracks in concrete from drone imagery.

Option 1 — ImageFolder

When the data is already organized by folder, ImageFolder is the shortest path from disk to Dataset:

main.py
train_data = datasets.ImageFolder(root=train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(root=test_dir, transform=test_transforms)
print(train_data.classes) # ['pizza', 'steak', 'sushi']
print(train_data.class_to_idx) # {'pizza': 0, 'steak': 1, 'sushi': 2}
print(len(train_data), len(test_data))
image, label = train_data[0]
print(image.shape, label) # torch.Size([3, 64, 64]) 0

ImageFolder walks the directory once at construction time, builds the list of paths, and reads the corresponding image lazily on each __getitem__.

Option 2 — Write your own Dataset

ImageFolder covers most needs but you’ll often have to write a Dataset yourself: a CSV of (path, label) pairs, multi-label data, paired images, audio, sensor logs, etc.

A Dataset only has to implement three methods:

main.py
class ImageFolderCustom(Dataset):
def __init__(self, root: Path, transform=None):
self.paths = sorted(p for p in root.glob("*/*.jpg"))
self.transform = transform
self.classes = sorted(d.name for d in root.iterdir() if d.is_dir())
self.class_to_idx = {name: i for i, name in enumerate(self.classes)}
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, index: int):
path = self.paths[index]
image = Image.open(path).convert("RGB")
label = self.class_to_idx[path.parent.name]
if self.transform:
image = self.transform(image)
return image, label
custom_train = ImageFolderCustom(train_dir, transform=train_transforms)
print(len(custom_train), custom_train.classes)
image, label = custom_train[0]
print(image.shape, label)

The contract:

MethodWhat it returns
__len__Number of samples.
__getitem__(i)The i-th (input, target) pair.
__init__Whatever indexing/lookup you need to make __getitem__ cheap.

Two important properties:

  • Index-by-int. dataset[0], dataset[1], … dataset[len(dataset) - 1]. The DataLoader decides the order.
  • Lazy. Don’t load all images in __init__ — store paths and read on demand. Otherwise large datasets won’t fit in memory.

A Dataset from a CSV

A common real-world layout is images/ plus a CSV with path,label rows:

images/dog_001.jpg,dog
images/cat_017.jpg,cat
...
csv_dataset.py
import csv
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
class CSVImageDataset(Dataset):
def __init__(self, csv_path: Path, image_root: Path, transform=None):
with open(csv_path) as f:
self.rows = list(csv.reader(f))
self.image_root = image_root
self.transform = transform
self.classes = sorted({label for _, label in self.rows})
self.class_to_idx = {label: i for i, label in enumerate(self.classes)}
def __len__(self):
return len(self.rows)
def __getitem__(self, index):
rel_path, label = self.rows[index]
image = Image.open(self.image_root / rel_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, self.class_to_idx[label]

You’ll reuse this pattern over and over — just swap the storage layer (CSV, JSON, database, S3) and the loader (Image.open, torchaudio.load, pandas.read_parquet).

Try it — multi-label dataset

Imagine each photo has multiple labels (["beach", "sunset"], ["cat", "indoors"]). Sketch the changes you’d make to __getitem__ and to the loss function.

DataLoader

DataLoader wraps any Dataset and yields batches.

main.py
BATCH_SIZE = 32
train_loader = DataLoader(
train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
)
test_loader = DataLoader(
test_data,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=0,
)
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)
# torch.Size([32, 3, 64, 64]) torch.Size([32])

The most useful parameters:

ParameterMeaning
batch_sizeHow many samples per batch.
shuffleReorder indices every epoch. True for train, False for test.
num_workersBackground worker processes. 0 = main thread; >0 = parallel loading.
pin_memoryFaster CPU→GPU transfers (set True if using a GPU).
drop_lastDrop the last batch if it’s smaller than batch_size.
samplerReplaces the default sequential/random sampler — see below.
Note

On Windows, num_workers > 0 requires the training code to live inside an if __name__ == "__main__": guard, because each worker re-imports the script. Start with num_workers=0 while developing, then bump it up once everything works.

Inspect a batch

main.py
def show_batch(images, labels, classes, n: int = 8):
fig, axes = plt.subplots(1, n, figsize=(n * 2, 2.5))
for ax, image, label in zip(axes, images[:n], labels[:n]):
ax.imshow(image.permute(1, 2, 0)) # CHW -> HWC for matplotlib
ax.set_title(classes[label])
ax.axis("off")
plt.tight_layout(); plt.show()
show_batch(images, labels, train_data.classes)

permute(1, 2, 0) is the matplotlib convention — [H, W, C] instead of PyTorch’s [C, H, W].

Class imbalance — WeightedRandomSampler

If one class has 10× more samples than another, the model will learn to predict the majority class even when it shouldn’t. WeightedRandomSampler fixes that by drawing minority-class samples more often.

main.py
import collections
# Per-class counts
counts = collections.Counter(label for _, label in train_data.samples)
print(counts)
# Per-sample weight: inverse of its class count
class_weight = {cls: 1.0 / count for cls, count in counts.items()}
sample_weights = [class_weight[label] for _, label in train_data.samples]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True,
)
balanced_loader = DataLoader(
train_data,
batch_size=BATCH_SIZE,
sampler=sampler, # mutually exclusive with shuffle=True
num_workers=0,
)
# Confirm batches are now roughly balanced
batch_labels = next(iter(balanced_loader))[1]
print(collections.Counter(batch_labels.tolist()))

The pizza/steak/sushi dataset is already balanced, but the technique transfers to any imbalanced classification task.

Train a small CNN

The training loop is the same as the vision chapter. The only changes are the input channels (3 for RGB) and the spatial size after pooling.

main.py
class TinyCNN(nn.Module):
def __init__(self, in_channels: int, hidden: int, out_features: int):
super().__init__()
self.block_1 = nn.Sequential(
nn.Conv2d(in_channels, hidden, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.block_2 = nn.Sequential(
nn.Conv2d(hidden, hidden, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(hidden * 16 * 16, out_features),
)
def forward(self, x):
x = self.block_1(x); x = self.block_2(x); return self.classifier(x)
model = TinyCNN(in_channels=3, hidden=10, out_features=len(train_data.classes)).to(device)

For 64×64 input, two MaxPool2d(2) stages bring you down to 16×16, hence hidden * 16 * 16.

With this little data the model will not generalize well — that is exactly the problem the transfer-learning chapter solves. For now, train for a few epochs and observe the gap between train and test accuracy.

main.py
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(10):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
model.eval()
correct = total = 0
with torch.inference_mode():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
correct += (preds == y).sum().item(); total += x.size(0)
print(f"epoch {epoch + 1} test acc={correct / total:.2%}")

Exercises

Warm-up

  1. Print the first 5 paths in train_data.samples and confirm __getitem__(i) opens the matching file.
  2. Build a histogram of image sizes (width × height) before resizing. How varied is the data?
  3. Add transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) (ImageNet stats) to both train and test transforms. Why exactly the same to both?

Custom datasets

  1. Write a Dataset that reads a CSV with columns (path, label) instead of relying on the folder layout.
  2. Modify ImageFolderCustom to also accept *.png and *.jpeg.
  3. Write a Dataset whose targets are bounding-box coordinates (x1, y1, x2, y2) parsed from a JSON file. (Imagine a detection problem.)

Augmentation

  1. Increase the resize to (128, 128). What changes about the classifier’s input size?
  2. Replace TrivialAugmentWide with RandAugment(num_ops=2, magnitude=9). Did test accuracy improve?

Sampler

  1. Delete half of the pizza folder. Train once without a sampler and once with WeightedRandomSampler. Compare per-class accuracy.
  2. Set num_workers=2 inside an if __name__ == "__main__": block on Windows and measure the speedup.

Capstone — your own image classifier

Build a 3-class image classifier from scratch using your own photos. This trains every part of the pipeline you’ve learned: collecting data, organizing it, transforming it, loading it, and training on it.

Step 1 — pick three classes

Anything you can take 30+ photos of: pencil/pen/marker, apple/orange/banana, your three favourite mugs. Aim for variety — different lighting, angles, backgrounds.

Step 2 — assemble the folder

Take 30 training photos and 10 test photos per class. Resize them so none is larger than ~1000 px on the long side (your phone images are way too big). Lay them out as:

my_dataset/
├── train/
│ ├── class_a/
│ │ ├── 0001.jpg
│ │ └── ...
│ ├── class_b/
│ └── class_c/
└── test/
├── class_a/
├── class_b/
└── class_c/

Step 3 — load it

capstone.py
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
data_root = Path("my_dataset")
train_tf = transforms.Compose([
transforms.Resize((96, 96)),
transforms.RandomHorizontalFlip(),
transforms.TrivialAugmentWide(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_tf = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data = datasets.ImageFolder(data_root / "train", transform=train_tf)
test_data = datasets.ImageFolder(data_root / "test", transform=test_tf)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)
print(train_data.classes)

Step 4 — train a small CNN

capstone.py
import torch.nn as nn
class MyCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(64 * 12 * 12, 64), nn.ReLU(),
nn.Linear(64, num_classes),
)
def forward(self, x):
return self.classifier(self.features(x))
model = MyCNN(num_classes=len(train_data.classes)).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
EPOCHS = 30
for epoch in range(EPOCHS):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
model.eval()
correct = total = 0
with torch.inference_mode():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
correct += (preds == y).sum().item(); total += x.size(0)
print(f"epoch {epoch + 1:2d} test acc={correct / total:.2%}")

With 30 photos per class you should not expect great test accuracy — 65–80% is typical. The point of this capstone is the pipeline, not the score. The next chapter, Transfer learning, shows how to push that number much higher with the same data.

Step 5 — predict on a new photo

capstone.py
from PIL import Image
def predict(path: Path):
image = Image.open(path).convert("RGB")
tensor = test_tf(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
probs = torch.softmax(model(tensor), dim=1).squeeze().cpu()
cls = train_data.classes[probs.argmax().item()]
return cls, probs
label, probs = predict(Path("test_photo.jpg"))
print(f"{label}: {probs.max():.1%}")

Stretch goals

  • Compute and display a confusion matrix on your test set (see the vision chapter).
  • Add WeightedRandomSampler if some classes ended up with more samples than others.
  • Save the dataset class list to a JSON file alongside the model so you don’t depend on folder ordering at inference time.

Recap

  • A Dataset is anything with __len__ and __getitem__. Keep __init__ cheap, load on demand.
  • ImageFolder is the shortcut for train/<class>/<image>.jpg layouts.
  • Build a transforms pipeline: deterministic preprocessing + augmentation in train, deterministic preprocessing only in test.
  • DataLoader turns a Dataset into batches. shuffle for train, not for test. num_workers > 0 needs if __name__ == "__main__": on Windows.
  • For imbalanced data, swap shuffle=True for a WeightedRandomSampler.

The next chapter, Transfer learning, shows you how to leverage models pretrained on millions of images to get strong results from a tiny dataset like the one you just built.

References