How to define a PyTorch loss function?
Published on Aug. 22, 2023, 12:18 p.m.
In PyTorch, you can define a loss function by creating a new class that inherits from torch.nn.Module
and overrides the forward
method. Here is an example of how to define a mean squared error (MSE) loss function:
import torch.nn as nn
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, y_pred, y_true):
return ((y_pred - y_true) ** 2).mean()
In this code, we define a new class called MyLoss
that computes the mean squared error between the predicted y_pred
and true y_true
values, using the PyTorch nn.Module
base class. The __init__
method initializes any parameters that are needed for the computation, but in this case, we don’t need any.
To use our custom loss function in a training loop, we can create an instance of it and pass it to the optimizer.
import torch.optim as optim
model = MyModel()
loss_fn = MyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for input_data, target in dataset:
output = model(input_data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
In this code, we create an instance of MyLoss
, pass it to the optimizer along with the model parameters, and use it to compute the loss in each iteration of the training loop. We also use the backward
method to compute gradients and the step
method to update the model parameters. Note that this is just a simple example and that more complex loss functions may require additional arguments and computation.