How to define a simple PyTorch model?
Published on Aug. 22, 2023, 12:18 p.m.
To define a simple PyTorch model, you can create a new class that inherits from torch.nn.Module
. In this class, you define the layers of your model in the __init__
method and specify the forward computation of the model in the forward
method. Here is an example of a simple network with two linear layers and a ReLU activation function:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
In this code, we define a new class called MyModel
that has two linear layers and a ReLU activation function. In the __init__
method, we define the layers of the model and specify their input and output sizes using the nn.Linear
class. In the forward
method, we specify how the input should be passed through the layers, using the nn.ReLU
activation function after the first linear layer.
Once you have defined your model, you can create an instance of it and use it to make predictions on your data.
import torch
model = MyModel()
input_data = torch.randn(1, 10)
output = model(input_data)
This code creates an instance of MyModel
, generates some random input data with a size of 1x10, and passes it through the model to get the model’s prediction.