LSTM+Attention demo
Published on Aug. 22, 2023, 12:11 p.m.
#!/usr/bin/python
# -*- coding: utf-8 -*-
import torch
from torch import nn
import torch.nn.functional as F
# 定义 单双向LSTM 模型
class Rnn(nn.Module):
”“”
双向lstm
“”“
def __init__(self, in_dim, hidden_dim, n_layer, n_class, bidirectional):
super(Rnn, self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.bidirectional = bidirectional
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True,
bidirectional=bidirectional)
if self.bidirectional:
self.classifier = nn.Linear(hidden_dim * 2, n_class)
else:
self.classifier = nn.Linear(hidden_dim, n_class)
def forward(self, x):
out, (hn, _) = self.lstm(x)
if self.bidirectional:
out = torch.hstack((hn[-2, :, :], hn[-1, :, :]))
else:
out = out[:, -1, :]
out = self.classifier(out)
return out
class Attention(nn.Module):
”“”
注意层
“”“
def __init__(self, rnn_size: int):
super(Attention, self).__init__()
self.w = nn.Linear(rnn_size, 1)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1)
def forward(self, H):
# eq.9: M = tanh(H)
M = self.tanh(H) # (batch_size, word_pad_len, rnn_size)
# eq.10: α = softmax(w^T M)
alpha = self.w(M).squeeze(2) # (batch_size, word_pad_len)
alpha = self.softmax(alpha) # (batch_size, word_pad_len)
# eq.11: r = H
r = H * alpha.unsqueeze(2) # (batch_size, word_pad_len, rnn_size)
r = r.sum(dim=1) # (batch_size, rnn_size)
return r, alpha
class AttBiLSTM(nn.Module):
def __init__(
self,
n_classes: int,
emb_size: int,
rnn_size: int,
rnn_layers: int,
dropout: float
):
super(AttBiLSTM, self).__init__()
self.rnn_size = rnn_size
# bidirectional LSTM
self.BiLSTM = nn.LSTM(
emb_size, rnn_size,
num_layers=rnn_layers,
bidirectional=True,
batch_first=True
)
self.attention = Attention(rnn_size)
self.fc = nn.Linear(rnn_size, n_classes)
self.tanh = nn.Tanh()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
rnn_out, _ = self.BiLSTM(x)
H = rnn_out[:, :, : self.rnn_size] + rnn_out[:, :, self.rnn_size:]
# attention module
r, alphas = self.attention(
H) # (batch_size, rnn_size), (batch_size, word_pad_len)
# eq.12: h* = tanh(r)
h = self.tanh(r) # (batch_size, rnn_size)
scores = self.fc(self.dropout(h)) # (batch_size, n_classes)
return scores