缩放的点乘注意力ScaledDotProductAttention
Published on Aug. 22, 2023, 12:11 p.m.
缩放的点乘注意力ScaledDotProductAttention 就是Transformer中使用的注意力
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple
# https://github.com/napoler/attentions/blob/master/attentions.py#L10
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Compute the dot products of the query with all keys, divide each by sqrt(dim),
and apply a softmax function to obtain the weights on the values
Args: dim, mask
dim (int): dimention of attention
mask (torch.Tensor): tensor containing indices to be masked
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: context, attn
- **context**: tensor containing the context vector from attention mechanism.
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, dim: int):
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask.view(score.size()), -float('Inf'))
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn