从数据集准备到模型训练:一步步教你用P2PNet跑通SHHA人群计数数据集
2026/6/2 23:11:39 网站建设 项目流程

从数据集准备到模型训练:一步步教你用P2PNet跑通SHHA人群计数数据集

人群计数技术在智慧城市、公共安全等领域具有广泛应用价值。P2PNet作为近年来提出的创新性人群计数模型,以其端到端的点预测能力和轻量化结构受到研究者关注。本文将带您从零开始,完整实现P2PNet在SHHA数据集上的训练流程,涵盖数据准备、环境配置、训练调参等关键环节,特别针对实际项目中容易遇到的路径配置、版本兼容等问题提供解决方案。

1. 环境准备与数据获取

1.1 基础环境配置

P2PNet基于PyTorch框架实现,推荐使用Python 3.8+和CUDA 11.3以上环境。以下是核心依赖的安装命令:

pip install torch==2.1.2 torchvision==0.16.2 pip install opencv-python pandas matplotlib tensorboardX

常见环境问题排查:

  • CUDA版本不匹配:通过nvidia-sminvcc --version检查驱动与运行时版本
  • Pillow版本问题:新版Pillow中ANTIALIAS已被移除,需替换为Resampling.LANCZOS
  • Torchvision API变更:若遇到_new_empty_tensor导入错误,可注释相关代码或降级torchvision版本

1.2 SHHA数据集准备

SHHA(ShanghaiTech Part A)是人群计数领域的基准数据集,包含482张高密度人群图像。建议按以下结构组织数据:

SHHA/ ├── images/ │ ├── IMG_1.jpg │ └── ... ├── txt/ │ ├── GT_IMG_1.txt │ └── ... └── splits/ ├── train.list └── test.list

每个GT文件包含对应图像的标注点坐标,格式示例:

104.5 203.2 78.1 156.7 ...

提示:原始数据集可能使用.mat格式标注,需提前转换为txt格式。可使用以下Python代码片段转换:

import scipy.io mat = scipy.io.loadmat('GT_IMG_1.mat') points = mat['image_info'][0][0][0][0][0] # SHHA特定结构 np.savetxt('GT_IMG_1.txt', points, fmt='%.1f')

2. 数据预处理实战

2.1 自动生成数据列表文件

P2PNet训练需要提供包含图像-标注对路径的列表文件。以下脚本可自动生成训练/测试集列表:

import os def generate_data_list(dataset_path, output_file): image_files = [f for f in os.listdir(f"{dataset_path}/images") if f.endswith('.jpg')] with open(output_file, 'w') as f: for img_file in image_files: img_path = os.path.join(dataset_path, 'images', img_file) txt_file = f"GT_{os.path.splitext(img_file)[0]}.txt" txt_path = os.path.join(dataset_path, 'txt', txt_file) if os.path.exists(txt_path): f.write(f"{img_path} {txt_path}\n") else: print(f"Warning: Missing annotation for {img_file}") # 示例用法 generate_data_list('/data/SHHA/train', 'train.list') generate_data_list('/data/SHHA/test', 'test.list')

关键参数说明:

  • dataset_path:包含images和txt子目录的数据集根目录
  • output_file:生成的列表文件路径(如train.list)

2.2 数据增强策略

P2PNet原始论文采用以下增强组合:

  • 随机水平翻转(p=0.5)
  • 颜色抖动(亮度0.2,对比度0.2,饱和度0.2)
  • 随机裁剪(512×512)

可通过修改datasets.py中的__getitem__方法调整增强策略。推荐保留原始增强方案以获得最佳性能。

3. 模型训练全流程

3.1 训练参数解析

P2PNet的核心训练命令如下:

python train.py \ --data_root /path/to/SHHA \ --dataset_file SHHA \ --epochs 3500 \ --lr_drop 3500 \ --batch_size 8 \ --lr 0.0001 \ --lr_backbone 0.00001 \ --eval_freq 1 \ --output_dir ./logs \ --checkpoints_dir ./weights

关键参数说明:

参数推荐值作用
lr1e-4主学习率
lr_backbone1e-5Backbone学习率
lr_drop3500学习率衰减epoch
eval_freq1每N个epoch验证一次
batch_size8根据GPU显存调整

3.2 训练监控与调优

通过TensorBoard可实时监控训练过程:

tensorboard --logdir=./logs --port=6006

重点观察指标:

  • train_loss:应平稳下降,最终收敛到0.2-0.3
  • val_mae:验证集平均绝对误差,SHHA上通常能达到60-70
  • lr:学习率变化曲线

常见训练问题处理:

  • Loss震荡:减小batch_size或降低学习率
  • 过拟合:增加数据增强或提前停止训练
  • 显存不足:减小batch_size或使用梯度累积

4. 模型推理与部署

4.1 测试集评估

使用训练好的模型进行评估:

python run_test.py \ --weight_path ./weights/best.pth \ --output_dir ./results \ --dataset_path /path/to/SHHA/test

输出包括:

  • 预测密度图(_density.jpg)
  • 点预测结果(_points.txt)
  • 可视化标注(_vis.jpg)

4.2 自定义数据推理

对新的图像进行预测:

from models import build_model from PIL import Image model = build_model(args) checkpoint = torch.load('weights/best.pth', map_location='cpu') model.load_state_dict(checkpoint['model']) img = Image.open('custom_image.jpg').convert('RGB') points = model.predict(img) # 获取预测点坐标

4.3 性能优化技巧

  1. TorchScript导出:将模型转换为TorchScript提升推理速度

    traced_model = torch.jit.trace(model, example_input) traced_model.save("p2pnet.pt")
  2. ONNX转换:支持跨平台部署

    torch.onnx.export(model, dummy_input, "p2pnet.onnx")
  3. TensorRT加速:对CUDA核心进行优化,可获得2-3倍速度提升

实际部署中发现,输入分辨率对推理速度影响显著。将图像缩放至800×600左右可在精度和速度间取得较好平衡。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询