How to perform LSTM in PyTorch?

Published on Aug. 22, 2023, 12:18 p.m.

To perform LSTM in PyTorch, you can use the nn.LSTM module. Here is an example code snippet that demonstrates how to define an LSTM module in PyTorch:

import torch
import torch.nn as nn

# Define LSTM module
input_size = 10
hidden_size = 20
num_layers = 2
batch_size = 5
seq_len = 15

lstm = nn.LSTM(input_size=input_size, 
               hidden_size=hidden_size,
               num_layers=num_layers,
               batch_first=True)

# Create input tensor
x = torch.randn(batch_size, seq_len, input_size)

# Initialize hidden state and cell state
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)

# Forward pass through LSTM layer
out, (hn, cn) = lstm(x, (h0, c0))

In this example, we define an LSTM module with a given input_size, hidden_size, num_layers, and batch_first argument. We then create an input tensor x, and initialize the hidden state h0 and cell state c0. Finally, we perform a forward pass through the LSTM layer using the lstm() method, which returns the output tensor out, and the final hidden state hn and cell state cn.

Note that the input tensor must have dimensions (batch_size, seq_len, input_size), and the hidden state and cell state tensors must have dimensions (num_layers, batch_size, hidden_size).

Tags: