How to define a simple PyTorch model?

Published on Aug. 22, 2023, 12:18 p.m.

To define a simple PyTorch model, you can create a new class that inherits from torch.nn.Module. In this class, you define the layers of your model in the __init__ method and specify the forward computation of the model in the forward method. Here is an example of a simple network with two linear layers and a ReLU activation function:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In this code, we define a new class called MyModel that has two linear layers and a ReLU activation function. In the __init__ method, we define the layers of the model and specify their input and output sizes using the nn.Linear class. In the forward method, we specify how the input should be passed through the layers, using the nn.ReLU activation function after the first linear layer.

Once you have defined your model, you can create an instance of it and use it to make predictions on your data.

import torch

model = MyModel()
input_data = torch.randn(1, 10)
output = model(input_data)

This code creates an instance of MyModel, generates some random input data with a size of 1x10, and passes it through the model to get the model’s prediction.

Tags: