Complete Guide to PyTorch Lightning: Deep Learning Made Simple
PyTorch Lightning is a high-level framework for PyTorch that simplifies the deep learning model training process. Lightning separates science code from engineering code, making code cleaner, more scalable, and reproducible.
In this tutorial, we'll learn PyTorch Lightning from basics to advanced usage with practical examples.
Why PyTorch Lightning?
PyTorch Lightning Advantages:
Comparison: PyTorch vs PyTorch Lightning
| Aspect | Pure PyTorch | PyTorch Lightning |
|--------|--------------|-------------------|
| Training Loop | Manual | Automatic |
| Multi-GPU | Manual implementation | 1 line change |
| Mixed Precision | Manual setup | Simple flag |
| Checkpointing | Manual | Built-in |
| Logging | Manual | Integrated |
| Code Organization | Free-form | Structured |
Installation
Install PyTorch Lightning
# Install with pip
pip install lightning
Or install with conda
conda install lightning -c conda-forge
Install with extras (for logging, etc)
pip install lightning[extra]
Specific version
pip install lightning==2.1.0
Verify Installation
import lightning as L
import torch
print(f"Lightning version: {L.version}")
print(f"PyTorch version: {torch.version}")
print(f"CUDA available: {torch.cuda.isavailable()}")
Core Concepts
Lightning Structure
PyTorch Lightning
├── LightningModule # Model + Training Logic
├── LightningDataModule # Data Loading
├── Trainer # Training Orchestration
├── Callbacks # Custom Behaviors
└── Loggers # Experiment Tracking
1. LightningModule
LightningModule is the core abstraction that combines model and training logic:
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(L.LightningModule):
def init(self, inputsize, hiddensize, numclasses, learningrate=1e-3):
super().init()
# Save hyperparameters
self.savehyperparameters()
# Define model architecture
self.layer1 = nn.Linear(inputsize, hiddensize)
self.layer2 = nn.Linear(hiddensize, hiddensize)
self.layer3 = nn.Linear(hiddensize, numclasses)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
"""Forward pass for inference."""
x = F.relu(self.layer1(x))
x = self.dropout(x)
x = F.relu(self.layer2(x))
x = self.dropout(x)
x = self.layer3(x)
return x
def trainingstep(self, batch, batchidx):
"""Single training step."""
x, y = batch
logits = self(x)
loss = F.crossentropy(logits, y)
# Log metrics
acc = (logits.argmax(dim=1) == y).float().mean()
self.log('trainloss', loss, progbar=True)
self.log('trainacc', acc, progbar=True)
return loss
def validationstep(self, batch, batchidx):
"""Single validation step."""
x, y = batch
logits = self(x)
loss = F.crossentropy(logits, y)
acc = (logits.argmax(dim=1) == y).float().mean()
self.log('valloss', loss, progbar=True)
self.log('valacc', acc, progbar=True)
return loss
def teststep(self, batch, batchidx):
"""Single test step."""
x, y = batch
logits = self(x)