在Transformer中验证Multihead Attention的实现

5
我已经在Transformers中实现了MultiAttention头部。有很多实现方式让人感到困惑。有人能验证我的实现是否正确吗?
DotProductAttention 参考自:https://www.tensorflow.org/tutorials/text/transformer#setup
import tensorflow as tf

def scaled_dot_product(q,k,v):
    #calculates Q . K(transpose)
    qkt = tf.matmul(q,k,transpose_b=True)
    #caculates scaling factor
    dk = tf.math.sqrt(tf.cast(q.shape[-1],dtype=tf.float32))
    scaled_qkt = qkt/dk
    softmax = tf.nn.softmax(scaled_qkt,axis=-1)
    
    z = tf.matmul(softmax,v)
    #shape: (m,Tx,depth), same shape as q,k,v
    return z

class MultiAttention(tf.keras.layers.Layer):
    def __init__(self,d_model,num_of_heads):
        super(MultiAttention,self).__init__()
        self.d_model = d_model
        self.num_of_heads = num_of_heads
        self.depth = d_model//num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(d_model)
        
    def call(self,x):
        
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(scaled_dot_product(Q,K,V))
            
        multi_head = tf.concat(multi_attn,axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

#Calling the attention 
multi = MultiAttention(d_model=512,num_of_heads=8)
m = 5; sequence_length = 4; word_embedding_dim = 512
sample_ip = tf.constant(tf.random.normal(shape=(m,sequence_length,word_embedding_dim)))
attn =multi(sample_ip)
#shape of op (attn): (5,4,512)

1
你正在寻找代码审查吗?如果是的话,我认为这里不是合适的地方,但是可以去code review - Innat
1
不,我在寻找更多关于MultiAttention逻辑的审查。你已经在答案中解释了。谢谢。 - data_person
1个回答

4
在你的实现中,scaled_dot_product 函数中使用了 query 进行缩放,但是根据原论文,它们使用了 key 进行归一化。除此之外,这个实现看起来还可以,但不够通用。
class MultiAttention(tf.keras.layers.Layer):
    def __init__(self, num_of_heads, out_dim):
        super(MultiAttention,self).__init__()
        self.out_dim      = out_dim
        self.num_of_heads = num_of_heads
        self.depth        = self.out_dim // self.num_of_heads
        self.wq = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wk = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wv = [tf.keras.layers.Dense(self.depth) for i in range(num_of_heads)]
        self.wo = tf.keras.layers.Dense(self.out_dim)
        
    def call(self,x):
        multi_attn = []
        for i in range(self.num_of_heads):
            Q = self.wq[i](x)
            K = self.wk[i](x)
            V = self.wv[i](x)
            multi_attn.append(self.scaled_dot_product(Q,K,V))

        multi_head = tf.concat(multi_attn, axis=-1)
        multi_head_attention = self.wo(multi_head)
        return multi_head_attention

    def scaled_dot_product(self, q,k,v):
        qkt = tf.matmul(q, k, transpose_b=True)
        dk = tf.math.sqrt( tf.cast(k.shape[-1], dtype=tf.float32) )
        scaled_qkt = qkt/dk
        softmax = tf.nn.softmax(scaled_qkt, axis=-1)
        z = tf.matmul(softmax, v)
        return z

multi = MultiAttention(num_of_heads=3, out_dim=32)
sample_ip = tf.random.normal(shape=(2, 2, 32)); print(sample_ip.shape)
multi(sample_ip).shape

通用的Transformer架构可以演示如下图所示,其中前两个线性层表示query和key,并负责生成注意力权重映射,后面通过矩阵乘法加权value。
原图如下: 图片来源:https://www.youtube.com/watch?v=mMa2PmYJlCo
我了解你正在试图将原始TF教程代码最小化,但我认为你应该先在原始问题中添加参考。在原始实现中,它们还返回加权概率或分数以及加权特征映射。我认为你不应该跳过这个。
您正在遵循的原始代码更通用且经过有效优化。
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.dense = tf.keras.layers.Dense(d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
        # scale matmul_qk
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        # add the mask to the scaled tensor.
        if mask is not None: scaled_attention_logits += (mask * -1e9)
        # softmax is normalized on the last axis (seq_len_k) so that the scores
        # add up to 1.
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
        output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
        return output, attention_weights

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention,  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights

FYI,在TF 2.4中,正式添加了tf.keras.layers.MultiHeadAttention层。

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[2, 2, 32]); print(input_tensor.shape)
print(layer(input_tensor, input_tensor).shape)

您可以按照以下方法测试这两个内容:
# custom layer MHA
multi = MultiHeadAttention(d_model=512, num_heads=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = multi(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))

# built-in layer 
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
y = tf.random.uniform((1, 60, 512))  
out, attn = layer(y, y, return_attention_scores=True)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 2, 60, 60]))

官方的MultiHeadAttention层中,key_dim参数是用来做什么的? - Aritra Roy Gosthipaty
查询和键的每个注意力头的大小。 - Innat
这就是文档上所说的,但我并不真正理解它。它是指缩放点积注意力之前的线性投影层的大小吗? - Aritra Roy Gosthipaty
我对key_dim和num_heads有些困惑。我实际上提出了一个专门关注这个问题的问题:https://stackoverflow.com/questions/73900721/confusion-regarding-num-heads-key-dim-keras-layers-multiheadattention-in-the-t - kawingkelvin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接