How to define a complex PyTorch model?
Published on Aug. 22, 2023, 12:18 p.m.
Defining a complex PyTorch model typically involves defining multiple layers and composing them using a Sequential
or a custom module. Here is an example of how to define a complex PyTorch model:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MyModel, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5)
)
self.layer2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5)
)
self.layer3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
In this example, MyModel
is a custom PyTorch module that takes as input a tensor of shape (batch_size, input_dim)
and outputs a tensor of shape (batch_size, output_dim)
. The model contains three layers: layer1
is a fully connected layer followed by batch normalization, ReLU activation, and dropout; layer2
is a second fully connected layer followed by batch normalization, ReLU activation, and dropout; and layer3
is a final fully connected layer.
The forward
method specifies how the input tensor x
is processed by each layer in sequence. Note that since we used nn.Sequential
to define the layers, we can simply pass the input tensor x
to each layer in succession using the dot .
notation.
This is just one example of a complex PyTorch model - the structure, number of layers, and activation functions used will depend on the specific task and the nature of the data being used.