手把手教你用PyTorch复现STANet:从LEVIR-CD数据集下载到模型训练全流程
遥感图像变化检测是计算机视觉领域的重要应用之一,能够自动识别地表随时间发生的变化。STANet(Spatial-Temporal Attention Network)作为该领域的创新模型,通过引入时空自注意力机制,显著提升了变化检测的精度。本文将带你从零开始,完成STANet模型的完整复现过程,包括环境配置、数据处理、模型训练和结果评估等关键步骤。
1. 环境准备与依赖安装
复现STANet的第一步是搭建合适的开发环境。推荐使用Python 3.8+和PyTorch 1.8+的组合,这是经过验证的稳定版本搭配。
首先创建并激活conda环境:
conda create -n stanet python=3.8 -y conda activate stanet安装核心依赖包:
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install opencv-python numpy scikit-learn tqdm tensorboard对于GPU加速,确保你的CUDA版本与PyTorch版本兼容。可以通过以下命令检查CUDA是否可用:
import torch print(torch.cuda.is_available()) # 应返回True print(torch.version.cuda) # 显示CUDA版本2. 获取与处理LEVIR-CD数据集
LEVIR-CD是一个专门用于建筑物变化检测的大规模数据集,包含637对高分辨率遥感图像(1024×1024像素),时间跨度为5-14年。
数据集下载与解压:
wget https://www.dropbox.com/s/xxx/LEVIR-CD.zip # 替换为实际下载链接 unzip LEVIR-CD.zip -d ./data数据集通常包含三个子集:
- train:训练图像对(445对)
- val:验证图像对(64对)
- test:测试图像对(128对)
建议的数据预处理流程:
- 图像裁剪:将大图分割为256×256的小块,便于模型处理
- 数据增强:应用旋转、翻转等操作增加样本多样性
- 归一化:将像素值缩放到[0,1]范围
以下是预处理代码示例:
import cv2 import numpy as np from skimage.util import view_as_windows def crop_image(img, patch_size=256, stride=256): patches = view_as_windows(img, (patch_size, patch_size, 3), step=stride) return patches.reshape(-1, patch_size, patch_size, 3) # 示例:处理单张图像 img = cv2.imread('data/train/A/1.png') / 255.0 patches = crop_image(img) print(f"生成{len(patches)}个图像块")3. STANet模型架构解析与实现
STANet的核心创新在于其空间-时间注意力模块(STA),能够有效捕捉遥感图像中的时空依赖关系。模型主要由以下组件构成:
- 双流编码器:分别处理两个时间点的图像
- STA模块:计算空间和时间注意力权重
- 解码器:将特征图上采样回原始分辨率
关键模型实现代码:
import torch.nn as nn class STA_Module(nn.Module): def __init__(self, in_channels): super().__init__() self.conv_q = nn.Conv2d(in_channels, in_channels//8, 1) self.conv_k = nn.Conv2d(in_channels, in_channels//8, 1) self.conv_v = nn.Conv2d(in_channels, in_channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x1, x2): batch_size, C, H, W = x1.size() # 计算查询、键、值 q1 = self.conv_q(x1).view(batch_size, -1, H*W).permute(0,2,1) k2 = self.conv_k(x2).view(batch_size, -1, H*W) v2 = self.conv_v(x2).view(batch_size, -1, H*W) # 计算注意力权重 energy = torch.bmm(q1, k2) attention = torch.softmax(energy, dim=-1) # 应用注意力 out = torch.bmm(v2, attention.permute(0,2,1)) out = out.view(batch_size, C, H, W) return self.gamma*out + x14. 模型训练与超参数调优
训练STANet需要仔细设置超参数,以下是一组经过验证的推荐配置:
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| 学习率 | 0.001 | 使用Adam优化器 |
| batch_size | 8 | 根据GPU显存调整 |
| 训练轮数 | 100 | 可早停 |
| 损失函数 | BCE+Dice | 组合损失 |
| 输入尺寸 | 256×256 | 匹配数据预处理 |
训练脚本示例:
from torch.utils.data import DataLoader from torch.optim import Adam from model import STANet # 初始化模型和优化器 model = STANet(in_channels=3).cuda() optimizer = Adam(model.parameters(), lr=0.001) # 自定义组合损失 def criterion(pred, target): bce_loss = nn.BCEWithLogitsLoss()(pred, target) pred_sigmoid = torch.sigmoid(pred) dice_loss = 1 - (2.*(pred_sigmoid*target).sum() + 1e-5) / (pred_sigmoid.sum() + target.sum() + 1e-5) return bce_loss + dice_loss # 训练循环 for epoch in range(100): model.train() for img1, img2, label in train_loader: img1, img2, label = img1.cuda(), img2.cuda(), label.cuda() optimizer.zero_grad() output = model(img1, img2) loss = criterion(output, label) loss.backward() optimizer.step()5. 常见问题与解决方案
在实际复现过程中,可能会遇到以下典型问题:
显存不足错误
- 降低batch_size(可降至4或2)
- 使用混合精度训练
- 尝试梯度累积技术
训练指标波动大
- 检查学习率是否过高
- 增加batch_size
- 添加更多的数据增强
模型收敛慢
- 尝试学习率预热
- 检查数据预处理是否正确
- 使用预训练编码器
梯度累积示例代码:
accum_steps = 4 # 累积4个batch的梯度 for i, (img1, img2, label) in enumerate(train_loader): # 前向传播和损失计算 loss = criterion(model(img1, img2), label) # 反向传播(累积梯度) loss = loss / accum_steps loss.backward() # 每accum_steps步更新一次参数 if (i+1) % accum_steps == 0: optimizer.step() optimizer.zero_grad()6. 模型评估与结果可视化
使用测试集评估模型性能时,建议计算以下指标:
- 精确度(Precision)
- 召回率(Recall)
- F1分数
- IoU(交并比)
评估代码框架:
from sklearn.metrics import precision_score, recall_score, f1_score def evaluate(model, test_loader): model.eval() total_pred, total_true = [], [] with torch.no_grad(): for img1, img2, label in test_loader: output = model(img1.cuda(), img2.cuda()) pred = (torch.sigmoid(output) > 0.5).float() total_pred.append(pred.cpu()) total_true.append(label) pred_all = torch.cat(total_pred) true_all = torch.cat(total_true) precision = precision_score(true_all, pred_all) recall = recall_score(true_all, pred_all) f1 = f1_score(true_all, pred_all) print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")结果可视化对于理解模型性能至关重要。可以使用以下代码生成变化检测图:
import matplotlib.pyplot as plt def visualize(img1, img2, pred, true): fig, axes = plt.subplots(1, 4, figsize=(20,5)) axes[0].imshow(img1) # 时间点1 axes[1].imshow(img2) # 时间点2 axes[2].imshow(pred, cmap='gray') # 预测变化 axes[3].imshow(true, cmap='gray') # 真实变化 plt.show()在实际项目中,STANet的表现很大程度上取决于数据质量和训练技巧。建议先在小批量数据上验证流程的正确性,再扩展到整个数据集。训练过程中使用TensorBoard监控损失和指标变化,可以帮助及时发现训练问题。