8.1.3. PyTorch Model Persistence#

PyTorch handles serialization differently from scikit-learn. A scikit-learn model is a single Python object that bundles its parameters with the code to use them. In PyTorch, the model architecture is defined in code (as a Python class), while the parameters (weights and biases learned during training) are stored separately. This separation is intentional: it gives you greater flexibility to change the code without losing the trained weights, and to load the same weights into different versions of an architecture for fine-tuning or analysis.

The practical implication is that saving and loading a PyTorch model always requires two steps: saving the weights to a file, and later re-instantiating the architecture in code before loading the weights into it.

8.1.3.1. Two Approaches to Saving PyTorch Models#

PyTorch offers two mechanisms:

  • Save the state dict: Save only the learned parameters (weights and biases). The architecture must be recreated from code before loading.

  • Save the entire model: Save both architecture and parameters together as a pickled object.

The PyTorch documentation recommends saving the state dict for better flexibility and maintainability. Saving the entire model couples the file to a specific class definition and Python environment, making it fragile across refactoring or version changes.

8.1.3.3. Saving Additional Information#

In practice, you should save more than just weights—include optimizer state, training epoch, and model configuration.

Complete Checkpoint#

import torch
import torch.nn as nn
import torch.optim as optim

# Training setup
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    # Training code...
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss = validate(model, val_loader)

# Save comprehensive checkpoint
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': train_loss,
    'val_loss': val_loss,
    'model_config': {
        'input_size': 10,
        'hidden_size': 50,
        'output_size': 2
    }
}

torch.save(checkpoint, 'checkpoint.pth')

Loading a Checkpoint#

# Recreate model and optimizer
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load checkpoint
checkpoint = torch.load('checkpoint.pth')

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']

# Resume training or evaluate
model.train()  # For training
# or
model.eval()  # For inference

8.1.3.4. Inference Mode: Saving for Deployment#

During training you need the optimizer state and loss history to resume training from a checkpoint. When deploying for inference, you only need what is required to make predictions: the model weights, the architecture configuration, and the preprocessing parameters. Saving a leaner inference package keeps your deployment artefact small and focused.

import torch

# After training is complete
model.eval()

# Save for inference
inference_package = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_size': 10,
        'hidden_size': 50,
        'output_size': 2
    },
    'input_mean': train_data.mean(axis=0).tolist(),  # For normalization
    'input_std': train_data.std(axis=0).tolist(),
    'class_names': ['Class A', 'Class B']
}

torch.save(inference_package, 'model_inference.pth')

# Loading for inference
package = torch.load('model_inference.pth')

# Recreate model from config
config = package['model_config']
model = SimpleNN(
    input_size=config['input_size'],
    hidden_size=config['hidden_size'],
    output_size=config['output_size']
)

model.load_state_dict(package['model_state_dict'])
model.eval()

# Make predictions with proper preprocessing
def predict(input_data):
    # Normalize using saved statistics
    normalized = (input_data - package['input_mean']) / package['input_std']
    
    with torch.no_grad():  # Disable gradient computation
        input_tensor = torch.FloatTensor(normalized)
        output = model(input_tensor)
        predicted_class = output.argmax(dim=1).item()
    
    return package['class_names'][predicted_class]

8.1.3.5. CPU vs GPU: Device Handling#

Models trained on one hardware configuration need explicit handling when loaded elsewhere. A model trained on a GPU is stored with CUDA tensor types; loading it on a CPU requires remapping those tensors. PyTorch’s map_location argument handles this cleanly.

Save on GPU, Load on CPU#

# Save (on GPU)
torch.save(model.state_dict(), 'model.pth')

# Load on CPU
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()

Save on CPU, Load on GPU#

# Save (on CPU)
torch.save(model.state_dict(), 'model.pth')

# Load on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN(input_size=10, hidden_size=50, output_size=2)
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()

Flexible Device Loading#

def load_model_flexible(model_class, model_path, device=None):
    """Load model on appropriate device."""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    checkpoint = torch.load(model_path, map_location=device)
    
    config = checkpoint['model_config']
    model = model_class(**config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, device

# Usage
model, device = load_model_flexible(SimpleNN, 'checkpoint.pth')

8.1.3.6. Saving Entire Model (Alternative Approach)#

While not recommended, you can save the entire model including architecture.

# Save entire model
torch.save(model, 'entire_model.pth')

# Load entire model
model = torch.load('entire_model.pth')
model.eval()

Drawbacks of Saving Entire Model#

  • Less flexible: Architecture changes break loading

  • Version sensitivity: May fail across PyTorch versions

  • Larger files: Includes redundant information

  • Dependencies: Requires same class definitions

8.1.3.7. Complete Example: Production-Ready PyTorch Saving#

import torch
import torch.nn as nn
from datetime import datetime
import json

class ModelSaver:
    """Utility for saving PyTorch models with metadata."""
    
    @staticmethod
    def save_checkpoint(model, optimizer, epoch, train_loss, val_loss,
                       model_config, filepath, additional_info=None):
        """Save training checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'model_config': model_config,
            'timestamp': datetime.now().isoformat(),
            'pytorch_version': torch.__version__,
            'additional_info': additional_info or {}
        }
        
        torch.save(checkpoint, filepath)
    
    @staticmethod
    def save_for_inference(model, model_config, preprocessing_params,
                          filepath, class_names=None):
        """Save model for deployment/inference."""
        model.eval()
        
        package = {
            'model_state_dict': model.state_dict(),
            'model_config': model_config,
            'preprocessing': preprocessing_params,
            'class_names': class_names,
            'timestamp': datetime.now().isoformat(),
            'pytorch_version': torch.__version__
        }
        
        torch.save(package, filepath)
        
        # Save metadata as JSON
        metadata = {k: v for k, v in package.items() 
                   if k != 'model_state_dict'}
        with open(filepath.replace('.pth', '_metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2, default=str)
        
    
    @staticmethod
    def load_for_inference(model_class, filepath, device=None):
        """Load model for inference."""
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        package = torch.load(filepath, map_location=device)
        
        # Recreate model
        model = model_class(**package['model_config'])
        model.load_state_dict(package['model_state_dict'])
        model.to(device)
        model.eval()
        
        
        return model, package, device

# Usage
model_config = {'input_size': 10, 'hidden_size': 50, 'output_size': 2}
preprocessing = {'mean': [0.5] * 10, 'std': [0.2] * 10}

ModelSaver.save_for_inference(
    model, 
    model_config, 
    preprocessing,
    'models/trained_model.pth',
    class_names=['Negative', 'Positive']
)

# Later, load for inference
model, package, device = ModelSaver.load_for_inference(
    SimpleNN, 
    'models/trained_model.pth'
)

8.1.3.8. Beyond Basics: Advanced Topics#

For production deployments, consider these advanced techniques (covered in PyTorch documentation):

  • TorchScript: Serialize models for production deployment

  • ONNX Export: Convert to ONNX format for cross-framework compatibility

  • Model Quantization: Reduce model size for deployment

  • TorchServe: Serve PyTorch models at scale

8.1.3.9. Summary#

  • Use state_dict for saving model weights (recommended)

  • Save model configuration to recreate architecture

  • Include preprocessing parameters for deployment

  • Handle device mapping (CPU/GPU) properly

  • Use model.eval() and torch.no_grad() for inference

  • Save metadata for tracking and debugging

Proper PyTorch model persistence ensures models can be reliably loaded and used across different environments and hardware configurations.