How to initialize weights in a PyTorch model?
Published on Aug. 22, 2023, 12:19 p.m.
To initialize weights in a PyTorch model, you can use the torch.nn.init
module, which provides a variety of functions to initialize the weights of a model. For example, you can use the torch.nn.init.xavier_uniform_()
function to initialize the weights based on a uniform distribution:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(in_features=10, out_features=5)
nn.init.xavier_uniform_(self.fc1.weight)
def forward(self, x):
x = self.fc1(x)
return x
In this example, we are initializing the weights of the fc1
layer using the xavier_uniform_()
function from the torch.nn.init
module. This will initialize the weights based on a uniform distribution with a specific range, which is designed to keep the activations in the model within reasonable bounds.
Other weight initialization functions available in torch.nn.init
include xavier_normal_()
, kaiming_uniform_()
and kaiming_normal_()
, among others. The choice of function will depend on the specific needs of your model and the activation function you are using.
I hope this helps! Let me know if you have any other questions or if there’s anything else I can help you with.