How to use PyTorch Lightning to log training metrics to TensorBoard for visualization?
Published on Aug. 22, 2023, 12:19 p.m.
To log training metrics to TensorBoard for visualization using PyTorch Lightning, you can use the TensorBoardLogger
callback. Here are the steps:
- Import the
TensorBoardLogger
callback from PyTorch Lightning. - Initialize an instance of
TensorBoardLogger
, passing in the folder where you want to store the logs. - Pass the
TensorBoardLogger
instance to theTrainer
object through thelogger
argument. - Train your model using
model.fit()
as usual.
Here’s an example code snippet:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
class MyModel(pl.LightningModule):
def __init__(self):
# define your model architecture here
def forward(self, x):
# define your forward pass here
def training_step(self, batch, batch_idx):
# your training loop here
def validation_step(self, batch, batch_idx):
# your validation loop here
data_module = MyDataModule()
model = MyModel()
logger = TensorBoardLogger('logs/', name='my_model')
trainer = pl.Trainer(logger=logger, max_epochs=10)
trainer.fit(model, datamodule=data_module)
In this example, the TensorBoardLogger
callback is set up to log the training metrics to the logs/
folder using the name ‘my_model’. During training, the metrics will be stored in this folder and can be visualized using TensorBoard.
Note that the TensorBoardLogger
callback can also log validation and test metrics. You can learn more about the options for logging metrics in PyTorch Lightning from the official documentation.