Model Training: Best Practices and Patterns
Published: November 2025 | 25 min read
The Model Training Lifecycle
Model training is the core of any machine learning project, where algorithms learn patterns from data to make predictions or decisions. A well-structured training pipeline is crucial for developing robust, maintainable, and high-performing models.
Key Components of Model Training
- Data Preparation
- Feature engineering
- Train/validation/test splits
- Data augmentation
-
Class imbalance handling
-
Model Architecture
- Model selection
- Architecture design
- Custom layers/blocks
-
Transfer learning
-
Training Process
- Loss functions
- Optimizers
- Learning rate scheduling
-
Regularization techniques
-
Evaluation
- Metrics selection
- Cross-validation
- Error analysis
- Model interpretation
Implementation with PyTorch Lightning
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import datasets, transforms
class LitModel(pl.LightningModule):
def __init__(self, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
# Model architecture
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
# Metrics
self.train_accuracy = Accuracy(task="multiclass", num_classes=10)
self.val_accuracy = Accuracy(task="multiclass", num_classes=10)
self.test_accuracy = Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = self.train_accuracy(preds, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = self.val_accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = self.test_accuracy(preds, y)
self.log('test_loss', loss, prog_bar=True)
self.log('test_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.1, patience=3, verbose=True
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'frequency': 1
}
}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir='./data', batch_size=64):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
# Download data
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets
if stage == 'fit' or stage is None:
mnist_full = datasets.MNIST(
self.data_dir, train=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000]
)
# Assign test dataset
if stage == 'test' or stage is None:
self.mnist_test = datasets.MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=4
)
def val_dataloader(self):
return DataLoader(
self.mnist_val,
batch_size=self.batch_size,
num_workers=4
)
def test_dataloader(self):
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=4
)
def train_model():
# Initialize data module and model
dm = MNISTDataModule(batch_size=64)
model = LitModel(learning_rate=1e-3)
# Initialize callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='val_loss',
dirpath='checkpoints/',
filename='mnist-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min',
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
verbose=True,
mode='min'
)
# Initialize trainer
trainer = pl.Trainer(
max_epochs=20,
callbacks=[checkpoint_callback, early_stop_callback],
accelerator='auto',
devices=1 if torch.cuda.is_available() else None,
log_every_n_steps=50,
deterministic=True
)
# Train the model
trainer.fit(model, dm)
# Test the model
trainer.test(datamodule=dm)
return model
if __name__ == "__main__":
model = train_model()
Advanced Model Training Techniques
1. Mixed Precision Training
trainer = pl.Trainer(
precision=16, # Enable mixed precision
amp_backend='native',
# ... other trainer args
)
2. Distributed Training
trainer = pl.Trainer(
strategy='ddp', # Data Distributed Parallel
accelerator='gpu',
devices=4, # Number of GPUs
# ... other trainer args
)
3. Gradient Clipping
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return {
'optimizer': optimizer,
'gradient_clip_val': 0.5,
'gradient_clip_algorithm': 'norm'
}
Model Training Best Practices
- Reproducibility
- Set random seeds
- Version control everything
-
Log hyperparameters
-
Monitoring
- Track training metrics
- Visualize learning curves
-
Monitor hardware utilization
-
Regularization
- Dropout
- Weight decay
- Early stopping
-
Data augmentation
-
Hyperparameter Tuning
- Grid/Random search
- Bayesian optimization
- Population based training
Common Challenges and Solutions
| Challenge | Solution |
|---|---|
| Overfitting | Increase dropout, add L2 regularization, use more data |
| Underfitting | Increase model capacity, train longer, reduce regularization |
| Vanishing Gradients | Use proper weight initialization, batch normalization |
| Exploding Gradients | Gradient clipping, weight normalization |
| Slow Training | Mixed precision, larger batch size, better optimizer |
| High Memory Usage | Gradient checkpointing, smaller batch size, model parallelism |
Next Steps
- Implement model versioning
- Set up model evaluation pipelines
- Create model cards for documentation
- Plan for model monitoring in production