Image Classification with Transfer Learning: A Comprehensive Tutorial
Table of Contents
Introduction
Transfer learning is a machine learning technique where a model trained on a large dataset is repurposed for a different but related task. In computer vision, this typically involves taking a model pre-trained on ImageNet (1.4 million images, 1000 classes) and adapting it to your specific classification problem. This approach dramatically reduces the data and compute needed to achieve high accuracy.
This tutorial covers the complete workflow: selecting a pre-trained model, preparing your data, fine-tuning with PyTorch, evaluating performance, and deploying the final model.
Prerequisites
pip install torch torchvision
pip install timm # PyTorch Image Models - extensive model zoo
pip install albumentations # Advanced augmentation
pip install scikit-learn # Metrics
pip install matplotlib seaborn # Visualization
pip install onnx onnxruntime # Export and deployment
System requirements:
- Python 3.8 or higher
- GPU with at least 6 GB VRAM (training), CPU is sufficient for inference
- Basic understanding of neural networks and PyTorch
import torch
import torchvision
import timm
print(f"PyTorch: {torch.version}")
print(f"Torchvision: {torchvision.version}")
print(f"CUDA: {torch.cuda.isavailable()}")
print(f"Timm: {timm.version}")
print(f"Available timm models: {len(timm.listmodels())}")
Understanding Transfer Learning
Transfer learning works because the early layers of a CNN learn universal features (edges, textures, patterns) that apply to almost any vision task. Only the later layers become task-specific.
There are two main strategies:
# Strategy 1: Feature extraction
def createfeatureextractor(modelname, numclasses):
"""Freeze all layers except the classification head."""
model = timm.createmodel(modelname, pretrained=True, numclasses=numclasses)
# Freeze all parameters
for param in model.parameters():
param.requiresgrad = False
# Unfreeze the classification head
if hasattr(model, 'classifier'):
for param in model.classifier.parameters():
param.requiresgrad = True
elif hasattr(model, 'fc'):
for param in model.fc.parameters():
param.requiresgrad = True
elif hasattr(model, 'head'):
for param in model.head.parameters():
param.requiresgrad = True
return model
Strategy 2: Full fine-tuning with discriminative learning rates
def createfinetuningmodel(modelname, numclasses):
"""All layers trainable, but with different learning rates."""
model = timm.createmodel(modelname, pretrained=True, numclasses=numclasses)
# All parameters are trainable by default
return model
Choosing a Pre-trained Model
Comparing Popular Architectures
import timm
def comparemodels():