2. How to use PyTorch Lightning to implement early stopping during training?
Published on Aug. 22, 2023, 12:19 p.m.
To implement early stopping during training using PyTorch Lightning, you can use the EarlyStopping
callback. Here are the steps:
- Import the
EarlyStopping
callback from PyTorch Lightning. - Initialize an instance of
EarlyStopping
, passing in the relevant arguments such as the metric to monitor, the patience and the mode. - Pass the
EarlyStopping
instance to theTrainer
object through thecallbacks
argument. - Train your model using
model.fit()
as usual.
Here’s an example code snippet:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
class MyModel(pl.LightningModule):
def __init__(self):
# define your model architecture here
def training_step(self, batch, batch_idx):
# your training loop here
def validation_step(self, batch, batch_idx):
# your validation loop here
def configure_optimizers(self):
# define and return an optimizer here
data_module = MyDataModule()
model = MyModel()
early_stopping = EarlyStopping(monitor='val_loss', patience=3, mode='min')
trainer = pl.Trainer(callbacks=[early_stopping], max_epochs=10)
trainer.fit(model, datamodule=data_module)
In this example, the EarlyStopping
callback is set up to monitor the validation loss, with a patience of 3 epochs and mode set to “min”. This means that if the validation loss doesn’t improve for 3 consecutive epochs, the training stops early.