医学图像分割实战:手把手教你用PyTorch复现TransUNet(附完整代码与数据集处理)
医学图像分割一直是计算机视觉领域的重要研究方向,尤其在临床诊断和治疗规划中发挥着关键作用。传统的U-Net架构虽然在医学图像分割中表现出色,但其卷积操作的局部感受野限制了全局信息的捕获能力。而Transformer结构的引入,为解决这一问题提供了新的思路。本文将带你从零开始,用PyTorch完整复现TransUNet模型,并针对医学图像的特殊性进行优化。
1. 环境配置与准备工作
在开始构建TransUNet之前,我们需要确保开发环境配置正确。以下是推荐的配置方案:
conda create -n transunet python=3.8 conda activate transunet pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html pip install nibabel scikit-image tqdm tensorboard注意:CUDA版本需要与显卡驱动兼容,建议使用NVIDIA官方文档检查兼容性。
关键依赖库及其作用:
| 库名称 | 版本要求 | 主要用途 |
|---|---|---|
| PyTorch | ≥1.8.0 | 深度学习框架基础 |
| nibabel | ≥3.2.1 | 医学图像格式处理 |
| scikit-image | ≥0.18.0 | 图像预处理工具 |
| tensorboard | ≥2.5.0 | 训练过程可视化 |
对于硬件配置,建议至少满足以下条件:
- GPU:NVIDIA显卡,显存≥8GB(处理3D医学图像建议≥16GB)
- 内存:≥16GB
- 存储:SSD硬盘,预留≥50GB空间用于数据集缓存
2. 医学图像数据预处理
医学图像(如CT、MRI)与自然图像有很大不同,需要特殊处理:
import nibabel as nib import numpy as np def load_nii_file(file_path): """加载NIfTI格式的医学图像""" img = nib.load(file_path) data = img.get_fdata() # 标准化到[0,1]范围 data = (data - np.min(data)) / (np.max(data) - np.min(data)) return np.transpose(data, (2, 0, 1)) # 调整为通道优先格式常见医学图像预处理流程:
- 重采样:统一不同设备采集的图像分辨率
- 窗宽窗位调整:突出特定组织的对比度
- 标准化:消除不同扫描仪间的差异
- 数据增强:旋转、翻转等操作增加数据多样性
对于Synapse多器官分割数据集,建议采用以下预处理步骤:
def preprocess_synapse(data, label): # 1. 重采样到统一分辨率(256x256) data = resize(data, (256, 256), order=3, preserve_range=True) label = resize(label, (256, 256), order=0, preserve_range=True) # 2. 强度归一化 data = (data - data.mean()) / data.std() # 3. 随机数据增强 if np.random.rand() > 0.5: data, label = random_rotate(data, label, angle_range=(-15,15)) if np.random.rand() > 0.5: data, label = random_flip(data, label) return data, label3. TransUNet模型架构实现
TransUNet的核心创新在于将CNN的局部特征提取能力与Transformer的全局建模能力相结合。下面我们分模块实现:
3.1 混合编码器实现
import torch import torch.nn as nn from einops import rearrange class TransformerEncoder(nn.Module): def __init__(self, dim, depth, heads, mlp_dim, dropout=0.1): super().__init__() self.layers = nn.ModuleList([ TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth) ]) def forward(self, x): for layer in self.layers: x = layer(x) return x class HybridEncoder(nn.Module): def __init__(self, in_chans=3, embed_dim=768, depth=12, num_heads=12): super().__init__() # CNN特征提取部分(修改后的ResNet50) self.cnn_backbone = ModifiedResNet50() # Transformer部分 self.proj = nn.Conv2d(1024, embed_dim, kernel_size=1) self.transformer = TransformerEncoder( dim=embed_dim, depth=depth, heads=num_heads, mlp_dim=3072 ) def forward(self, x): # CNN特征提取 features = self.cnn_backbone(x) # [B, 1024, 14, 14] # 投影到Transformer维度 x = self.proj(features) # [B, 768, 14, 14] # 序列化并添加位置编码 b, c, h, w = x.shape x = rearrange(x, 'b c h w -> b (h w) c') x = x + self.pos_embedding # Transformer编码 x = self.transformer(x) # 恢复空间维度 x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) return x, features3.2 解码器实现
解码器采用类似U-Net的结构,但加入了Transformer编码特征的融合:
class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, skip_channels=0): super().__init__() self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.conv = nn.Sequential( nn.Conv2d(out_channels+skip_channels, out_channels, 3, padding=1), nn.GroupNorm(16, out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.GroupNorm(16, out_channels), nn.ReLU(inplace=True) ) def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) return self.conv(x) class TransUNetDecoder(nn.Module): def __init__(self, num_classes, embed_dim=768): super().__init__() self.decoder1 = DecoderBlock(embed_dim, 512, skip_channels=512) self.decoder2 = DecoderBlock(512, 256, skip_channels=256) self.decoder3 = DecoderBlock(256, 128, skip_channels=128) self.decoder4 = DecoderBlock(128, 64, skip_channels=64) self.final = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x, features): # features是CNN编码器各阶段输出 x = self.decoder1(x, features[3]) # 1/16 -> 1/8 x = self.decoder2(x, features[2]) # 1/8 -> 1/4 x = self.decoder3(x, features[1]) # 1/4 -> 1/2 x = self.decoder4(x, features[0]) # 1/2 -> 1/1 return self.final(x)4. 训练策略与技巧
医学图像分割训练需要特别注意以下几点:
4.1 损失函数选择
class DiceLoss(nn.Module): def __init__(self, smooth=1e-5): super().__init__() self.smooth = smooth def forward(self, pred, target): pred = pred.sigmoid() intersection = (pred * target).sum() union = pred.sum() + target.sum() dice = (2. * intersection + self.smooth) / (union + self.smooth) return 1 - dice class CombinedLoss(nn.Module): def __init__(self, alpha=0.5): super().__init__() self.alpha = alpha self.dice = DiceLoss() self.bce = nn.BCEWithLogitsLoss() def forward(self, pred, target): return self.alpha * self.dice(pred, target) + (1-self.alpha) * self.bce(pred, target)4.2 学习率调度策略
def get_optimizer(model, lr=1e-4, weight_decay=1e-4): param_groups = [ {'params': [p for n, p in model.named_parameters() if 'backbone' in n], 'lr': lr/10}, {'params': [p for n, p in model.named_parameters() if 'backbone' not in n], 'lr': lr} ] return torch.optim.AdamW(param_groups, weight_decay=weight_decay) def get_scheduler(optimizer, num_epochs): return torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs, eta_min=1e-6 )4.3 训练过程中的关键技巧
- 混合精度训练:减少显存占用,加快训练速度
- 梯度裁剪:防止梯度爆炸
- 早停机制:基于验证集性能停止训练
- 模型EMA:使用指数移动平均提升模型稳定性
from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() for epoch in range(epochs): model.train() for images, masks in train_loader: optimizer.zero_grad() with autocast(): outputs = model(images) loss = criterion(outputs, masks) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step()5. 模型评估与结果可视化
5.1 评估指标实现
def compute_iou(pred, target, num_classes): ious = [] pred = pred.argmax(1) for cls in range(num_classes): pred_inds = pred == cls target_inds = target == cls intersection = (pred_inds & target_inds).sum().float() union = (pred_inds | target_inds).sum().float() if union == 0: ious.append(float('nan')) else: ious.append((intersection / union).item()) return np.nanmean(ious) def compute_dice(pred, target, num_classes): dices = [] pred = pred.argmax(1) for cls in range(num_classes): pred_inds = pred == cls target_inds = target == cls intersection = (pred_inds & target_inds).sum().float() if (pred_inds.sum() + target_inds.sum()) == 0: dices.append(float('nan')) else: dices.append((2 * intersection) / (pred_inds.sum() + target_inds.sum()).item()) return np.nanmean(dices)5.2 结果可视化
import matplotlib.pyplot as plt def visualize_results(image, mask, pred, save_path=None): fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(image, cmap='gray') axes[0].set_title('Input Image') axes[0].axis('off') axes[1].imshow(mask, cmap='jet') axes[1].set_title('Ground Truth') axes[1].axis('off') axes[2].imshow(pred.argmax(0), cmap='jet') axes[2].set_title('Prediction') axes[2].axis('off') if save_path: plt.savefig(save_path, bbox_inches='tight', dpi=300) plt.close()在实际项目中,我们发现TransUNet在小型器官(如胰腺)的分割上表现尤为突出,这得益于Transformer捕获的全局上下文信息。一个常见的调优方向是调整编码器中CNN与Transformer的比例,找到适合特定数据集的平衡点。