缩放点积注意力允许网络参与序列。然而,序列元素通常需要关注多个不同方面,并且单个加权平均值并不是一个好的选择。这就是为什么我们将注意力机制扩展到多个头,即相同特征上的多个不同的查询键值三元组。具体来说,给定一个查询、键和值矩阵,我们将它们转换为 h 子查询、子键和子值,并独立地通过缩放的点积注意力。然后,我们连接头部并将它们与最终的权重矩阵组合起来。从数学上来说,我们可以将此操作表示为:
多头(Q,K,V)其中 head i=Concat( head 1,…, head h)WO= Attention (QWiQ,KWiK,VWiV)
def expand_mask(mask):
assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
if mask.ndim == 3:
mask = mask.unsqueeze(1)
while mask.ndim < 4:
mask = mask.unsqueeze(0)
return mask
class MultiheadAttention(nn.Module):
def __init__(self, input_dim, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
self.o_proj = nn.Linear(embed_dim, embed_dim)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.qkv_proj.weight)
self.qkv_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.o_proj.weight)
self.o_proj.bias.data.fill_(0)
def forward(self, x, mask=None, return_attention=False):
batch_size, seq_length, _ = x.size()
if mask is not None:
mask = expand_mask(mask)
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
q, k, v = qkv.chunk(3, dim=-1)
values, attention = scaled_dot_product(q, k, v, mask=mask)
values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
values = values.reshape(batch_size, seq_length, self.embed_dim)
o = self.o_proj(values)
if return_attention:
return o, attention
else:
return o