8.1.3. PyTorch Model Persistence#
PyTorch handles serialization differently from scikit-learn. A scikit-learn model is a single Python object that bundles its parameters with the code to use them. In PyTorch, the model architecture is defined in code (as a Python class), while the parameters (weights and biases learned during training) are stored separately. This separation is intentional: it gives you greater flexibility to change the code without losing the trained weights, and to load the same weights into different versions of an architecture for fine-tuning or analysis.
The practical implication is that saving and loading a PyTorch model always requires two steps: saving the weights to a file, and later re-instantiating the architecture in code before loading the weights into it.
8.1.3.1. Two Approaches to Saving PyTorch Models#
PyTorch offers two mechanisms:
Save the state dict: Save only the learned parameters (weights and biases). The architecture must be recreated from code before loading.
Save the entire model: Save both architecture and parameters together as a pickled object.
The PyTorch documentation recommends saving the state dict for better flexibility and maintainability. Saving the entire model couples the file to a specific class definition and Python environment, making it fragile across refactoring or version changes.
8.1.3.2. State Dict: The Recommended Approach#
A state dictionary (state_dict) is a Python dictionary that maps each layer to its parameter tensors.
Basic Pattern#
import torch
import torch.nn as nn
# Define model architecture
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Create and train model
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
# ... training code ...
# Save only the state dict
torch.save(model.state_dict(), 'model_weights.pth')
# Load the state dict
model = SimpleNN(input_size=10, hidden_size=50, output_size=2) # Recreate architecture
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # Set to evaluation mode
Why State Dict is Preferred#
More flexible: Separate architecture code can be updated
More portable: Works across PyTorch versions better
Clearer: Explicit about what’s being saved
Better practices: Encourages modular code
8.1.3.3. Saving Additional Information#
In practice, you should save more than just weights—include optimizer state, training epoch, and model configuration.
Complete Checkpoint#
import torch
import torch.nn as nn
import torch.optim as optim
# Training setup
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(num_epochs):
# Training code...
train_loss = train_one_epoch(model, train_loader, optimizer)
val_loss = validate(model, val_loader)
# Save comprehensive checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'model_config': {
'input_size': 10,
'hidden_size': 50,
'output_size': 2
}
}
torch.save(checkpoint, 'checkpoint.pth')
Loading a Checkpoint#
# Recreate model and optimizer
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
# Resume training or evaluate
model.train() # For training
# or
model.eval() # For inference
8.1.3.4. Inference Mode: Saving for Deployment#
During training you need the optimizer state and loss history to resume training from a checkpoint. When deploying for inference, you only need what is required to make predictions: the model weights, the architecture configuration, and the preprocessing parameters. Saving a leaner inference package keeps your deployment artefact small and focused.
import torch
# After training is complete
model.eval()
# Save for inference
inference_package = {
'model_state_dict': model.state_dict(),
'model_config': {
'input_size': 10,
'hidden_size': 50,
'output_size': 2
},
'input_mean': train_data.mean(axis=0).tolist(), # For normalization
'input_std': train_data.std(axis=0).tolist(),
'class_names': ['Class A', 'Class B']
}
torch.save(inference_package, 'model_inference.pth')
# Loading for inference
package = torch.load('model_inference.pth')
# Recreate model from config
config = package['model_config']
model = SimpleNN(
input_size=config['input_size'],
hidden_size=config['hidden_size'],
output_size=config['output_size']
)
model.load_state_dict(package['model_state_dict'])
model.eval()
# Make predictions with proper preprocessing
def predict(input_data):
# Normalize using saved statistics
normalized = (input_data - package['input_mean']) / package['input_std']
with torch.no_grad(): # Disable gradient computation
input_tensor = torch.FloatTensor(normalized)
output = model(input_tensor)
predicted_class = output.argmax(dim=1).item()
return package['class_names'][predicted_class]
8.1.3.5. CPU vs GPU: Device Handling#
Models trained on one hardware configuration need explicit handling when loaded elsewhere. A model trained on a GPU is stored with CUDA tensor types; loading it on a CPU requires remapping those tensors. PyTorch’s map_location argument handles this cleanly.
Save on GPU, Load on CPU#
# Save (on GPU)
torch.save(model.state_dict(), 'model.pth')
# Load on CPU
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()
Save on CPU, Load on GPU#
# Save (on CPU)
torch.save(model.state_dict(), 'model.pth')
# Load on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()
Flexible Device Loading#
def load_model_flexible(model_class, model_path, device=None):
"""Load model on appropriate device."""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint['model_config']
model = model_class(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, device
# Usage
model, device = load_model_flexible(SimpleNN, 'checkpoint.pth')
8.1.3.6. Saving Entire Model (Alternative Approach)#
While not recommended, you can save the entire model including architecture.
# Save entire model
torch.save(model, 'entire_model.pth')
# Load entire model
model = torch.load('entire_model.pth')
model.eval()
Drawbacks of Saving Entire Model#
Less flexible: Architecture changes break loading
Version sensitivity: May fail across PyTorch versions
Larger files: Includes redundant information
Dependencies: Requires same class definitions
8.1.3.7. Complete Example: Production-Ready PyTorch Saving#
import torch
import torch.nn as nn
from datetime import datetime
import json
class ModelSaver:
"""Utility for saving PyTorch models with metadata."""
@staticmethod
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss,
model_config, filepath, additional_info=None):
"""Save training checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'model_config': model_config,
'timestamp': datetime.now().isoformat(),
'pytorch_version': torch.__version__,
'additional_info': additional_info or {}
}
torch.save(checkpoint, filepath)
@staticmethod
def save_for_inference(model, model_config, preprocessing_params,
filepath, class_names=None):
"""Save model for deployment/inference."""
model.eval()
package = {
'model_state_dict': model.state_dict(),
'model_config': model_config,
'preprocessing': preprocessing_params,
'class_names': class_names,
'timestamp': datetime.now().isoformat(),
'pytorch_version': torch.__version__
}
torch.save(package, filepath)
# Save metadata as JSON
metadata = {k: v for k, v in package.items()
if k != 'model_state_dict'}
with open(filepath.replace('.pth', '_metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2, default=str)
@staticmethod
def load_for_inference(model_class, filepath, device=None):
"""Load model for inference."""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
package = torch.load(filepath, map_location=device)
# Recreate model
model = model_class(**package['model_config'])
model.load_state_dict(package['model_state_dict'])
model.to(device)
model.eval()
return model, package, device
# Usage
model_config = {'input_size': 10, 'hidden_size': 50, 'output_size': 2}
preprocessing = {'mean': [0.5] * 10, 'std': [0.2] * 10}
ModelSaver.save_for_inference(
model,
model_config,
preprocessing,
'models/trained_model.pth',
class_names=['Negative', 'Positive']
)
# Later, load for inference
model, package, device = ModelSaver.load_for_inference(
SimpleNN,
'models/trained_model.pth'
)
8.1.3.8. Beyond Basics: Advanced Topics#
For production deployments, consider these advanced techniques (covered in PyTorch documentation):
TorchScript: Serialize models for production deployment
ONNX Export: Convert to ONNX format for cross-framework compatibility
Model Quantization: Reduce model size for deployment
TorchServe: Serve PyTorch models at scale
8.1.3.9. Summary#
Use
state_dictfor saving model weights (recommended)Save model configuration to recreate architecture
Include preprocessing parameters for deployment
Handle device mapping (CPU/GPU) properly
Use
model.eval()andtorch.no_grad()for inferenceSave metadata for tracking and debugging
Proper PyTorch model persistence ensures models can be reliably loaded and used across different environments and hardware configurations.