nn Sequential 中使用LSTM/GRU
Published on Aug. 22, 2023, 12:10 p.m.
由于LSTM/GRU这种是多个输出,所以没法直接使用,可以采用自定义模块的方式来转化成一个。
还有一个方案:
# I made a module called SelectItem to pick out an element from a tuple or list
class SelectItem(nn.Module):
def __init__(self, item_index):
super(SelectItem, self).__init__()
self._name = 'selectitem'
self.item_index = item_index
def forward(self, inputs):
return inputs[self.item_index]
# SelectItem can be used in Sequential to pick out the hidden state:
net = nn.Sequential(
nn.GRU(dim_in, dim_out, batch_first=True),
SelectItem(1)
)
方案来自于