How to initialize weights in a PyTorch model?

Published on Aug. 22, 2023, 12:19 p.m.

To initialize weights in a PyTorch model, you can use the torch.nn.init module, which provides a variety of functions to initialize the weights of a model. For example, you can use the torch.nn.init.xavier_uniform_() function to initialize the weights based on a uniform distribution:

import torch
import torch.nn as nn

class MyModel(nn.Module):
  def __init__(self):
    super(MyModel, self).__init__()
    self.fc1 = nn.Linear(in_features=10, out_features=5)
    nn.init.xavier_uniform_(self.fc1.weight)

  def forward(self, x):
    x = self.fc1(x)
    return x

In this example, we are initializing the weights of the fc1 layer using the xavier_uniform_() function from the torch.nn.init module. This will initialize the weights based on a uniform distribution with a specific range, which is designed to keep the activations in the model within reasonable bounds.

Other weight initialization functions available in torch.nn.init include xavier_normal_(), kaiming_uniform_() and kaiming_normal_(), among others. The choice of function will depend on the specific needs of your model and the activation function you are using.

I hope this helps! Let me know if you have any other questions or if there’s anything else I can help you with.

Tags: