ResNet34实战:用自定义数据集(比如猫狗分类)快速验证你的模型是否真的有效
2026/6/10 19:21:53 网站建设 项目流程

ResNet34实战:从零构建猫狗分类器的快速验证指南

当我们需要验证一个深度学习模型的有效性时,最直接的方式不是阅读论文,而是动手实现它。ResNet34作为计算机视觉领域的经典网络,其残差结构设计巧妙,训练稳定,非常适合作为入门者验证模型能力的起点。本文将带你用PyTorch框架,在Kaggle猫狗数据集上快速搭建并验证ResNet34模型。

1. 环境准备与数据加载

在开始之前,确保你的Python环境已安装PyTorch和Torchvision。如果你使用GPU加速训练,还需要安装CUDA版本的PyTorch:

pip install torch torchvision

Kaggle猫狗数据集包含25,000张图片,其中12,500张猫和12,500张狗。我们可以使用Torchvision的ImageFolder自动加载这种结构化的数据集:

from torchvision import datasets, transforms data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } image_datasets = { x: datasets.ImageFolder(f'data/dogs-vs-cats/{x}', data_transforms[x]) for x in ['train', 'val'] } dataloaders = { x: torch.utils.data.DataLoader( image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val'] }

提示:数据增强是防止过拟合的有效手段,RandomResizedCrop和RandomHorizontalFlip可以增加训练数据的多样性。

2. ResNet34模型架构解析与调整

ResNet34的核心在于残差块(Residual Block)的设计,它通过shortcut连接解决了深层网络梯度消失的问题。标准的ResNet34是为ImageNet的1000类分类设计的,我们需要修改最后的全连接层以适应二分类任务:

import torch.nn as nn from torchvision import models model = models.resnet34(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # 修改为二分类输出 # 如果使用GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device)

模型的主要组件包括:

  • 初始卷积层:7x7卷积,64个输出通道,步长2
  • 最大池化层:3x3池化窗口,步长2
  • 四个残差块组:分别包含3,4,6,3个残差块
  • 全局平均池化:将特征图降维到1x1
  • 全连接层:最终的分类器

3. 训练策略与超参数设置

训练深度学习模型需要精心设置超参数和学习率策略。以下是一个典型的训练配置:

import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练过程中需要监控的关键指标:

指标说明理想范围
训练损失模型在训练集上的误差应持续下降
验证准确率模型在验证集上的分类准确率应逐步提高
训练/验证差距过拟合程度的指示器差距不宜过大

训练循环的基本结构如下:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): for epoch in range(num_epochs): # 训练阶段 model.train() running_loss = 0.0 for inputs, labels in dataloaders['train']: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) # 验证阶段 model.eval() val_loss = 0.0 corrects = 0 with torch.no_grad(): for inputs, labels in dataloaders['val']: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() * inputs.size(0) _, preds = torch.max(outputs, 1) corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(image_datasets['train']) epoch_val_loss = val_loss / len(image_datasets['val']) epoch_acc = corrects.double() / len(image_datasets['val']) print(f'Epoch {epoch}/{num_epochs-1}') print(f'Train Loss: {epoch_loss:.4f} Val Loss: {epoch_val_loss:.4f}') print(f'Val Acc: {epoch_acc:.4f}') scheduler.step() return model

4. 模型评估与可视化分析

训练完成后,我们需要评估模型的实际表现。混淆矩阵是分类任务中最直观的评估工具之一:

from sklearn.metrics import confusion_matrix import seaborn as sns import matplotlib.pyplot as plt def plot_confusion_matrix(model, dataloader): model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for inputs, labels in dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) cm = confusion_matrix(all_labels, all_preds) plt.figure(figsize=(8,6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Cat', 'Dog'], yticklabels=['Cat', 'Dog']) plt.xlabel('Predicted') plt.ylabel('Actual') plt.show() plot_confusion_matrix(model, dataloaders['val'])

注意:良好的模型应该在混淆矩阵的对角线上有较高的数值,而非对角线上的数值应尽可能小。

除了整体准确率,我们还应该关注以下指标:

  • 精确率(Precision):预测为正类中实际为正类的比例
  • 召回率(Recall):实际为正类中被正确预测的比例
  • F1分数:精确率和召回率的调和平均

这些指标可以通过sklearn轻松计算:

from sklearn.metrics import classification_report print(classification_report(all_labels, all_preds, target_names=['Cat', 'Dog']))

5. 模型优化与调参技巧

当基础模型表现不佳时,可以考虑以下优化策略:

  1. 学习率调整

    • 初始学习率太大可能导致震荡,太小则收敛缓慢
    • 使用学习率调度器(如StepLR或ReduceLROnPlateau)动态调整
  2. 正则化技术

    • Dropout:在全连接层前添加Dropout层
    • 权重衰减:在优化器中设置weight_decay参数
    • 早停(Early Stopping):验证集性能不再提升时停止训练
  3. 数据增强扩展

    train_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
  4. 模型微调策略

    • 只训练最后几层:冻结前面的卷积层,仅训练全连接层
    • 分层学习率:不同层使用不同的学习率
    • 渐进解冻:逐步解冻更多层进行训练
# 分层设置学习率示例 optimizer = optim.SGD([ {'params': model.conv1.parameters(), 'lr': 0.0001}, {'params': model.layer1.parameters(), 'lr': 0.0005}, {'params': model.layer2.parameters(), 'lr': 0.001}, {'params': model.fc.parameters(), 'lr': 0.01} ], momentum=0.9)

6. 实际应用与模型部署

训练好的模型可以保存下来供后续使用:

torch.save(model.state_dict(), 'resnet34_cat_dog.pth') # 加载模型 model = models.resnet34(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 2) model.load_state_dict(torch.load('resnet34_cat_dog.pth')) model.eval()

在实际应用中,我们需要处理单张图片的预测:

from PIL import Image def predict_image(image_path): img = Image.open(image_path) img = data_transforms['val'](img).unsqueeze(0) img = img.to(device) with torch.no_grad(): output = model(img) _, pred = torch.max(output, 1) return 'Dog' if pred.item() == 1 else 'Cat' # 测试单张图片 print(predict_image('test_cat.jpg'))

对于生产环境部署,可以考虑以下方案:

  • TorchScript:将模型转换为TorchScript格式,提高推理效率
  • ONNX:转换为开放神经网络交换格式,实现跨框架部署
  • Flask/Django:构建Web API服务
  • 移动端:使用PyTorch Mobile部署到iOS/Android设备
# 转换为TorchScript示例 example_input = torch.rand(1, 3, 224, 224).to(device) traced_script_module = torch.jit.trace(model, example_input) traced_script_module.save("resnet34_cat_dog_script.pt")

在Kaggle猫狗数据集上的实践表明,经过适当调参的ResNet34可以在验证集上达到约97%的准确率。这个过程中最重要的是理解模型每个组件的作用,并通过实验验证各种技术对最终效果的影响。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询