🫐Python记忆组合透明度语言模型

Python | 浏览器 | 语言模型 | 推理 | 神经网络 | 数学 | 交互式 | 可视化 | 透明度 | 歌词生成 | 不确定性 | 鲁棒性 | 深度学习 | 估计基准 | 文本 | 前馈 | 记忆组合 | 注意力 | 视觉解释器 | 知识 | 信息

🎯要点

🎯浏览器语言推理识别神经网络 | 🎯不同语言秽语训练识别数据集 | 🎯交互式语言处理解释 Transformer 语言模型 | 🎯可视化Transformer 语言模型 | 🎯语言模型生成优质歌词 | 🎯模型不确定性和鲁棒性深度学习估计基准 | 🎯文本生成神经网络诗歌生成 | 🎯模型透明度 | 🎯验证揭示前馈Transformer 语言模型记忆组合 | 🎯可视化语言模型注意力 | 🎯Transformer语言模型文本解释器和视觉解释器 | 🎯分布式训练和推理模型 | 🎯知识获取模型 | 🎯信息提取模型 | 🎯文本生成模型 | 🎯语音图像视频模型

🍇Python注意力

注意力机制描述了神经网络中最近出现的一组新层,在过去几年中引起了广泛关注,尤其是在序列任务中。文献中对“注意力”有很多不同的定义,但我们在这里使用的定义如下:注意力机制描述了(序列)元素的加权平均值,其权重根据输入查询和元素的键动态计算。那么这到底是什么意思呢?目标是对多个元素的特征取平均值。但是,我们不希望对每个元素赋予相同的权重,而是希望根据它们的实际值赋予它们权重。换句话说,我们希望动态地决定我们更希望“关注”哪些输入。

💦缩放点积注意力

其中 Q、K、V 是查询、键和值向量的串联。

上图中的块 Mask (opt.) 表示对注意力矩阵中的特定条目进行可选屏蔽。例如,如果我们将具有不同长度的多个序列堆叠成一个批次,就会使用此功能。为了仍然受益于 PyTorch 中的并行化,我们将句子填充到相同的长度,并在计算注意力值时屏蔽填充标记。这通常是通过将相应的注意力逻辑设置为非常低的值来实现的。

在讨论了缩放点积注意力块的细节之后,我们可以在下面编写一个函数,在给定查询、键和值三元组的情况下计算输出特征:

 def scaled_dot_product(q, k, v, mask=None):
     d_k = q.size()[-1]
     attn_logits = torch.matmul(q, k.transpose(-2, -1))
     attn_logits = attn_logits / math.sqrt(d_k)
     if mask is not None:
         attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
     attention = F.softmax(attn_logits, dim=-1)
     values = torch.matmul(attention, v)
     return values, attention

请注意,上面的代码支持序列长度前面的任何附加维度,因此我们也可以将其用于批处理。但是,为了更好地理解,让我们生成一些随机查询、键和值向量,并计算注意力输出:

 seq_len, d_k = 3, 2
 pl.seed_everything(42)
 q = torch.randn(seq_len, d_k)
 k = torch.randn(seq_len, d_k)
 v = torch.randn(seq_len, d_k)
 values, attention = scaled_dot_product(q, k, v)
 print("Q\n", q)
 print("K\n", k)
 print("V\n", v)
 print("Values\n", values)
 print("Attention\n", attention)
 Q
  tensor([[ 0.3367,  0.1288],
         [ 0.2345,  0.2303],
         [-1.1229, -0.1863]])
 K
  tensor([[ 2.2082, -0.6380],
         [ 0.4617,  0.2674],
         [ 0.5349,  0.8094]])
 V
  tensor([[ 1.1103, -1.6898],
         [-0.9890,  0.9580],
         [ 1.3221,  0.8172]])
 Values
  tensor([[ 0.5698, -0.1520],
         [ 0.5379, -0.0265],
         [ 0.2246,  0.5556]])
 Attention
  tensor([[0.4028, 0.2886, 0.3086],
         [0.3538, 0.3069, 0.3393],
         [0.1303, 0.4630, 0.4067]])

💦多头注意力

缩放点积注意力允许网络参与序列。然而,序列元素通常需要关注多个不同方面,并且单个加权平均值并不是一个好的选择。这就是为什么我们将注意力机制扩展到多个头,即相同特征上的多个不同的查询键值三元组。具体来说,给定一个查询、键和值矩阵,我们将它们转换为 h 子查询、子键和子值,并独立地通过缩放的点积注意力。然后,我们连接头部并将它们与最终的权重矩阵组合起来。从数学上来说,我们可以将此操作表示为:

 def expand_mask(mask):
     assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
     if mask.ndim == 3:
         mask = mask.unsqueeze(1)
     while mask.ndim < 4:
         mask = mask.unsqueeze(0)
     return mask
 class MultiheadAttention(nn.Module):
 ​
     def __init__(self, input_dim, embed_dim, num_heads):
         super().__init__()
         assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."
 ​
         self.embed_dim = embed_dim
         self.num_heads = num_heads
         self.head_dim = embed_dim // num_heads
 ​
 ​
         self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
         self.o_proj = nn.Linear(embed_dim, embed_dim)
 ​
         self._reset_parameters()
 ​
     def _reset_parameters(self):
 ​
         nn.init.xavier_uniform_(self.qkv_proj.weight)
         self.qkv_proj.bias.data.fill_(0)
         nn.init.xavier_uniform_(self.o_proj.weight)
         self.o_proj.bias.data.fill_(0)
 ​
     def forward(self, x, mask=None, return_attention=False):
         batch_size, seq_length, _ = x.size()
         if mask is not None:
             mask = expand_mask(mask)
         qkv = self.qkv_proj(x)
 ​
 ​
         qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
         qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
         q, k, v = qkv.chunk(3, dim=-1)
 ​
         values, attention = scaled_dot_product(q, k, v, mask=mask)
         values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
         values = values.reshape(batch_size, seq_length, self.embed_dim)
         o = self.o_proj(values)
 ​
         if return_attention:
             return o, attention
         else:
             return o

Last updated