How to use PyTorch Lightning to save and load models during training and inference?

Published on Aug. 22, 2023, 12:19 p.m.

To save and load models during training and inference using PyTorch Lightning, you can use the save_checkpoint() and load_from_checkpoint() methods. Here are the steps:

  1. In your PyTorch Lightning LightningModule, define a method to save the model using the save_checkpoint() method. This method should save any necessary information such as model weights, optimizer state, and any other relevant training state.
  2. To load a saved checkpoint during inference, define another method to load the model using the load_from_checkpoint() method. This method should load the saved state and restore it to the model.
  3. Call the save_checkpoint() method during training at the end of each epoch or at any other desired intervals.
  4. To load the saved checkpoint during inference, call the load_from_checkpoint() method and pass in the path to the saved checkpoint file.

Here’s an example code snippet:

import pytorch_lightning as pl

class MyModel(pl.LightningModule):
    def __init__(self):
        # define your model architecture here

    def forward(self, x):
        # define your forward pass here

    def training_step(self, batch, batch_idx):
        # your training loop here

    def validation_step(self, batch, batch_idx):
        # your validation loop here

    def save_model(self, checkpoint_path):
        checkpoint = {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.trainer.optimizers[0].state_dict(),
                }
        torch.save(checkpoint, checkpoint_path)

    def load_model(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.trainer.optimizers[0].load_state_dict(checkpoint['optimizer_state_dict'])

data_module = MyDataModule()
model = MyModel()
trainer = pl.Trainer(max_epochs=10)

for epoch in range(trainer.max_epochs):
    trainer.fit(model, datamodule=data_module)
    model.save_model(f'checkpoint_epoch_{epoch}.pt')

# Load model for inference or resuming training
model.load_model('checkpoint_epoch_9.pt')

This example saves the model at the end of each training epoch and loads it for inference or resuming training. Note that this is just a basic example and you may need to modify it based on your specific needs.