多头自注意力中的att_mask和key_padding_mask有什么区别?

21
在PyTorch的MultiHeadAttention中,key_padding_maskattn_mask有什么区别:

key_padding_mask——如果提供了此参数,则键(key)中指定的填充元素将被注意力机制忽略。当给定二进制掩码并且值为True时,注意力层上相应的值将被忽略。当给定字节掩码并且值不为零时,注意力层上相应的值将被忽略。

attn_mask——2D或3D掩码,防止注意力机制关注特定位置。2D掩码将被广播到所有批次,而3D掩码允许为每个批次的条目指定不同的掩码。

提前感谢。

2个回答

29

key_padding_mask 用于屏蔽填充位置,即输入序列末尾之后的位置。这始终针对输入批次,并取决于批次中序列与最长序列的长度比较。它是一个形状为 batch size × input length 的二维张量。

另一方面,attn_mask 表示哪些键值对是有效的。在 Transformer 解码器中,使用三角形掩码来模拟推理时间,防止关注“未来”位置。这通常是用于 att_mask 的。如果它是一个二维张量,则形状为 input length × input length。您还可以拥有针对批次中每个项目具体的掩码。在这种情况下,您可以使用形状为 (batch size × num heads) × input length × input length 的三维张量。(因此,理论上,您可以使用三维的 att_mask 模拟 key_padding_mask。)


1
每个批次都有特定的掩码,这有什么目的呢?好奇。 - Brofessor
每个批次中的每个项目可能在不同位置都有填充。例如,如果输入是一系列句子,并且它们在开头或结尾处进行了填充,则我们需要为每个句子应用单独的掩码。对于解码器(指编码器输入的键和值),此掩码将是attn_mask和key_padding_mask的组合。 - Allohvk
1
当为批处理中的每个项目传递掩码时,模块是否对每个注意力头沿着0维使用顺序项目?例如,当batch_size=32num_heads=4时,att_mask[:4,:,:]是用于第1个项目(对于头1、2、3和4)的掩码吗? - skurp

3
我认为它们的作用相同:两个掩码都定义了查询和键之间不使用的注意力,两者唯一的区别在于你更喜欢哪种形状的掩码输入。 根据代码,似乎这两个掩码被合并/取并集,因此它们都扮演着相同的角色——即不使用查询和键之间的注意力。由于它们是取并集的:如果必须使用两个掩码,则两个掩码输入可以具有不同的值,或者根据所需形状方便地在任何掩码参数中输入掩码。下面是来自pytorch/functional.py第5227行左右的函数multi_head_attention_forward()的部分原始代码。
...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

如果您有不同的意见或我错了,请纠正我。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接