2. How to use PyTorch Lightning to implement early stopping during training?

Published on Aug. 22, 2023, 12:19 p.m.

To implement early stopping during training using PyTorch Lightning, you can use the EarlyStopping callback. Here are the steps:

  1. Import the EarlyStopping callback from PyTorch Lightning.
  2. Initialize an instance of EarlyStopping, passing in the relevant arguments such as the metric to monitor, the patience and the mode.
  3. Pass the EarlyStopping instance to the Trainer object through the callbacks argument.
  4. Train your model using model.fit() as usual.

Here’s an example code snippet:

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

class MyModel(pl.LightningModule):
    def __init__(self):
        # define your model architecture here

    def training_step(self, batch, batch_idx):
        # your training loop here

    def validation_step(self, batch, batch_idx):
        # your validation loop here

    def configure_optimizers(self):
        # define and return an optimizer here

data_module = MyDataModule()
model = MyModel()
early_stopping = EarlyStopping(monitor='val_loss', patience=3, mode='min')
trainer = pl.Trainer(callbacks=[early_stopping], max_epochs=10)
trainer.fit(model, datamodule=data_module)

In this example, the EarlyStopping callback is set up to monitor the validation loss, with a patience of 3 epochs and mode set to “min”. This means that if the validation loss doesn’t improve for 3 consecutive epochs, the training stops early.