手把手教你用PyTorch复现STANet:从LEVIR-CD数据集下载到模型训练全流程
2026/5/26 5:23:01 网站建设 项目流程

手把手教你用PyTorch复现STANet:从LEVIR-CD数据集下载到模型训练全流程

遥感图像变化检测是计算机视觉领域的重要应用之一,能够自动识别地表随时间发生的变化。STANet(Spatial-Temporal Attention Network)作为该领域的创新模型,通过引入时空自注意力机制,显著提升了变化检测的精度。本文将带你从零开始,完成STANet模型的完整复现过程,包括环境配置、数据处理、模型训练和结果评估等关键步骤。

1. 环境准备与依赖安装

复现STANet的第一步是搭建合适的开发环境。推荐使用Python 3.8+和PyTorch 1.8+的组合,这是经过验证的稳定版本搭配。

首先创建并激活conda环境:

conda create -n stanet python=3.8 -y conda activate stanet

安装核心依赖包:

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install opencv-python numpy scikit-learn tqdm tensorboard

对于GPU加速,确保你的CUDA版本与PyTorch版本兼容。可以通过以下命令检查CUDA是否可用:

import torch print(torch.cuda.is_available()) # 应返回True print(torch.version.cuda) # 显示CUDA版本

2. 获取与处理LEVIR-CD数据集

LEVIR-CD是一个专门用于建筑物变化检测的大规模数据集,包含637对高分辨率遥感图像(1024×1024像素),时间跨度为5-14年。

数据集下载与解压:

wget https://www.dropbox.com/s/xxx/LEVIR-CD.zip # 替换为实际下载链接 unzip LEVIR-CD.zip -d ./data

数据集通常包含三个子集:

  • train:训练图像对(445对)
  • val:验证图像对(64对)
  • test:测试图像对(128对)

建议的数据预处理流程:

  1. 图像裁剪:将大图分割为256×256的小块,便于模型处理
  2. 数据增强:应用旋转、翻转等操作增加样本多样性
  3. 归一化:将像素值缩放到[0,1]范围

以下是预处理代码示例:

import cv2 import numpy as np from skimage.util import view_as_windows def crop_image(img, patch_size=256, stride=256): patches = view_as_windows(img, (patch_size, patch_size, 3), step=stride) return patches.reshape(-1, patch_size, patch_size, 3) # 示例:处理单张图像 img = cv2.imread('data/train/A/1.png') / 255.0 patches = crop_image(img) print(f"生成{len(patches)}个图像块")

3. STANet模型架构解析与实现

STANet的核心创新在于其空间-时间注意力模块(STA),能够有效捕捉遥感图像中的时空依赖关系。模型主要由以下组件构成:

  1. 双流编码器:分别处理两个时间点的图像
  2. STA模块:计算空间和时间注意力权重
  3. 解码器:将特征图上采样回原始分辨率

关键模型实现代码:

import torch.nn as nn class STA_Module(nn.Module): def __init__(self, in_channels): super().__init__() self.conv_q = nn.Conv2d(in_channels, in_channels//8, 1) self.conv_k = nn.Conv2d(in_channels, in_channels//8, 1) self.conv_v = nn.Conv2d(in_channels, in_channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x1, x2): batch_size, C, H, W = x1.size() # 计算查询、键、值 q1 = self.conv_q(x1).view(batch_size, -1, H*W).permute(0,2,1) k2 = self.conv_k(x2).view(batch_size, -1, H*W) v2 = self.conv_v(x2).view(batch_size, -1, H*W) # 计算注意力权重 energy = torch.bmm(q1, k2) attention = torch.softmax(energy, dim=-1) # 应用注意力 out = torch.bmm(v2, attention.permute(0,2,1)) out = out.view(batch_size, C, H, W) return self.gamma*out + x1

4. 模型训练与超参数调优

训练STANet需要仔细设置超参数,以下是一组经过验证的推荐配置:

超参数推荐值说明
学习率0.001使用Adam优化器
batch_size8根据GPU显存调整
训练轮数100可早停
损失函数BCE+Dice组合损失
输入尺寸256×256匹配数据预处理

训练脚本示例:

from torch.utils.data import DataLoader from torch.optim import Adam from model import STANet # 初始化模型和优化器 model = STANet(in_channels=3).cuda() optimizer = Adam(model.parameters(), lr=0.001) # 自定义组合损失 def criterion(pred, target): bce_loss = nn.BCEWithLogitsLoss()(pred, target) pred_sigmoid = torch.sigmoid(pred) dice_loss = 1 - (2.*(pred_sigmoid*target).sum() + 1e-5) / (pred_sigmoid.sum() + target.sum() + 1e-5) return bce_loss + dice_loss # 训练循环 for epoch in range(100): model.train() for img1, img2, label in train_loader: img1, img2, label = img1.cuda(), img2.cuda(), label.cuda() optimizer.zero_grad() output = model(img1, img2) loss = criterion(output, label) loss.backward() optimizer.step()

5. 常见问题与解决方案

在实际复现过程中,可能会遇到以下典型问题:

  1. 显存不足错误

    • 降低batch_size(可降至4或2)
    • 使用混合精度训练
    • 尝试梯度累积技术
  2. 训练指标波动大

    • 检查学习率是否过高
    • 增加batch_size
    • 添加更多的数据增强
  3. 模型收敛慢

    • 尝试学习率预热
    • 检查数据预处理是否正确
    • 使用预训练编码器

梯度累积示例代码:

accum_steps = 4 # 累积4个batch的梯度 for i, (img1, img2, label) in enumerate(train_loader): # 前向传播和损失计算 loss = criterion(model(img1, img2), label) # 反向传播(累积梯度) loss = loss / accum_steps loss.backward() # 每accum_steps步更新一次参数 if (i+1) % accum_steps == 0: optimizer.step() optimizer.zero_grad()

6. 模型评估与结果可视化

使用测试集评估模型性能时,建议计算以下指标:

  • 精确度(Precision)
  • 召回率(Recall)
  • F1分数
  • IoU(交并比)

评估代码框架:

from sklearn.metrics import precision_score, recall_score, f1_score def evaluate(model, test_loader): model.eval() total_pred, total_true = [], [] with torch.no_grad(): for img1, img2, label in test_loader: output = model(img1.cuda(), img2.cuda()) pred = (torch.sigmoid(output) > 0.5).float() total_pred.append(pred.cpu()) total_true.append(label) pred_all = torch.cat(total_pred) true_all = torch.cat(total_true) precision = precision_score(true_all, pred_all) recall = recall_score(true_all, pred_all) f1 = f1_score(true_all, pred_all) print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

结果可视化对于理解模型性能至关重要。可以使用以下代码生成变化检测图:

import matplotlib.pyplot as plt def visualize(img1, img2, pred, true): fig, axes = plt.subplots(1, 4, figsize=(20,5)) axes[0].imshow(img1) # 时间点1 axes[1].imshow(img2) # 时间点2 axes[2].imshow(pred, cmap='gray') # 预测变化 axes[3].imshow(true, cmap='gray') # 真实变化 plt.show()

在实际项目中,STANet的表现很大程度上取决于数据质量和训练技巧。建议先在小批量数据上验证流程的正确性,再扩展到整个数据集。训练过程中使用TensorBoard监控损失和指标变化,可以帮助及时发现训练问题。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询