Train convolutional neural networks on image data with PyTorch and torchvision — from FashionMNIST warm-up to an MNIST capstone.
On this page
Goal of the lesson
By the end of this 3-hour session you should be able to:
- read images as
[C, H, W]tensors and feed them to a model in batches, - explain what a convolution does and why CNNs beat plain MLPs on images,
- compute the spatial output size of a
Conv2d/MaxPool2dlayer, - assemble a small CNN with
Conv2d,MaxPool2d,BatchNorm2d,Dropout, - train it on FashionMNIST with batched loaders,
- visualize predictions and a confusion matrix,
- build an MNIST digit recognizer as a capstone and test it on hand-drawn digits.
Suggested timing
| Block | Topic |
|---|---|
| 20 min | Image tensors, transforms, DataLoader |
| 25 min | Convolution intuition: kernels, strides, padding |
| 25 min | Build a CNN, trace shapes through it |
| 35 min | Train and evaluate on FashionMNIST |
| 25 min | Visualize predictions, confusion matrix |
| 50 min | Capstone — MNIST digit recognizer |
Setup
uv init --python 3.12 visioncd visionuv add torch torchvision matplotlibimport matplotlib.pyplot as pltimport torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"torch.manual_seed(42)Image tensors
PyTorch images are [C, H, W] tensors:
C= channels (1 for grayscale, 3 for RGB),H= height in pixels,W= width in pixels.
torchvision.datasets ships ready-to-use datasets that download themselves on first run.
train_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=transforms.ToTensor(),)
test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=transforms.ToTensor(),)
class_names = train_data.classesprint(len(train_data), len(test_data))print(class_names)transforms.ToTensor() converts a PIL image into a [C, H, W] tensor with values in [0, 1].
image, label = train_data[0]print("image shape:", image.shape) # torch.Size([1, 28, 28])print("label :", label, class_names[label])
plt.imshow(image.squeeze(), cmap="gray")plt.title(class_names[label])plt.axis("off"); plt.show()squeeze() drops the singleton channel dim so matplotlib can plot a (H, W) array.
DataLoader
A DataLoader wraps a dataset and yields batches — small groups of samples — which is what we feed to the model during training.
BATCH_SIZE = 32
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
batch_images, batch_labels = next(iter(train_loader))print(batch_images.shape, batch_labels.shape)# torch.Size([32, 1, 28, 28]) torch.Size([32])Why batches?
- Memory. A whole dataset rarely fits in GPU memory.
- Speed. GPUs are happiest with hundreds of operations at once.
- Optimization. Stochastic gradient descent updates from one batch at a time. The noise from batch-to-batch variation actually helps generalization.
shuffle=True on the training loader prevents the model from memorising the order. The test loader does not shuffle — it isn’t being learned from.
Try it — read shapes
For each line, predict the shape before running:
print(train_data[5][0].shape)print(torch.stack([img for img, _ in [train_data[i] for i in range(8)]]).shape)print(next(iter(train_loader))[0].shape)torch.Size([1, 28, 28])— single image, one channel.torch.Size([8, 1, 28, 28])—stackadds a new outer (batch) dim.torch.Size([32, 1, 28, 28])— batch of 32 from the loader.
Convolution intuition
A linear layer treats the input as a flat vector. That throws away the spatial structure of an image: pixel (0, 0) and pixel (0, 1) are siblings, while pixel (0, 0) and pixel (27, 27) are far apart.
A convolutional layer keeps the structure. It slides a small kernel (e.g. 3×3) over the image and computes a weighted sum at every position. The same kernel is applied everywhere — so a “vertical edge detector” works the same in the top-left as in the bottom-right.
A convolutional network stacks many such layers. Early layers learn simple features (edges, corners). Deep layers combine them into more abstract features (shapes, parts, whole objects).
Output size formula
Given an input with side H, a kernel of size K, padding P, and stride S, the output side is:
out = (H + 2P - K) / S + 1Two common configurations:
| Layer | What stays / shrinks |
|---|---|
Conv2d(in, out, kernel_size=3, padding=1) | Spatial size stays the same |
MaxPool2d(kernel_size=2, stride=2) | Spatial size halves |
Memorize these two — almost every CNN you’ll write uses them.
Try it — compute output sizes
For each, compute the output spatial size:
| Input | Layer | Output |
|---|---|---|
28×28 | Conv2d(_, _, k=3, padding=1) | ? |
28×28 | Conv2d(_, _, k=5, padding=0) | ? |
28×28 | MaxPool2d(2) | ? |
64×64 | Conv2d(_, _, k=3, stride=2, padding=1) | ? |
28×28— same size, that’s the point ofpadding=1withk=3.24×24—(28 + 0 - 5) / 1 + 1 = 24.14×14— pooling halves both dims.32×32—(64 + 2 - 3) / 2 + 1 = 32.
A first model — flatten + Linear (baseline)
Before reaching for convolutions, try the simpler approach: flatten the image and feed it to a fully connected network.
class FashionMLP(nn.Module): def __init__(self, in_features: int, hidden: int, out_features: int): super().__init__() self.net = nn.Sequential( nn.Flatten(), nn.Linear(in_features, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, out_features), )
def forward(self, x): return self.net(x)
mlp = FashionMLP(in_features=28 * 28, hidden=128, out_features=len(class_names)).to(device)This works, but it discards the spatial structure of the image. We’ll come back to it for comparison.
Build a CNN
class FashionCNN(nn.Module): def __init__(self, in_channels: int, hidden: int, out_features: int): super().__init__() self.block_1 = nn.Sequential( nn.Conv2d(in_channels, hidden, kernel_size=3, padding=1), nn.BatchNorm2d(hidden), nn.ReLU(), nn.Conv2d(hidden, hidden, kernel_size=3, padding=1), nn.BatchNorm2d(hidden), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.block_2 = nn.Sequential( nn.Conv2d(hidden, hidden, kernel_size=3, padding=1), nn.BatchNorm2d(hidden), nn.ReLU(), nn.Conv2d(hidden, hidden, kernel_size=3, padding=1), nn.BatchNorm2d(hidden), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.25), nn.Linear(hidden * 7 * 7, out_features), )
def forward(self, x): x = self.block_1(x) x = self.block_2(x) return self.classifier(x)
model = FashionCNN(in_channels=1, hidden=16, out_features=len(class_names)).to(device)print(model)How spatial dimensions evolve:
| Stage | Shape |
|---|---|
| Input | [B, 1, 28, 28] |
After block 1 (MaxPool2d(2)) | [B, 16, 14, 14] |
After block 2 (MaxPool2d(2)) | [B, 16, 7, 7] |
After Flatten | [B, 16 × 7 × 7] = [B, 784] |
| After classifier | [B, 10] (one logit per class) |
Three new building blocks:
BatchNorm2d— normalizes activations across a batch. Speeds up training and stabilizes it.Dropout— randomly zeros activations during training. A simple, effective regularizer.Flatten— collapses every dim except the batch into one.
The classifier’s input size (hidden * 7 * 7) depends on the image size and the number of pooling steps. If you change the architecture you’ll need to recompute it, or run a dummy tensor through the convolutional blocks once and read .shape:
with torch.no_grad(): dummy = torch.zeros(1, 1, 28, 28) print(model.block_2(model.block_1(dummy)).shape)Inspect the model
n_params = sum(p.numel() for p in model.parameters())print(f"parameters: {n_params:,}")You should see ~13k parameters for the CNN — tiny by modern standards.
Train and evaluate
The training loop is the same skeleton as in workflow, now with batches:
loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_step(model, loader): model.train() total_loss, total_correct, total = 0.0, 0, 0 for x, y in loader: x, y = x.to(device), y.to(device) logits = model(x) loss = loss_fn(logits, y)
optimizer.zero_grad() loss.backward() optimizer.step()
total_loss += loss.item() * x.size(0) total_correct += (logits.argmax(dim=1) == y).sum().item() total += x.size(0) return total_loss / total, total_correct / total
def test_step(model, loader): model.eval() total_loss, total_correct, total = 0.0, 0, 0 with torch.inference_mode(): for x, y in loader: x, y = x.to(device), y.to(device) logits = model(x) total_loss += loss_fn(logits, y).item() * x.size(0) total_correct += (logits.argmax(dim=1) == y).sum().item() total += x.size(0) return total_loss / total, total_correct / total
EPOCHS = 5history = []for epoch in range(EPOCHS): train_loss, train_acc = train_step(model, train_loader) test_loss, test_acc = test_step(model, test_loader) history.append((train_loss, train_acc, test_loss, test_acc)) print( f"epoch {epoch + 1} train loss={train_loss:.4f} acc={train_acc:.2%} " f"| test loss={test_loss:.4f} acc={test_acc:.2%}" )Five epochs on CPU should land around 90% test accuracy. On GPU it takes well under a minute.
Plot loss and accuracy
import numpy as np
h = np.array(history)fig, axes = plt.subplots(1, 2, figsize=(10, 4))axes[0].plot(h[:, 0], label="train"); axes[0].plot(h[:, 2], label="test")axes[0].set_title("loss"); axes[0].legend()axes[1].plot(h[:, 1], label="train"); axes[1].plot(h[:, 3], label="test")axes[1].set_title("accuracy"); axes[1].legend()plt.show()Visualize predictions
import random
model.eval()samples, truths = [], []for _ in range(9): idx = random.randint(0, len(test_data) - 1) image, label = test_data[idx] samples.append(image); truths.append(label)
with torch.inference_mode(): logits = model(torch.stack(samples).to(device)) preds = logits.argmax(dim=1).cpu()
fig, axes = plt.subplots(3, 3, figsize=(8, 8))for ax, image, truth, pred in zip(axes.flat, samples, truths, preds): ax.imshow(image.squeeze(), cmap="gray") color = "green" if pred == truth else "red" ax.set_title(f"{class_names[pred]}\n(true: {class_names[truth]})", color=color) ax.axis("off")plt.tight_layout(); plt.show()Misclassifications usually fall on similar-looking categories — pullover vs. coat, shirt vs. t-shirt.
Confusion matrix
import numpy as np
model.eval()confusion = np.zeros((10, 10), dtype=int)with torch.inference_mode(): for x, y in test_loader: x, y = x.to(device), y.to(device) preds = model(x).argmax(dim=1) for t, p in zip(y.cpu().numpy(), preds.cpu().numpy()): confusion[t, p] += 1
fig, ax = plt.subplots(figsize=(7, 7))im = ax.imshow(confusion, cmap="Blues")ax.set_xticks(range(10)); ax.set_xticklabels(class_names, rotation=45, ha="right")ax.set_yticks(range(10)); ax.set_yticklabels(class_names)for i in range(10): for j in range(10): ax.text(j, i, confusion[i, j], ha="center", va="center", color="white" if confusion[i, j] > confusion.max() / 2 else "black")plt.colorbar(im); plt.tight_layout(); plt.show()Bright cells off the diagonal are the persistent confusions. They tell you what to fix: harder augmentation? More capacity? Better data?
Save the model
from pathlib import Path
Path("models").mkdir(exist_ok=True)torch.save(model.state_dict(), "models/fashion_cnn.pt")Exercises
Warm-up
- Train the
FashionMLPfor the same 5 epochs and compare its test accuracy. By how much does the CNN win? - Count parameters for both models. Is the CNN much bigger?
- Try
lr=0.01andlr=0.0001for the CNN. Which converges fastest?
Architecture
- Add a third convolutional block. What does the classifier’s input size become? Does test accuracy improve?
- Replace
BatchNorm2dwith nothing. Train again. Did it slow down convergence? - Increase
hiddenfrom 16 to 64. Does the gap between train and test accuracy change?
Data augmentation
- Add
transforms.RandomHorizontalFlip()to the training transforms. Does it help on FashionMNIST? Why might it hurt on a “left-shoe vs right-shoe” dataset? - Add
transforms.RandomCrop(28, padding=4)to the training transforms. How does test accuracy change?
Visualization
- Pick the most-confused pair from the confusion matrix and look at 8 example mistakes side by side.
- Visualize the 16 first-layer kernels:
model.block_1[0].weight.datais a tensor of shape[16, 1, 3, 3]. Plot each kernel as a 3×3 grayscale image.
Capstone — MNIST digit recognizer
MNIST is the classic 28×28 grayscale digit dataset. We’ll train a CNN, save it, and write a small inference script that classifies digits we draw ourselves.
Train the digit model
import torchfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"torch.manual_seed(0)
train_tf = transforms.Compose([ transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), # MNIST mean/std])test_tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),])
train_data = datasets.MNIST(root="data", train=True, download=True, transform=train_tf)test_data = datasets.MNIST(root="data", train=False, download=True, transform=test_tf)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)test_loader = DataLoader(test_data, batch_size=128, shuffle=False)
class DigitCNN(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.25), nn.Linear(64 * 7 * 7, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 10), )
def forward(self, x): return self.classifier(self.features(x))
model = DigitCNN().to(device)loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 5for epoch in range(EPOCHS): model.train() for x, y in train_loader: x, y = x.to(device), y.to(device) loss = loss_fn(model(x), y) optimizer.zero_grad(); loss.backward(); optimizer.step()
model.eval() correct = total = 0 with torch.inference_mode(): for x, y in test_loader: x, y = x.to(device), y.to(device) preds = model(x).argmax(dim=1) correct += (preds == y).sum().item(); total += x.size(0) print(f"epoch {epoch + 1} test acc={correct / total:.4%}")
torch.save(model.state_dict(), "mnist_cnn.pt")You should reach around 99% test accuracy in 5 epochs.
Inference on a hand-drawn digit
Open Paint (or any image editor), draw a single digit on a roughly square white canvas with a thick black brush, save it as my_digit.png. Then:
from PIL import Image, ImageOpsimport torchfrom torchvision import transforms
# Re-import the same DigitCNN class definition.from capstone import DigitCNN, device
model = DigitCNN().to(device)model.load_state_dict(torch.load("mnist_cnn.pt", map_location=device))model.eval()
def preprocess(path: str) -> torch.Tensor: image = Image.open(path).convert("L") # grayscale image = ImageOps.invert(image) # MNIST is white-on-black image = image.resize((28, 28)) tensor = transforms.ToTensor()(image) tensor = transforms.Normalize((0.1307,), (0.3081,))(tensor) return tensor.unsqueeze(0).to(device) # add batch dim
x = preprocess("my_digit.png")with torch.inference_mode(): probs = torch.softmax(model(x), dim=1).squeeze().cpu()
for digit, p in enumerate(probs): print(f"{digit}: {p.item():.2%}")print("\nprediction:", probs.argmax().item(), f"({probs.max():.2%})")Tips for getting good predictions on your own drawings:
- MNIST digits are white on black with a roughly centered digit. The
ImageOps.invertstep accounts for the white background most editors use. - Use a thick brush — the digit should fill maybe 60–70% of the canvas height.
- Don’t add extra marks; the model has only seen single, isolated digits.
Stretch goals
- Add a confusion matrix as in the FashionMNIST section.
- Use
torch.softmax(logits, dim=1).max()as a confidence and reject predictions below 0.9 with a “not sure” message. - Train and save a smaller version of the model (e.g. 8-channel features). How small can you go before accuracy collapses?
Recap
- Image tensors are
[C, H, W]; batches add a leadingB. - A
DataLoaderyields batches efficiently; shuffle the train loader, not the test loader. - A convolution slides a small kernel over an image — same weights everywhere, much fewer parameters than a fully-connected equivalent.
Conv2d(_, _, k=3, padding=1)keeps the size;MaxPool2d(2)halves it. Memorize that pair.BatchNorm2d,Dropout, and data augmentation are the three regularizers you will reach for first.- Always plot a confusion matrix to see what the model still gets wrong.
The next chapter, Datasets, shows how to load your own data instead of relying on the built-in collections.