unilm_mask 注意力矩阵生成

Published on Aug. 22, 2023, 12:10 p.m.

unilm v1注意力矩阵生成

[caption id=”attachment_5072” align=”alignnone” width=”1080”]unilm v1注意力矩阵 unilm v1注意力矩阵[/caption]

def unilm_mask(inputs, s):
    idxs = torch.cumsum(s, dim=1)
    mask = idxs[:, None, :] <= idxs[:, :, None]
    mask = mask[:, None].squeeze(1)
    return mask.to(dtype=torch.int64)
def create_lm_mask(attention_mask, direction='l2r'):

    seq_len = attention_mask.size(-1)
    if attention_mask.ndim == 2:
        attention_mask = attention_mask.view(-1, 1, seq_len)

    idxs = torch.arange(0, seq_len).to(attention_mask)
    if direction == 'l2r':
        triu = (idxs.unsqueeze(-1) >= idxs).float()
    elif direction == 'r2l':
        triu = (idxs.unsqueeze(-1) <= idxs).float()

    attention_mask = (attention_mask + triu > 1).float()
    return attention_mask

unilm_mask= unilm_mask(inputs['input_ids'],inputs['token_type_ids'])
print("unilm_mask",unilm_mask)

create_lm_mask=create_lm_mask(inputs['input_ids'])
print("create_lm_mask",create_lm_mask)

unilm_mask tensor([[[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1]]])

create_lm_mask tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1.]]])

tkitAutoMask实现
https://docs.terrychan.org/tkit-automask/

UniLM 2.0

UniLM 2.0

更多示例
https://colab.research.google.com/drive/11IDalP2xNYWzF4gIz6T3yTjp53UqzkOe#scrollTo=gFeycxpykrCx

https://github.com/SunnyGJing/t5-pegasus-chinese/blob/d34eeca4fc4c9ce78356e47a8968ff6ec07d4111/bert4torch/train.py#L59

https://www.tqwba.com/x_d/jishu/284603.html