用Python从零构建Transformer:实战代码与核心原理拆解
在自然语言处理领域,Transformer架构已经成为现代AI模型的基石。但对于许多开发者来说,仅通过数学公式和理论图解来理解这一复杂架构往往事倍功半。本文将采用完全不同的路径——我们将用Python和PyTorch从零开始实现一个简化但功能完整的Transformer模型,通过可运行的代码来揭示其内部工作机制。
1. 环境准备与基础架构
在开始编码前,我们需要明确几个关键设计决策。我们的简化版Transformer将保留原始架构的核心组件,但会适当减少层数和维度以方便实验。具体配置如下:
import torch import torch.nn as nn import math class TransformerConfig: def __init__(self): self.vocab_size = 10000 # 词汇表大小 self.max_len = 512 # 最大序列长度 self.d_model = 256 # 模型维度 self.n_head = 4 # 注意力头数 self.d_ff = 1024 # 前馈网络维度 self.n_layers = 2 # 编码器/解码器层数 self.dropout = 0.1 # Dropout率关键组件依赖关系可以用以下表格表示:
| 组件 | 输入 | 输出 | 依赖关系 |
|---|---|---|---|
| 词嵌入 | token索引 | d_model维向量 | 无 |
| 位置编码 | 位置索引 | d_model维向量 | 词嵌入 |
| 多头注意力 | Q,K,V矩阵 | 加权特征表示 | 所有前序层 |
| 前馈网络 | 注意力输出 | 变换后特征 | 当前层注意力 |
注意:实际实现中每个组件都需要包含残差连接和层归一化,这是Transformer稳定训练的关键
2. 实现位置编码与词嵌入
位置编码是Transformer理解序列顺序的关键。不同于RNN的递归结构,Transformer需要显式的位置信息。我们采用原始论文中的正弦/余弦函数方案:
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=512): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)]词嵌入层将离散的token索引映射为连续向量空间:
class Embeddings(nn.Module): def __init__(self, config): super().__init__() self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_enc = PositionalEncoding(config.d_model, config.max_len) self.dropout = nn.Dropout(config.dropout) self.scale = math.sqrt(config.d_model) def forward(self, x): token_emb = self.token_emb(x) * self.scale return self.dropout(self.pos_enc(token_emb))位置编码可视化效果:
- 低频维度:变化缓慢,捕获长距离依赖
- 高频维度:变化快速,编码局部位置信息
- 奇偶维度:分别使用正弦和余弦函数
3. 自注意力机制实现
自注意力是Transformer最核心的创新。我们首先实现单头注意力:
def attention(query, key, value, mask=None, dropout=None): d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = scores.softmax(dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn扩展为多头注意力:
class MultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() assert config.d_model % config.n_head == 0 self.d_k = config.d_model // config.n_head self.n_head = config.n_head self.linears = nn.ModuleList([nn.Linear(config.d_model, config.d_model) for _ in range(4)]) self.dropout = nn.Dropout(config.dropout) def forward(self, query, key, value, mask=None): batch_size = query.size(0) query, key, value = [ lin(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) for lin, x in zip(self.linears, (query, key, value)) ] x, attn = attention(query, key, value, mask, self.dropout) x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k) return self.linears[-1](x)注意力模式对比:
| 类型 | 计算复杂度 | 并行度 | 长程依赖 |
|---|---|---|---|
| RNN | O(n) | 低 | 衰减严重 |
| CNN | O(kn) | 中 | 有限感受野 |
| 自注意力 | O(n²) | 高 | 完美捕获 |
4. 前馈网络与编码器层
前馈网络为每个位置提供非线性变换能力:
class PositionwiseFeedForward(nn.Module): def __init__(self, config): super().__init__() self.w_1 = nn.Linear(config.d_model, config.d_ff) self.w_2 = nn.Linear(config.d_ff, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.w_2(self.dropout(self.w_1(x).relu()))完整的编码器层组合了注意力和前馈网络:
class EncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = MultiHeadAttention(config) self.feed_forward = PositionwiseFeedForward(config) self.norm1 = nn.LayerNorm(config.d_model) self.norm2 = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x, mask): attn_out = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_out)) ff_out = self.feed_forward(x) return self.norm2(x + self.dropout(ff_out))梯度流动分析:
- 残差连接:缓解梯度消失
- 层归一化:稳定训练过程
- Dropout:防止过拟合
5. 解码器与完整模型组装
解码器需要处理自注意力和编码器-解码器注意力:
class DecoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = MultiHeadAttention(config) self.src_attn = MultiHeadAttention(config) self.feed_forward = PositionwiseFeedForward(config) self.norm1 = nn.LayerNorm(config.d_model) self.norm2 = nn.LayerNorm(config.d_model) self.norm3 = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x, memory, src_mask, tgt_mask): attn_out = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_out)) attn_out = self.src_attn(x, memory, memory, src_mask) x = self.norm2(x + self.dropout(attn_out)) ff_out = self.feed_forward(x) return self.norm3(x + self.dropout(ff_out))最终组装完整Transformer:
class Transformer(nn.Module): def __init__(self, config): super().__init__() self.encoder = nn.ModuleList( [EncoderLayer(config) for _ in range(config.n_layers)]) self.decoder = nn.ModuleList( [DecoderLayer(config) for _ in range(config.n_layers)]) self.src_embed = Embeddings(config) self.tgt_embed = Embeddings(config) self.generator = nn.Linear(config.d_model, config.vocab_size) def encode(self, src, src_mask): x = self.src_embed(src) for layer in self.encoder: x = layer(x, src_mask) return x def decode(self, tgt, memory, src_mask, tgt_mask): x = self.tgt_embed(tgt) for layer in self.decoder: x = layer(x, memory, src_mask, tgt_mask) return x def forward(self, src, tgt, src_mask, tgt_mask): memory = self.encode(src, src_mask) output = self.decode(tgt, memory, src_mask, tgt_mask) return self.generator(output)6. 训练技巧与实战调试
实现模型只是第一步,正确的训练方法同样重要:
学习率调度采用原始论文的warmup策略:
class WarmupScheduler: def __init__(self, d_model, warmup_steps=4000): self.d_model = d_model self.warmup_steps = warmup_steps self.step_num = 0 def step(self): self.step_num += 1 return (self.d_model ** -0.5) * min( self.step_num ** -0.5, self.step_num * (self.warmup_steps ** -1.5))标签平滑可以提升模型泛化能力:
def label_smoothing_loss(pred, target, smoothing=0.1): n_class = pred.size(-1) log_pred = pred.log_softmax(dim=-1) with torch.no_grad(): true_dist = torch.zeros_like(pred) true_dist.fill_(smoothing / (n_class - 1)) true_dist.scatter_(1, target.unsqueeze(1), 1 - smoothing) return (-true_dist * log_pred).sum(dim=-1).mean()常见调试问题:
梯度爆炸:添加梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)过拟合:调整dropout率或增加正则化
训练不稳定:检查残差连接和层归一化实现
7. 模型评估与应用示例
我们构建一个简单的机器翻译任务来验证模型:
def train_epoch(model, dataloader, optimizer, scheduler): model.train() total_loss = 0 for src, tgt in dataloader: src_mask = (src != 0).unsqueeze(-2) tgt_mask = (tgt != 0).unsqueeze(-2) tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)) optimizer.zero_grad() output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1, :-1]) loss = label_smoothing_loss(output, tgt[:, 1:]) loss.backward() optimizer.step() scheduler.step() total_loss += loss.item() return total_loss / len(dataloader)性能优化技巧:
使用混合精度训练:
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output = model(src, tgt_input, src_mask, tgt_mask) loss = criterion(output, tgt_output) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()实现内存高效的注意力:
torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=True )
通过这个简化实现,我们不仅理解了Transformer的核心机制,还获得了可以扩展的代码基础。实际应用中,可以根据需求调整模型尺寸、添加更多训练技巧或整合到更大系统中。