PyTorch - Transfer learning

Reuse pretrained torchvision models on your own dataset to reach strong accuracy with very little data and very little training time.

Goal of the lesson

By the end of this 3-hour session you should be able to:

  • explain why pretrained features generalize across vision tasks,
  • distinguish feature extraction from fine-tuning,
  • load a pretrained torchvision model and inspect its parts,
  • replace the classifier head and freeze the backbone,
  • apply the model’s own preprocessing transforms,
  • train a classifier on a small dataset and reach >90% accuracy in minutes,
  • compare two architectures and report on the winner.

Suggested timing

BlockTopic
15 minWhat transfer learning is and why it works
25 minLoad a pretrained model, inspect its layers
25 minReplace head, freeze backbone, train
30 minEvaluate, predict, save
25 minFine-tuning the backbone
60 minCapstone — pretrained model on your own dataset

Why transfer learning

Training a CNN from scratch needs a lot of data and a lot of compute. The first few layers of any vision network learn very generic features — edges, corners, textures, colors — that are useful for almost any image task. The deeper layers combine those into more task-specific patterns.

Transfer learning is the trick that makes deep learning practical for small projects:

  1. take a model pretrained on a large dataset (typically ImageNet, ~1.3M images, 1000 classes),
  2. keep most of its weights,
  3. retrain only the final classifier on your data.

Two flavors:

ApproachTrainable parametersWhen to use
Feature extractionOnly the new classifier headSmall dataset, similar domain. Default starting point.
Fine-tuningThe new head + (some of) the backbone, with a tiny learning rateLarger dataset, or domain that drifts noticeably from ImageNet (medical, satellite, drawings).

We’ll start with feature extraction and add fine-tuning at the end.

Setup

PowerShell
uv init --python 3.12 transfer
cd transfer
uv add torch torchvision matplotlib pillow requests
main.py
from pathlib import Path
import zipfile
import requests
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

We’ll reuse the pizza/steak/sushi dataset from the datasets chapter:

main.py
DATA_PATH = Path("data")
IMAGE_PATH = DATA_PATH / "pizza_steak_sushi"
if not IMAGE_PATH.is_dir():
IMAGE_PATH.mkdir(parents=True, exist_ok=True)
url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
zip_file = DATA_PATH / "pizza_steak_sushi.zip"
zip_file.write_bytes(requests.get(url).content)
with zipfile.ZipFile(zip_file, "r") as z:
z.extractall(IMAGE_PATH)
zip_file.unlink()
train_dir = IMAGE_PATH / "train"
test_dir = IMAGE_PATH / "test"

Load a pretrained model

torchvision.models exposes architectures plus their pretrained weights. Each set of weights advertises:

  • the model itself,
  • the transforms it expects at inference time,
  • metadata (number of parameters, ImageNet accuracy).
main.py
weights = models.EfficientNet_B0_Weights.DEFAULT
print("acc on ImageNet:", weights.meta["_metrics"]["ImageNet-1K"])
print("input size :", weights.transforms())
print("classes :", len(weights.meta["categories"]))
model = models.efficientnet_b0(weights=weights).to(device)

efficientnet_b0 is a strong general-purpose CNN that runs comfortably on CPU. For larger datasets you’ll reach for efficientnet_b2, resnet50, or convnext_tiny.

Use the model’s own transforms

This is the single most important rule of transfer learning: preprocess your data exactly the way the pretrained model was trained to expect. The weights already have a “view” of the world that includes a specific resize and normalization.

main.py
auto_transforms = weights.transforms()
print(auto_transforms)
train_data = datasets.ImageFolder(root=train_dir, transform=auto_transforms)
test_data = datasets.ImageFolder(root=test_dir, transform=auto_transforms)
class_names = train_data.classes
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

auto_transforms for EfficientNet_B0_Weights.DEFAULT resizes to 256, center-crops to 224, converts to tensor, and normalizes with the ImageNet mean/std.

Inspect the model

main.py
print(model)

Two logical parts:

  • a feature extractor (model.features) — many convolutional blocks,
  • a classifier (model.classifier) — a small head that maps features to 1000 ImageNet classes.
main.py
print(model.classifier)
# Sequential(
# (0): Dropout(p=0.2, inplace=True)
# (1): Linear(in_features=1280, out_features=1000, bias=True)
# )

The 1280 is the number of features the convolutional backbone produces — it depends on the architecture. Each torchvision model has its own. We’ll need to know this number for the next step.

Feature extraction — replace the head, freeze the backbone

main.py
for parameter in model.features.parameters():
parameter.requires_grad = False
model.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(in_features=1280, out_features=len(class_names)),
).to(device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"trainable: {trainable:,} / {total:,}")

You should see only a few thousand trainable parameters out of millions — the backbone is frozen, only the new classifier learns.

Note

Freezing means requires_grad = False. The optimizer sees only the unfrozen parameters, gradient memory is tiny, and training is fast even on CPU.

Try it — different architectures

What changes if you swap in a different backbone? For each, find the in_features of its head and the file naming the head:

ModelHead attributein_features
efficientnet_b0model.classifier[1]1280
resnet18model.fc?
mobilenet_v3_smallmodel.classifier[3]?
vit_b_16model.heads.head?

Train

main.py
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train_step(model, loader):
model.train()
total_loss, total_correct, total = 0.0, 0, 0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = loss_fn(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
total_correct += (logits.argmax(dim=1) == y).sum().item()
total += x.size(0)
return total_loss / total, total_correct / total
def test_step(model, loader):
model.eval()
total_loss, total_correct, total = 0.0, 0, 0
with torch.inference_mode():
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
total_loss += loss_fn(logits, y).item() * x.size(0)
total_correct += (logits.argmax(dim=1) == y).sum().item()
total += x.size(0)
return total_loss / total, total_correct / total
EPOCHS = 5
for epoch in range(EPOCHS):
train_loss, train_acc = train_step(model, train_loader)
test_loss, test_acc = test_step(model, test_loader)
print(
f"epoch {epoch + 1} train loss={train_loss:.4f} acc={train_acc:.2%} "
f"| test loss={test_loss:.4f} acc={test_acc:.2%}"
)

You should see test accuracy jump above 90% within a few epochs — far better than what a small CNN trained from scratch on the same 225 images can manage. That is the value of transfer learning.

The total training time on CPU is on the order of a couple of minutes; on a GPU, seconds.

Predict on a new image

main.py
from PIL import Image
def predict(image_path: Path):
image = Image.open(image_path).convert("RGB")
tensor = auto_transforms(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze().cpu()
pred = probs.argmax().item()
return class_names[pred], probs[pred].item(), image
label, confidence, image = predict(next(test_dir.glob("**/*.jpg")))
plt.imshow(image)
plt.title(f"{label} ({confidence:.1%})")
plt.axis("off"); plt.show()

unsqueeze(0) adds the batch dimension — the model expects [B, C, H, W], not [C, H, W].

Save and reload

main.py
Path("models").mkdir(exist_ok=True)
torch.save(model.state_dict(), "models/efficientnet_food.pt")

When loading, you must rebuild the same architecture (with the replaced classifier) before calling load_state_dict. A small helper makes that ergonomic:

main.py
def build_model(num_classes: int) -> nn.Module:
m = models.efficientnet_b0(weights=None)
m.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True),
nn.Linear(1280, num_classes),
)
return m
reloaded = build_model(num_classes=len(class_names)).to(device)
reloaded.load_state_dict(torch.load("models/efficientnet_food.pt", map_location=device))
reloaded.eval()

Fine-tuning — when to thaw the backbone

Feature extraction is usually enough. When it isn’t:

  • the head accuracy plateaus below your target,
  • the domain is far from ImageNet (medical scans, satellite, art, line drawings).

Once the head has converged, unfreeze the rest of the network and keep training with a much smaller learning rate:

main.py
for parameter in model.features.parameters():
parameter.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(3):
train_loss, train_acc = train_step(model, train_loader)
test_loss, test_acc = test_step(model, test_loader)
print(f"fine-tune {epoch + 1} test acc={test_acc:.2%}")

Two things to remember:

  • The learning rate must be much smaller (often 100×) than the head learning rate. Otherwise the new gradients destroy the pretrained weights.
  • Don’t fine-tune for too long. The pretrained backbone is the reason this works at all; pushing too hard erases it.

You can also unfreeze only part of the backbone — typically the last block or two:

main.py
# Freeze everything
for parameter in model.parameters():
parameter.requires_grad = False
# Unfreeze just the last feature block + the classifier
for parameter in model.features[-1].parameters():
parameter.requires_grad = True
for parameter in model.classifier.parameters():
parameter.requires_grad = True

This is a good middle ground when you don’t have a lot of data: you adapt the most task-specific layers without disturbing the early generic ones.

Compare two architectures

Drop-in replacement: swap efficientnet_b0 for resnet18 and run the whole pipeline again.

main.py
weights = models.ResNet18_Weights.DEFAULT
auto_transforms = weights.transforms()
resnet = models.resnet18(weights=weights).to(device)
for p in resnet.parameters():
p.requires_grad = False
resnet.fc = nn.Linear(512, len(class_names)).to(device)
# rebuild the loaders with the new transforms, then call train_step / test_step

Track:

  • final test accuracy,
  • training time per epoch,
  • model size on disk.

Different architectures trade these off differently.

Exercises

Warm-up

  1. Replace efficientnet_b0 with resnet18. Compare final test accuracy and training time.
  2. Try mobilenet_v3_small (head: model.classifier[3], in_features 1024). How does it compare in speed and accuracy?
  3. Print the names of the layers being trained: [name for name, p in model.named_parameters() if p.requires_grad].

Architecture inspection

  1. Pass a single dummy torch.zeros(1, 3, 224, 224) through model.features and read .shape. Confirm the spatial dimensions match what the head expects.
  2. Visualize the first 16 filters of model.features[0][0] (the first Conv2d). They should look like edge / color detectors even though we never trained them.

Fine-tuning

  1. After training the head, unfreeze the last block of model.features only (model.features[-1]) and continue with lr=1e-5. Does test accuracy improve? Does training cost change much?
  2. Fine-tune the whole backbone with lr=1e-5 for one extra epoch. Then again with lr=1e-3. What goes wrong with the higher rate?

Confidence

  1. Plot the confidence (probs.max()) on every test image, separated into “correct” and “incorrect”. A well-calibrated model should be most uncertain on its mistakes.
  2. Implement a “not sure” rejection: if probs.max() < 0.6, return the string "unsure" instead of a class. Apply it to the test set and report rejection rate.

Capstone — pretrained model on your own dataset

Reuse the dataset of your own photos from the datasets capstone. Replace the from-scratch CNN with a pretrained efficientnet_b0 and observe how much better it does — same data, same training time, much better accuracy.

Step 1 — load your data with the model’s transforms

capstone.py
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
weights = models.EfficientNet_B0_Weights.DEFAULT
auto_transforms = weights.transforms()
data_root = Path("my_dataset") # the folder you assembled previously
train_data = datasets.ImageFolder(data_root / "train", transform=auto_transforms)
test_data = datasets.ImageFolder(data_root / "test", transform=auto_transforms)
class_names = train_data.classes
print(class_names)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)

Step 2 — feature extraction

capstone.py
model = models.efficientnet_b0(weights=weights).to(device)
for p in model.features.parameters():
p.requires_grad = False
model.classifier = nn.Sequential(
nn.Dropout(0.2, inplace=True),
nn.Linear(1280, len(class_names)),
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def epoch():
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
model.eval()
correct = total = 0
with torch.inference_mode():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(dim=1)
correct += (preds == y).sum().item(); total += x.size(0)
return correct / total
for e in range(10):
print(f"epoch {e + 1} test acc={epoch():.2%}")

Compare the test accuracy with what the from-scratch CNN reached on the same data.

Step 3 — short fine-tune

After the head converges, unfreeze just the last block and continue with a low learning rate:

capstone.py
for p in model.features[-1].parameters():
p.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for e in range(3):
print(f"fine-tune {e + 1} test acc={epoch():.2%}")

Step 4 — predict on a new photo

capstone.py
from PIL import Image
def predict(path: Path) -> tuple[str, float]:
image = Image.open(path).convert("RGB")
tensor = auto_transforms(image).unsqueeze(0).to(device)
model.eval()
with torch.inference_mode():
probs = torch.softmax(model(tensor), dim=1).squeeze().cpu()
return class_names[probs.argmax().item()], probs.max().item()
print(predict(Path("test_photo.jpg")))

Step 5 — save and write a small CLI

capstone.py
from pathlib import Path
import sys
import json
torch.save(model.state_dict(), "my_classifier.pt")
Path("classes.json").write_text(json.dumps(class_names))
# tiny inference script — call as: python predict_cli.py path/to/photo.jpg
predict_cli.py
import json, sys
from pathlib import Path
import torch
from torch import nn
from torchvision import models
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
class_names = json.loads(Path("classes.json").read_text())
weights = models.EfficientNet_B0_Weights.DEFAULT
auto_transforms = weights.transforms()
model = models.efficientnet_b0(weights=None)
model.classifier = nn.Sequential(nn.Dropout(0.2, inplace=True), nn.Linear(1280, len(class_names)))
model.load_state_dict(torch.load("my_classifier.pt", map_location=device))
model.eval().to(device)
image = Image.open(sys.argv[1]).convert("RGB")
tensor = auto_transforms(image).unsqueeze(0).to(device)
with torch.inference_mode():
probs = torch.softmax(model(tensor), dim=1).squeeze().cpu()
print(class_names[probs.argmax().item()], f"{probs.max():.1%}")

You now have a small classifier that runs on any laptop, predicts in milliseconds, and was trained in minutes from a few dozen photos.

Stretch goals

  • Add a confusion matrix and identify the class the model still struggles with.
  • Compare two backbones (efficientnet_b0 vs resnet18) on your dataset and report final accuracy + size on disk.
  • Add transforms.RandomHorizontalFlip() to the training transforms only (you’ll need to wrap auto_transforms so it doesn’t run before the augmentation). Does it help?

Recap

  • Pretrained models give you ~99% of the work for free. Use them as a starting point on every realistic vision task.
  • Always use the model’s own transforms for input preprocessing.
  • Default approach: replace the head, freeze the backbone, train the head with lr=1e-3.
  • Fine-tune (unfreeze part of the backbone with a tiny lr) only after the head converges, and only if you have data to support it.
  • The whole transfer-learning workflow is the same five steps as before — only the model setup is special.

That ends the core series. The patterns you have learned cover most everyday computer-vision projects. From here, take a real problem you care about, follow the same workflow, and don’t be surprised when it works.

References