transformer细节及代码实现2-Multi-head attention

transformer

transformer总体结构
在这里插入图片描述
多头注意力(Multi-head attention)是按比缩放的点积注意力(Scaled dot product attention)组成的,所以首先需要了解按比缩放的点积注意力。

下面是简化的encoder-decoder模型,对于整体的模型来说会有多个encoder和decoder。
在这里插入图片描述
在每一个encoder和decoder中可以简化成下图所示:encoder中包含self-attention模型和feed-forward两个模块。而decoder多包含了一个encoder-decoder attention 模块。具体原理会在下面提及
在这里插入图片描述
在每个encoder中的内部细节如下图所示:

输入的词通过enmbedding层并且与positional_encoding融合作为self-attention的输入,经过self-attention的结果经过add-normalize层进而输入到feed forward中,最后又通过一个add-normalize。值得一提的是(图中包含两个残差网络(虚线表示))
在这里插入图片描述

1.Scaled dot product attention(按比缩放的点积注意力)
按比缩放的点积注意力原理图为:
在这里插入图片描述
Transformer 使用的注意力函数有三个输入:Q(请求(query))、K(主键(key))、V(数值(value))。用于计算注意力权重的等式为:
query向量:query顾名思义,是负责寻找这个字的于其他字的相关度(通过其它字的key)
key向量:key向量就是用来于query向量作匹配,得到相关度评分的
value向量:Value vectors 是实际上的字的表示, 一旦我们得到了字的相关度评分,这些表示是用来加权求和的
得到每个字的 之后,我们要得到每个字和句子中其他字的相关关系,我们只需要把这个字的query去和其他字的key作匹配,然后得到分数,最后在通过其它字的value的加权求和(权重就是哪个分数)得到这个字的最终输出。

点积注意力被缩小了深度的平方根倍。这样做是因为对于较大的深度值,点积的大小会增大,从而推动 softmax 函数往仅有很小的梯度的方向靠拢,导致了一种很硬的(hard)softmax。

例如,假设 Q 和 K 的均值为0,方差为1。它们的矩阵乘积将有均值为0,方差为 dk。因此,dk 的平方根被用于缩放(而非其他数值),因为,Q 和 K 的矩阵乘积的均值本应该为 0,方差本应该为1,这样会获得一个更平缓的 softmax。

遮挡(mask)与 -1e9(接近于负无穷)相乘。这样做是因为遮挡与缩放的 Q 和 K 的矩阵乘积相加,并在 softmax 之前立即应用。目标是将这些单元归零,因为 softmax 的较大负数输入在输出中接近于零。

下面将举例说明这一过程:
如上面的总体描述图所示:
在这里插入图片描述
为了提高计算速度,将会采用矩阵的计算方式:
在这里插入图片描述
在这里插入图片描述

1def scaled_dot_product_attention(q, k, v, mask): 2 """计算注意力权重。 3 q, k, v 必须具有匹配的前置维度。 4 k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。 5 虽然 mask 根据其类型(填充或前瞻)有不同的形状, 6 但是 mask 必须能进行广播转换以便求和。 7 8 参数: 9 q: 请求的形状 == (..., seq_len_q, depth) 10 k: 主键的形状 == (..., seq_len_k, depth) 11 v: 数值的形状 == (..., seq_len_v, depth_v) 12 mask: Float 张量,其形状能转换成 13 (..., seq_len_q, seq_len_k)。默认为None。 14 15 返回值: 16 输出,注意力权重 17 """ 18 19 matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 20 21 # 缩放 matmul_qk 22 dk = tf.cast(tf.shape(k)[-1], tf.float32) 23 scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 24 25 # 将 mask 加入到缩放的张量上。 26 if mask is not None: 27 scaled_attention_logits += (mask * -1e9) 28 29 # softmax 在最后一个轴(seq_len_k)上归一化,因此分数 30 # 相加等于1。 31 attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 32 33 output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 34 35 return output, attention_weights 36 37

2.Multi-head attention
在这里插入图片描述
多头注意力由四部分组成:

线性层并分拆成多头。
按比缩放的点积注意力。
多头及联。
最后一层线性层

每个多头注意力块有三个输入:Q(请求)、K(主键)、V(数值)。这些输入经过线性(Dense)层,并分拆成多头。

将上面定义的 scaled_dot_product_attention 函数应用于每个头(进行了广播(broadcasted)以提高效率)。注意力这步必须使用一个恰当的 mask。然后将每个头的注意力输出连接起来(用tf.transpose 和 tf.reshape),并放入最后的 Dense 层。

Q、K、和 V 被拆分到了多个头,而非单个的注意力头,因为多头允许模型共同注意来自不同表示空间的不同位置的信息。在分拆后,每个头部的维度减少,因此总的计算成本与有着全部维度的单个注意力头相同。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

1class MultiHeadAttention(tf.keras.layers.Layer): 2 def __init__(self, d_model, num_heads): 3 super(MultiHeadAttention, self).__init__() 4 self.num_heads = num_heads 5 self.d_model = d_model 6 7 assert d_model % self.num_heads == 0 8 9 self.depth = d_model // self.num_heads 10 11 self.wq = tf.keras.layers.Dense(d_model) 12 self.wk = tf.keras.layers.Dense(d_model) 13 self.wv = tf.keras.layers.Dense(d_model) 14 15 self.dense = tf.keras.layers.Dense(d_model) 16 17 def split_heads(self, x, batch_size): 18 """ 19 分拆最后一个维度到 (num_heads, depth). 20 转置结果使得形状为 (batch_size, num_heads, seq_len, depth) 21 """ 22 x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 23 return tf.transpose(x, perm=[0, 2, 1, 3]) 24 25 def call(self, v, k, q, mask): 26 batch_size = tf.shape(q)[0] 27 28 q = self.wq(q) # (batch_size, seq_len, d_model) 29 k = self.wk(k) # (batch_size, seq_len, d_model) 30 v = self.wv(v) # (batch_size, seq_len, d_model) 31 32 q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 33 k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 34 v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 35 36 # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) 37 # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 38 scaled_attention, attention_weights = scaled_dot_product_attention( 39 q, k, v, mask) 40 41 scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 42 43 concat_attention = tf.reshape(scaled_attention, 44 (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) 45 46 output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model) 47 48 return output, attention_weights 49 50

代码交流 2021