How to define a PyTorch loss function?

Published on Aug. 22, 2023, 12:18 p.m.

In PyTorch, you can define a loss function by creating a new class that inherits from torch.nn.Module and overrides the forward method. Here is an example of how to define a mean squared error (MSE) loss function:

import torch.nn as nn

class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, y_pred, y_true):
        return ((y_pred - y_true) ** 2).mean()

In this code, we define a new class called MyLoss that computes the mean squared error between the predicted y_pred and true y_true values, using the PyTorch nn.Module base class. The __init__ method initializes any parameters that are needed for the computation, but in this case, we don’t need any.

To use our custom loss function in a training loop, we can create an instance of it and pass it to the optimizer.

import torch.optim as optim

model = MyModel()
loss_fn = MyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for input_data, target in dataset:
        output = model(input_data)
        loss = loss_fn(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In this code, we create an instance of MyLoss, pass it to the optimizer along with the model parameters, and use it to compute the loss in each iteration of the training loop. We also use the backward method to compute gradients and the step method to update the model parameters. Note that this is just a simple example and that more complex loss functions may require additional arguments and computation.