这一章我们来补充Transfomer的相应代码。
Scaled Dot-Product Attention:
# * (Q):Query,表示当前 token 想查询什么信息;
# * (K):Key,表示每个 token 能提供什么匹配特征;
# * (V):Value,表示每个 token 真正携带的内容信息;
# * (d_k):Key 向量的维度;
# * (\sqrt{d_k}):缩放因子,用于防止点积结果过大。
# * (Q):Query,表示当前 token 想查询什么信息; # * (K):Key,表示每个 token 能提供什么匹配特征; # * (V):Value,表示每个 token 真正携带的内容信息; # * (d_k):Key 向量的维度; # * (\sqrt{d_k}):缩放因子,用于防止点积结果过大。 import torch import torch.nn.functional as F import math def scaled_dot_product_attention(Q, K, V, mask=None): """ Scaled Dot-Product Attention 参数: Q: Query 矩阵,形状为 [batch_size, seq_len_q, d_k] K: Key 矩阵,形状为 [batch_size, seq_len_k, d_k] V: Value 矩阵,形状为 [batch_size, seq_len_k, d_v] mask: 可选掩码矩阵,形状可以广播到 [batch_size, seq_len_q, seq_len_k] 返回: output: Attention 输出,形状为 [batch_size, seq_len_q, d_v] attention_weights: 注意力权重矩阵,形状为 [batch_size, seq_len_q, seq_len_k] """ # 1. 获取 Key 向量维度 d_k d_k = Q.size(-1) # 2. 计算 QK^T # K.transpose(-2, -1) 表示交换 K 的最后两个维度 scores = torch.matmul(Q, K.transpose(-2, -1)) # 3. 除以 sqrt(d_k),防止点积结果过大 scores = scores / math.sqrt(d_k) # 4. 如果有 mask,则把被 mask 的位置设为一个非常小的值 # 这样经过 softmax 后,这些位置的权重接近 0 if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 5. 对最后一个维度做 softmax,得到注意力权重 attention_weights = F.softmax(scores, dim=-1) # 6. 用注意力权重对 V 加权求和 output = torch.matmul(attention_weights, V) return output, attention_weights if __name__ == '__main__': # batch_size = 1 # seq_len = 3 # d_k = 4 # d_v = 2 Q = torch.tensor([ [[1.0, 0.0, 1.0, 0.0],[0.0, 1.0, 0.0, 1.0],[1.0, 1.0, 0.0, 0.0]]]) K = torch.tensor([ [[1.0, 0.0, 1.0, 0.0],[0.0, 1.0, 0.0, 1.0],[1.0, 1.0, 0.0, 0.0]]]) V = torch.tensor([ [[1.0, 2.0],[3.0, 4.0],[5.0, 6.0]]]) # ====================== # 执行 Attention # ====================== output, attention_weights = scaled_dot_product_attention(Q, K, V) print("Q shape:", Q.shape) print("K shape:", K.shape) print("V shape:", V.shape) print("\nAttention Weights:") print(attention_weights) print("\nAttention Output:") print(output)#表示新的Multi-Head Attention
在 Transformer 中,单头 Attention 只能在一个表示空间中计算 token 之间的关系。Multi-Head Attention 的思想是:
把 Q、K、V 投影到多个子空间
↓
每个子空间独立计算 Attention
↓
把多个 head 的结果拼接起来
↓
再经过一个线性层得到最终输出
import torch import torch.nn as nn import torch.nn.functional as F import math def scaled_dot_product_attention(Q, K, V, mask=None): """ Scaled Dot-Product Attention Q: [batch_size, num_heads, seq_len_q, d_k] K: [batch_size, num_heads, seq_len_k, d_k] V: [batch_size, num_heads, seq_len_k, d_v] mask: 可选,形状可广播到 [batch_size, num_heads, seq_len_q, seq_len_k] """ d_k = Q.size(-1) # [batch_size, num_heads, seq_len_q, seq_len_k] scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention_weights = F.softmax(scores, dim=-1) # [batch_size, num_heads, seq_len_q, d_v] output = torch.matmul(attention_weights, V) return output, attention_weights class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): """ d_model: 输入向量维度 num_heads: 注意力头数 """ super().__init__() assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # 分别生成 Q、K、V 的线性层 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) # 多个 head 拼接后,再经过输出线性层 self.W_o = nn.Linear(d_model, d_model) def split_heads(self, x): """ 将输入拆分成多个注意力头 输入: x: [batch_size, seq_len, d_model] 输出: x: [batch_size, num_heads, seq_len, d_k] """ batch_size, seq_len, d_model = x.size() # [batch_size, seq_len, num_heads, d_k] x = x.view(batch_size, seq_len, self.num_heads, self.d_k) # [batch_size, num_heads, seq_len, d_k] x = x.transpose(1, 2) return x def combine_heads(self, x): """ 将多个注意力头重新拼接回来 输入: x: [batch_size, num_heads, seq_len, d_k] 输出: x: [batch_size, seq_len, d_model] """ batch_size, num_heads, seq_len, d_k = x.size() # [batch_size, seq_len, num_heads, d_k] x = x.transpose(1, 2) # contiguous 用于保证内存连续,便于 view 操作 x = x.contiguous().view(batch_size, seq_len, self.d_model) return x def forward(self, Q, K, V, mask=None): """ Q, K, V: [batch_size, seq_len, d_model] """ # 1. 线性变换,生成 Q、K、V Q = self.W_q(Q) K = self.W_k(K) V = self.W_v(V) # 2. 拆分成多个 head Q = self.split_heads(Q) K = self.split_heads(K) V = self.split_heads(V) # 3. 每个 head 内部独立计算 Scaled Dot-Product Attention attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask) # 4. 将多个 head 拼接回来 attention_output = self.combine_heads(attention_output) # 5. 输出线性变换 output = self.W_o(attention_output) return output, attention_weights # ====================== # 测试数据 # ====================== if __name__ == '__main__': torch.manual_seed(42) batch_size = 1 seq_len = 3 d_model = 4 num_heads = 2 # 输入序列 X # 可以理解为 3 个 token,每个 token 是 4 维向量 X = torch.tensor([ [ [1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 0.0, 0.0] ] ]) mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads) output, attention_weights = mha(X, X, X) # 这说明Multi - HeadAttention并不会改变token数量,也不会改变最终的hiddensize。 # 它只是让模型在多个子空间中分别计算注意力,然后再把这些信息融合起来。 print("Input shape:", X.shape) print("Output shape:", output.shape) print("Attention weights shape:", attention_weights.shape) print("\nAttention Weights:") print(attention_weights) print("\nOutput:") print(output)Causal Mask
Causal Mask 的作用可以总结为一句话:保证第 (t) 个 token 只能看到第 (t) 个 token 以及它之前的 token,不能看到未来 token。
import torch import torch.nn.functional as F import math def scaled_dot_product_attention(Q, K, V, mask=None): """ 带 mask 的 Scaled Dot-Product Attention。 Q: [batch_size, seq_len, d_k] K: [batch_size, seq_len, d_k] V: [batch_size, seq_len, d_v] mask: [seq_len, seq_len] """ d_k = Q.size(-1) # 1. 计算 QK^T scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 2. 如果有 mask,则把未来位置设置为一个非常小的值 # mask == 0 的地方表示不能关注 if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 3. softmax 得到注意力权重 attention_weights = F.softmax(scores, dim=-1) # 4. 对 V 加权求和 output = torch.matmul(attention_weights, V) return output, attention_weights def create_causal_mask(seq_len): """ 创建 causal mask。 """ return torch.tril(torch.ones(seq_len, seq_len)) if __name__ == '__main__': Q = torch.tensor([ [[1.0, 0.0, 1.0, 0.0],[0.0, 1.0, 0.0, 1.0],[1.0, 1.0, 0.0, 0.0],[0.0, 0.0, 1.0, 1.0]] ]) K = torch.tensor([ [[1.0, 0.0, 1.0, 0.0],[0.0, 1.0, 0.0, 1.0],[1.0, 1.0, 0.0, 0.0],[0.0, 0.0, 1.0, 1.0]] ]) V = torch.tensor([ [[1.0, 1.0],[2.0, 2.0],[3.0, 3.0],[4.0, 4.0]] ]) seq_len = Q.size(1) mask = create_causal_mask(seq_len) output, attention_weights = scaled_dot_product_attention(Q, K, V, mask) print("Causal Mask:") print(mask) print("\nAttention Weights:") print(attention_weights) print("\nAttention Output:") print(output)Sinusoidal Positional Encoding
Transformer 中的 Self-Attention 本身并不知道 token 的顺序。
例如:
我 喜欢 你
你 喜欢 我
这两个句子的 token 集合很相似,但顺序不同,语义完全不同。因此,Transformer 需要给每个 token 加上位置信息。原始 Transformer 使用的是 Sinusoidal Positional Encoding,也就是正弦余弦位置编码。
import torch import math def positional_encoding(seq_len, d_model): """ 生成 Sinusoidal Positional Encoding。 参数: seq_len: 序列长度 d_model: token 向量维度 返回: pe: [seq_len, d_model] 每一行表示一个位置的位置编码 """ # 初始化位置编码矩阵 pe = torch.zeros(seq_len, d_model) # position: [seq_len, 1] # 表示每个 token 的位置编号 position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # div_term: [d_model / 2] # 控制不同维度上的频率 div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) # 偶数维使用 sin pe[:, 0::2] = torch.sin(position * div_term) # 奇数维使用 cos pe[:, 1::2] = torch.cos(position * div_term) return pe if __name__ == '__main__': # ====================== # 测试数据 # ====================== # 假设 batch_size = 1 # seq_len = 4 # d_model = 8 token_embedding = torch.tensor([ [ [0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80], [0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], [0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 1.00], [0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 1.00, 1.10] ] ]) batch_size, seq_len, d_model = token_embedding.shape pe = positional_encoding(seq_len, d_model) # pe: [seq_len, d_model] # token_embedding: [batch_size, seq_len, d_model] # PyTorch 会自动广播 pe 到 batch 维度 x = token_embedding + pe print("Token Embedding shape:", token_embedding.shape) print("Positional Encoding shape:", pe.shape) print("Final Input shape:", x.shape) print("\nToken Embedding:") print(token_embedding) print("\nPositional Encoding:") print(pe) print("\nToken Embedding + Positional Encoding:") print(x)