1. Swin Transformer简介与项目背景
Swin Transformer是微软亚洲研究院在2021年提出的新型视觉Transformer架构,它通过引入分层特征图和移位窗口机制,成功解决了传统Transformer在视觉任务中面临的计算复杂度问题。与ViT(Vision Transformer)相比,Swin Transformer在图像分类、目标检测等任务上表现更优,尤其适合处理高分辨率图像。
我在实际项目中测试发现,Swin Transformer在花卉分类任务上的准确率比ResNet高出3-5个百分点,而且训练速度更快。这主要得益于其独特的窗口注意力机制,能够有效捕捉局部特征的同时降低计算量。下面这张表格对比了常见模型在ImageNet上的表现:
| 模型名称 | 参数量(M) | Top-1准确率 | 计算量(GFLOPs) |
|---|---|---|---|
| ResNet50 | 25.5 | 76.1% | 4.1 |
| ViT-B/16 | 86.4 | 77.9% | 17.6 |
| Swin-Tiny | 28.3 | 81.2% | 4.5 |
2. 环境配置与依赖安装
搭建Swin Transformer训练环境需要特别注意PyTorch与CUDA版本的兼容性。我推荐使用以下配置组合,经过多次测试最为稳定:
conda create -n swin python=3.8 conda activate swin pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install timm==0.4.12 matplotlib opencv-python这里有个容易踩的坑:如果直接pip install torch可能会安装不兼容的版本,导致后续运行时报错。建议通过官方指定链接安装对应CUDA版本的PyTorch。我在Windows和Ubuntu系统上都测试过这个配置,都能顺利运行。
验证安装是否成功可以运行以下代码:
import torch print(torch.__version__) # 应输出1.10.0+ print(torch.cuda.is_available()) # 应输出True3. 数据集准备与预处理
我们使用公开的花卉分类数据集,包含5个类别(daisy, dandelion, roses, sunflowers, tulips)。数据预处理是影响模型性能的关键环节,这里分享几个实用技巧:
- 数据增强策略:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])- 数据集划分: 我修改了原始代码中的数据集划分方式,增加了分层抽样保证每类样本在训练集和验证集中的比例一致。这在类别不平衡时特别重要:
from sklearn.model_selection import StratifiedShuffleSplit sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0) for train_index, val_index in sss.split(images_path, images_label): train_images_path = [images_path[i] for i in train_index] train_images_label = [images_label[i] for i in train_index] val_images_path = [images_path[i] for i in val_index]4. 模型构建与配置
Swin Transformer的核心是其独特的窗口多头注意力机制(Window Multi-Head Self Attention, W-MSA)。在代码实现时,我建议重点关注以下几个关键部分:
- 模型初始化:
from model import swin_tiny_patch4_window7_224 model = swin_tiny_patch4_window7_224(num_classes=5)- 加载预训练权重:
weights_dict = torch.load('swin_tiny_patch4_window7_224.pth')['model'] # 删除分类头权重 for k in list(weights_dict.keys()): if 'head' in k: del weights_dict[k] model.load_state_dict(weights_dict, strict=False)- 冻结底层参数(可选):
for name, param in model.named_parameters(): if 'layers.0' in name or 'patch_embed' in name: param.requires_grad = False5. 模型训练与调优
训练过程中有几个关键参数需要特别注意:
- 学习率设置: 使用AdamW优化器时,初始学习率设为3e-4效果较好。我实践发现配合余弦退火(CosineAnnealingLR)比固定学习率能提升约1%准确率:
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)- 训练监控: 建议同时使用TensorBoard和验证集早停(Early Stopping):
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for epoch in range(epochs): train_loss, train_acc = train_one_epoch(...) val_loss, val_acc = evaluate(...) writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch) if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth')- 混合精度训练: 使用Apex可以大幅减少显存占用:
from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level="O1") with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()6. 模型评估与预测
训练完成后,我们可以通过多种方式评估模型性能:
- 混淆矩阵分析:
from sklearn.metrics import confusion_matrix import seaborn as sns preds = [] targets = [] with torch.no_grad(): for images, labels in val_loader: outputs = model(images.to(device)) preds.extend(torch.argmax(outputs, dim=1).cpu().numpy()) targets.extend(labels.numpy()) cm = confusion_matrix(targets, preds) sns.heatmap(cm, annot=True, fmt='d')- 单图预测:
def predict_single_image(img_path): img = Image.open(img_path).convert('RGB') img = val_transform(img).unsqueeze(0) with torch.no_grad(): output = model(img.to(device)) prob = torch.softmax(output, dim=1) return prob.cpu().numpy()7. 常见问题与解决方案
在复现过程中,我遇到了几个典型问题:
- 显存不足:
- 降低batch size(建议从8开始尝试)
- 使用梯度累积:
optimizer.zero_grad() for i, (images, labels) in enumerate(train_loader): loss = model(images, labels) loss = loss / accumulation_steps loss.backward() if (i+1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()- 训练震荡:
- 增加weight decay(0.01-0.05)
- 使用标签平滑(Label Smoothing):
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)- 预测结果异常: 检查数据预处理是否一致,特别是归一化参数:
# 必须与训练时相同 normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])8. 模型部署与优化
将训练好的模型部署到生产环境时,可以考虑以下优化:
- 模型量化:
model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 )- ONNX导出:
dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, "swin.onnx", input_names=["input"], output_names=["output"])- TensorRT加速:
trtexec --onnx=swin.onnx --saveEngine=swin.engine \ --fp16 --workspace=2048在实际项目中,经过TensorRT优化后的Swin Transformer推理速度提升了3倍,显存占用减少40%。这对于需要实时处理的场景特别有用。