How to define a complex PyTorch model?

Published on Aug. 22, 2023, 12:18 p.m.

Defining a complex PyTorch model typically involves defining multiple layers and composing them using a Sequential or a custom module. Here is an example of how to define a complex PyTorch model:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MyModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5)
        )
        self.layer3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

In this example, MyModel is a custom PyTorch module that takes as input a tensor of shape (batch_size, input_dim) and outputs a tensor of shape (batch_size, output_dim). The model contains three layers: layer1 is a fully connected layer followed by batch normalization, ReLU activation, and dropout; layer2 is a second fully connected layer followed by batch normalization, ReLU activation, and dropout; and layer3 is a final fully connected layer.

The forward method specifies how the input tensor x is processed by each layer in sequence. Note that since we used nn.Sequential to define the layers, we can simply pass the input tensor x to each layer in succession using the dot . notation.

This is just one example of a complex PyTorch model - the structure, number of layers, and activation functions used will depend on the specific task and the nature of the data being used.

Tags: