How to use PyTorch Lightning to log training metrics to TensorBoard for visualization?

Published on Aug. 22, 2023, 12:19 p.m.

To log training metrics to TensorBoard for visualization using PyTorch Lightning, you can use the TensorBoardLogger callback. Here are the steps:

  1. Import the TensorBoardLogger callback from PyTorch Lightning.
  2. Initialize an instance of TensorBoardLogger, passing in the folder where you want to store the logs.
  3. Pass the TensorBoardLogger instance to the Trainer object through the logger argument.
  4. Train your model using model.fit() as usual.

Here’s an example code snippet:

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

class MyModel(pl.LightningModule):
    def __init__(self):
        # define your model architecture here

    def forward(self, x):
        # define your forward pass here

    def training_step(self, batch, batch_idx):
        # your training loop here

    def validation_step(self, batch, batch_idx):
        # your validation loop here

data_module = MyDataModule()
model = MyModel()
logger = TensorBoardLogger('logs/', name='my_model')
trainer = pl.Trainer(logger=logger, max_epochs=10)
trainer.fit(model, datamodule=data_module)

In this example, the TensorBoardLogger callback is set up to log the training metrics to the logs/ folder using the name ‘my_model’. During training, the metrics will be stored in this folder and can be visualized using TensorBoard.

Note that the TensorBoardLogger callback can also log validation and test metrics. You can learn more about the options for logging metrics in PyTorch Lightning from the official documentation.