如何为PyTorch中MultiHeadAttention的3D attn_mask参数准备数据

3

我目前正在尝试使用Transformers实现文本摘要的编码器-解码器架构。因此,我需要在模型的解码器端应用MultiHeadAttention。由于我希望确保模型不会注意到目标序列中未见过的标记,所以我需要使用3D注意力掩码(attn_mask参数)。

根据文档(https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html),掩码的形状必须为BATCH_SIZE * NUMBER_HEADS,SEQUENCE_LENGTH,SEQUENCE_LENGTH。这很好,因为它提供了在头之间使用不同注意力的可能性,但在我的情况下不需要...

但文档没有说明张量的第一维应该如何填充,我也找不到在实现中它实际被如何使用...

它是:

[
  [2D Attention for Batch 1 for Head 1]
  [2D Attention for Batch 2 for Head 1]
  ...
  [2D Attention for Batch 1 for Head 2]
  [2D Attention for Batch 2 for Head 2]
  ...
  [2D Attention for Batch n for Head n]
]

或者
[
  [2D Attention for Batch 1 for Head 1]
  [2D Attention for Batch 1 for Head 2]
  ...
  [2D Attention for Batch 2 for Head 1]
  [2D Attention for Batch 2 for Head 2]
  ...
  [2D Attention for Batch n for Head n]
]

希望有人知道,这样就太好了 :)

4
分析 PyTorch 的单元测试,看起来选项 2)是正确的,因为在单元测试中,他们使用 torch.repeat_interleave 而不是 torch.repeat(https://github.com/pytorch/pytorch/blob/c74c0c571880df886474be297c556562e95c00e0/test/test_nn.py#L5039)。 - cokeSchlumpf
没错,repeat_interleave 似乎是最好的选择。 - Nasheed Yasin
1个回答

1

在我看到由cokeSchlimpf发布的链接之前,我一直有同样的疑问。感谢分享。

概述:

如果我们想要为批次中的每个实例设置不同的mask = src_mask,则建议(这里)将src_attention_mask的形状设置为N.num_heads, T, S, 其中N是批处理大小,num_headsMultiHeadAttention模块中头部的数量。此外,T是目标序列长度,S是源序列长度。

链接上代码的说明:

假设mask的形状为N, T, S,那么通过torch.repeat_interleave(mask, num_heads, dim=0)重复每个掩码实例(总共有N个实例)num_heads次并堆叠以形成num_heads, T, S的数组。对于所有这样的N个掩码,再重复上述操作,最终得到一个形状为:

[
    [num_heads, T, S] # for example 1 in the batch
    [num_heads, T, S] # for example 2 in the batch
    .
    .
    .
    [num_heads, T, S] # for example N in the batch
] = [N.num_heads, T, S] # after concatenating along dim=0


下面是使用torch==1.12.1+cu102实现代码的简短片段。
import torch

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        enc_layer = torch.nn.TransformerEncoderLayer(d_model=16, nhead=8, batch_first=True)
        self.layer = torch.nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, x, src_mask, key_mask):
        return self.layer(x, mask=src_mask, src_key_padding_mask=key_mask)

mod = test()
mod.eval()
out = mod(x=torch.randn(2, 22, 16), src_mask=torch.ones(8*2, 22, 22), key_mask=torch.ones(2, 22))
print(out.shape)

希望这可以帮助你!

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接