PyTorch - Vision

Train convolutional neural networks on image data with PyTorch and torchvision — from FashionMNIST warm-up to an MNIST capstone.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • read images as [C, H, W] tensors and feed them to a model in batches,
  • explain what a convolution does and why CNNs beat plain MLPs on images,
  • compute the spatial output size of a Conv2d / MaxPool2d layer,
  • assemble a small CNN with Conv2d, MaxPool2d, BatchNorm2d, Dropout,
  • train it on FashionMNIST with batched loaders,
  • visualize predictions and a confusion matrix,
  • build an MNIST digit recognizer as a capstone and test it on hand-drawn digits.

Suggested timing

BlockTopic
20 minImage tensors, transforms, DataLoader
25 minConvolution intuition: kernels, strides, padding
25 minBuild a CNN, trace shapes through it
35 minTrain and evaluate on FashionMNIST
25 minVisualize predictions, confusion matrix
50 minCapstone — MNIST digit recognizer

Setup

PowerShell
uv init --python 3.12 vision
cd vision
uv add torch torchvision matplotlib
main.py
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

Image tensors

PyTorch images are [C, H, W] tensors:

  • C = channels (1 for grayscale, 3 for RGB),
  • H = height in pixels,
  • W = width in pixels.

torchvision.datasets ships ready-to-use datasets that download themselves on first run.

main.py
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=transforms.ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=transforms.ToTensor(),
)
class_names = train_data.classes
print(len(train_data), len(test_data))
print(class_names)

transforms.ToTensor() converts a PIL image into a [C, H, W] tensor with values in [0, 1].

main.py
image, label = train_data[0]
print("image shape:", image.shape) # torch.Size([1, 28, 28])
print("label :", label, class_names[label])
plt.imshow(image.squeeze(), cmap="gray")
plt.title(class_names[label])
plt.axis("off"); plt.show()

squeeze() drops the singleton channel dim so matplotlib can plot a (H, W) array.

DataLoader

A DataLoader wraps a dataset and yields batches — small groups of samples — which is what we feed to the model during training.

main.py
BATCH_SIZE = 32
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
batch_images, batch_labels = next(iter(train_loader))
print(batch_images.shape, batch_labels.shape)
# torch.Size([32, 1, 28, 28]) torch.Size([32])

Why batches?

  • Memory. A whole dataset rarely fits in GPU memory.
  • Speed. GPUs are happiest with hundreds of operations at once.
  • Optimization. Stochastic gradient descent updates from one batch at a time. The noise from batch-to-batch variation actually helps generalization.

shuffle=True on the training loader prevents the model from memorising the order. The test loader does not shuffle — it isn’t being learned from.

Try it — read shapes

For each line, predict the shape before running:

main.py
print(train_data[5][0].shape)
print(torch.stack([img for img, _ in [train_data[i] for i in range(8)]]).shape)
print(next(iter(train_loader))[0].shape)

Convolution intuition

A linear layer treats the input as a flat vector. That throws away the spatial structure of an image: pixel (0, 0) and pixel (0, 1) are siblings, while pixel (0, 0) and pixel (27, 27) are far apart.

A convolutional layer keeps the structure. It slides a small kernel (e.g. 3×3) over the image and computes a weighted sum at every position. The same kernel is applied everywhere — so a “vertical edge detector” works the same in the top-left as in the bottom-right.

A convolutional network stacks many such layers. Early layers learn simple features (edges, corners). Deep layers combine them into more abstract features (shapes, parts, whole objects).

Output size formula

Given an input with side H, a kernel of size K, padding P, and stride S, the output side is:

out = (H + 2P - K) / S + 1

Two common configurations:

LayerWhat stays / shrinks
Conv2d(in, out, kernel_size=3, padding=1)Spatial size stays the same
MaxPool2d(kernel_size=2, stride=2)Spatial size halves

Memorize these two — almost every CNN you’ll write uses them.

Try it — compute output sizes

For each, compute the output spatial size:

InputLayerOutput
28×28Conv2d(_, _, k=3, padding=1)?
28×28Conv2d(_, _, k=5, padding=0)?
28×28MaxPool2d(2)?
64×64Conv2d(_, _, k=3, stride=2, padding=1)?

A first model — flatten + Linear (baseline)

Before reaching for convolutions, try the simpler approach: flatten the image and feed it to a fully connected network.

main.py
class FashionMLP(nn.Module):
def __init__(self, in_features: int, hidden: int, out_features: int):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, out_features),
)
def forward(self, x):
return self.net(x)
mlp = FashionMLP(in_features=28 * 28, hidden=128, out_features=len(class_names)).to(device)

This works, but it discards the spatial structure of the image. We’ll come back to it for comparison.

Build a CNN

main.py
class FashionCNN(nn.Module):
def __init__(self, in_channels: int, hidden: int, out_features: int):
super().__init__()
self.block_1 = nn.Sequential(
nn.Conv2d(in_channels, hidden, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden),
nn.ReLU(),
nn.Conv2d(hidden, hidden, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.block_2 = nn.Sequential(
nn.Conv2d(hidden, hidden, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden),
nn.ReLU(),
nn.Conv2d(hidden, hidden, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.25),
nn.Linear(hidden * 7 * 7, out_features),
)
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
return self.classifier(x)
model = FashionCNN(in_channels=1, hidden=16, out_features=len(class_names)).to(device)
print(model)

How spatial dimensions evolve:

StageShape
Input[B, 1, 28, 28]
After block 1 (MaxPool2d(2))[B, 16, 14, 14]
After block 2 (MaxPool2d(2))[B, 16, 7, 7]
After Flatten[B, 16 × 7 × 7] = [B, 784]
After classifier[B, 10] (one logit per class)

Three new building blocks:

  • BatchNorm2d — normalizes activations across a batch. Speeds up training and stabilizes it.
  • Dropout — randomly zeros activations during training. A simple, effective regularizer.
  • Flatten — collapses every dim except the batch into one.
Note

The classifier’s input size (hidden * 7 * 7) depends on the image size and the number of pooling steps. If you change the architecture you’ll need to recompute it, or run a dummy tensor through the convolutional blocks once and read .shape:

with torch.no_grad():
dummy = torch.zeros(1, 1, 28, 28)
print(model.block_2(model.block_1(dummy)).shape)

Inspect the model

main.py
n_params = sum(p.numel() for p in model.parameters())
print(f"parameters: {n_params:,}")

You should see ~13k parameters for the CNN — tiny by modern standards.

Train and evaluate

The training loop is the same skeleton as in workflow, now with batches:

main.py
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_step(model, loader):
model.train()
total_loss, total_correct, total = 0.0, 0, 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
total_correct += (logits.argmax(dim=1) == y).sum().item()
total += x.size(0)
return total_loss / total, total_correct / total
def test_step(model, loader):
model.eval()
total_loss, total_correct, total = 0.0, 0, 0
with torch.inference_mode():
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
total_loss += loss_fn(logits, y).item() * x.size(0)
total_correct += (logits.argmax(dim=1) == y).sum().item()
total += x.size(0)
return total_loss / total, total_correct / total
EPOCHS = 5
history = []
for epoch in range(EPOCHS):
train_loss, train_acc = train_step(model, train_loader)
test_loss, test_acc = test_step(model, test_loader)
history.append((train_loss, train_acc, test_loss, test_acc))
print(
f"epoch {epoch + 1} train loss={train_loss:.4f} acc={train_acc:.2%} "
f"| test loss={test_loss:.4f} acc={test_acc:.2%}"
)

Five epochs on CPU should land around 90% test accuracy. On GPU it takes well under a minute.

Plot loss and accuracy

main.py
import numpy as np
h = np.array(history)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].plot(h[:, 0], label="train"); axes[0].plot(h[:, 2], label="test")
axes[0].set_title("loss"); axes[0].legend()
axes[1].plot(h[:, 1], label="train"); axes[1].plot(h[:, 3], label="test")
axes[1].set_title("accuracy"); axes[1].legend()
plt.show()

Visualize predictions

main.py
import random
model.eval()
samples, truths = [], []
for _ in range(9):
idx = random.randint(0, len(test_data) - 1)
image, label = test_data[idx]
samples.append(image); truths.append(label)
with torch.inference_mode():
logits = model(torch.stack(samples).to(device))
preds = logits.argmax(dim=1).cpu()
fig, axes = plt.subplots(3, 3, figsize=(8, 8))
for ax, image, truth, pred in zip(axes.flat, samples, truths, preds):
ax.imshow(image.squeeze(), cmap="gray")
color = "green" if pred == truth else "red"
ax.set_title(f"{class_names[pred]}\n(true: {class_names[truth]})", color=color)
ax.axis("off")
plt.tight_layout(); plt.show()

Misclassifications usually fall on similar-looking categories — pullover vs. coat, shirt vs. t-shirt.

Confusion matrix

main.py
import numpy as np
model.eval()
confusion = np.zeros((10, 10), dtype=int)
with torch.inference_mode():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
for t, p in zip(y.cpu().numpy(), preds.cpu().numpy()):
confusion[t, p] += 1
fig, ax = plt.subplots(figsize=(7, 7))
im = ax.imshow(confusion, cmap="Blues")
ax.set_xticks(range(10)); ax.set_xticklabels(class_names, rotation=45, ha="right")
ax.set_yticks(range(10)); ax.set_yticklabels(class_names)
for i in range(10):
for j in range(10):
ax.text(j, i, confusion[i, j], ha="center", va="center",
color="white" if confusion[i, j] > confusion.max() / 2 else "black")
plt.colorbar(im); plt.tight_layout(); plt.show()

Bright cells off the diagonal are the persistent confusions. They tell you what to fix: harder augmentation? More capacity? Better data?

Save the model

main.py
from pathlib import Path
Path("models").mkdir(exist_ok=True)
torch.save(model.state_dict(), "models/fashion_cnn.pt")

Exercises

Warm-up

  1. Train the FashionMLP for the same 5 epochs and compare its test accuracy. By how much does the CNN win?
  2. Count parameters for both models. Is the CNN much bigger?
  3. Try lr=0.01 and lr=0.0001 for the CNN. Which converges fastest?

Architecture

  1. Add a third convolutional block. What does the classifier’s input size become? Does test accuracy improve?
  2. Replace BatchNorm2d with nothing. Train again. Did it slow down convergence?
  3. Increase hidden from 16 to 64. Does the gap between train and test accuracy change?

Data augmentation

  1. Add transforms.RandomHorizontalFlip() to the training transforms. Does it help on FashionMNIST? Why might it hurt on a “left-shoe vs right-shoe” dataset?
  2. Add transforms.RandomCrop(28, padding=4) to the training transforms. How does test accuracy change?

Visualization

  1. Pick the most-confused pair from the confusion matrix and look at 8 example mistakes side by side.
  2. Visualize the 16 first-layer kernels: model.block_1[0].weight.data is a tensor of shape [16, 1, 3, 3]. Plot each kernel as a 3×3 grayscale image.

Capstone — MNIST digit recognizer

MNIST is the classic 28×28 grayscale digit dataset. We’ll train a CNN, save it, and write a small inference script that classifies digits we draw ourselves.

Train the digit model

capstone.py
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
train_tf = transforms.Compose([
transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # MNIST mean/std
])
test_tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_data = datasets.MNIST(root="data", train=True, download=True, transform=train_tf)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=test_tf)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)
class DigitCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.25),
nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10),
)
def forward(self, x):
return self.classifier(self.features(x))
model = DigitCNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 5
for epoch in range(EPOCHS):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
model.eval()
correct = total = 0
with torch.inference_mode():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
correct += (preds == y).sum().item(); total += x.size(0)
print(f"epoch {epoch + 1} test acc={correct / total:.4%}")
torch.save(model.state_dict(), "mnist_cnn.pt")

You should reach around 99% test accuracy in 5 epochs.

Inference on a hand-drawn digit

Open Paint (or any image editor), draw a single digit on a roughly square white canvas with a thick black brush, save it as my_digit.png. Then:

predict.py
from PIL import Image, ImageOps
import torch
from torchvision import transforms
# Re-import the same DigitCNN class definition.
from capstone import DigitCNN, device
model = DigitCNN().to(device)
model.load_state_dict(torch.load("mnist_cnn.pt", map_location=device))
model.eval()
def preprocess(path: str) -> torch.Tensor:
image = Image.open(path).convert("L") # grayscale
image = ImageOps.invert(image) # MNIST is white-on-black
image = image.resize((28, 28))
tensor = transforms.ToTensor()(image)
tensor = transforms.Normalize((0.1307,), (0.3081,))(tensor)
return tensor.unsqueeze(0).to(device) # add batch dim
x = preprocess("my_digit.png")
with torch.inference_mode():
probs = torch.softmax(model(x), dim=1).squeeze().cpu()
for digit, p in enumerate(probs):
print(f"{digit}: {p.item():.2%}")
print("\nprediction:", probs.argmax().item(), f"({probs.max():.2%})")

Tips for getting good predictions on your own drawings:

  • MNIST digits are white on black with a roughly centered digit. The ImageOps.invert step accounts for the white background most editors use.
  • Use a thick brush — the digit should fill maybe 60–70% of the canvas height.
  • Don’t add extra marks; the model has only seen single, isolated digits.

Stretch goals

  • Add a confusion matrix as in the FashionMNIST section.
  • Use torch.softmax(logits, dim=1).max() as a confidence and reject predictions below 0.9 with a “not sure” message.
  • Train and save a smaller version of the model (e.g. 8-channel features). How small can you go before accuracy collapses?

Recap

  • Image tensors are [C, H, W]; batches add a leading B.
  • A DataLoader yields batches efficiently; shuffle the train loader, not the test loader.
  • A convolution slides a small kernel over an image — same weights everywhere, much fewer parameters than a fully-connected equivalent.
  • Conv2d(_, _, k=3, padding=1) keeps the size; MaxPool2d(2) halves it. Memorize that pair.
  • BatchNorm2d, Dropout, and data augmentation are the three regularizers you will reach for first.
  • Always plot a confusion matrix to see what the model still gets wrong.

The next chapter, Datasets, shows how to load your own data instead of relying on the built-in collections.

References