How to perform LSTM in PyTorch?
Published on Aug. 22, 2023, 12:18 p.m.
To perform LSTM in PyTorch, you can use the nn.LSTM
module. Here is an example code snippet that demonstrates how to define an LSTM module in PyTorch:
import torch
import torch.nn as nn
# Define LSTM module
input_size = 10
hidden_size = 20
num_layers = 2
batch_size = 5
seq_len = 15
lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True)
# Create input tensor
x = torch.randn(batch_size, seq_len, input_size)
# Initialize hidden state and cell state
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)
# Forward pass through LSTM layer
out, (hn, cn) = lstm(x, (h0, c0))
In this example, we define an LSTM module with a given input_size
, hidden_size
, num_layers
, and batch_first
argument. We then create an input tensor x
, and initialize the hidden state h0
and cell state c0
. Finally, we perform a forward pass through the LSTM layer using the lstm()
method, which returns the output tensor out
, and the final hidden state hn
and cell state cn
.
Note that the input tensor must have dimensions (batch_size, seq_len, input_size)
, and the hidden state and cell state tensors must have dimensions (num_layers, batch_size, hidden_size)
.