DotProductAttention
Published on Aug. 22, 2023, 12:11 p.m.
用所有值计算查询的点积,并应用 softmax 函数来获得值的权重
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple
class DotProductAttention(nn.Module):
"""
Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
https://github.com/napoler/attentions/blob/d6a6b12b8a2b473e5a23e37a13c9547f00f8ffb4/attentions.py#L45
"""
def __init__(self, hidden_dim):
super(DotProductAttention, self).__init__()
self.normalize = nn.LayerNorm(hidden_dim)
self.out_projection = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
score = torch.bmm(query, value.transpose(1, 2))
attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
context = torch.bmm(attn, value)
return context, attn
出处论文