从数据集准备到模型训练:一步步教你用P2PNet跑通SHHA人群计数数据集
人群计数技术在智慧城市、公共安全等领域具有广泛应用价值。P2PNet作为近年来提出的创新性人群计数模型,以其端到端的点预测能力和轻量化结构受到研究者关注。本文将带您从零开始,完整实现P2PNet在SHHA数据集上的训练流程,涵盖数据准备、环境配置、训练调参等关键环节,特别针对实际项目中容易遇到的路径配置、版本兼容等问题提供解决方案。
1. 环境准备与数据获取
1.1 基础环境配置
P2PNet基于PyTorch框架实现,推荐使用Python 3.8+和CUDA 11.3以上环境。以下是核心依赖的安装命令:
pip install torch==2.1.2 torchvision==0.16.2 pip install opencv-python pandas matplotlib tensorboardX常见环境问题排查:
- CUDA版本不匹配:通过
nvidia-smi和nvcc --version检查驱动与运行时版本 - Pillow版本问题:新版Pillow中
ANTIALIAS已被移除,需替换为Resampling.LANCZOS - Torchvision API变更:若遇到
_new_empty_tensor导入错误,可注释相关代码或降级torchvision版本
1.2 SHHA数据集准备
SHHA(ShanghaiTech Part A)是人群计数领域的基准数据集,包含482张高密度人群图像。建议按以下结构组织数据:
SHHA/ ├── images/ │ ├── IMG_1.jpg │ └── ... ├── txt/ │ ├── GT_IMG_1.txt │ └── ... └── splits/ ├── train.list └── test.list每个GT文件包含对应图像的标注点坐标,格式示例:
104.5 203.2 78.1 156.7 ...提示:原始数据集可能使用.mat格式标注,需提前转换为txt格式。可使用以下Python代码片段转换:
import scipy.io mat = scipy.io.loadmat('GT_IMG_1.mat') points = mat['image_info'][0][0][0][0][0] # SHHA特定结构 np.savetxt('GT_IMG_1.txt', points, fmt='%.1f')2. 数据预处理实战
2.1 自动生成数据列表文件
P2PNet训练需要提供包含图像-标注对路径的列表文件。以下脚本可自动生成训练/测试集列表:
import os def generate_data_list(dataset_path, output_file): image_files = [f for f in os.listdir(f"{dataset_path}/images") if f.endswith('.jpg')] with open(output_file, 'w') as f: for img_file in image_files: img_path = os.path.join(dataset_path, 'images', img_file) txt_file = f"GT_{os.path.splitext(img_file)[0]}.txt" txt_path = os.path.join(dataset_path, 'txt', txt_file) if os.path.exists(txt_path): f.write(f"{img_path} {txt_path}\n") else: print(f"Warning: Missing annotation for {img_file}") # 示例用法 generate_data_list('/data/SHHA/train', 'train.list') generate_data_list('/data/SHHA/test', 'test.list')关键参数说明:
dataset_path:包含images和txt子目录的数据集根目录output_file:生成的列表文件路径(如train.list)
2.2 数据增强策略
P2PNet原始论文采用以下增强组合:
- 随机水平翻转(p=0.5)
- 颜色抖动(亮度0.2,对比度0.2,饱和度0.2)
- 随机裁剪(512×512)
可通过修改datasets.py中的__getitem__方法调整增强策略。推荐保留原始增强方案以获得最佳性能。
3. 模型训练全流程
3.1 训练参数解析
P2PNet的核心训练命令如下:
python train.py \ --data_root /path/to/SHHA \ --dataset_file SHHA \ --epochs 3500 \ --lr_drop 3500 \ --batch_size 8 \ --lr 0.0001 \ --lr_backbone 0.00001 \ --eval_freq 1 \ --output_dir ./logs \ --checkpoints_dir ./weights关键参数说明:
| 参数 | 推荐值 | 作用 |
|---|---|---|
| lr | 1e-4 | 主学习率 |
| lr_backbone | 1e-5 | Backbone学习率 |
| lr_drop | 3500 | 学习率衰减epoch |
| eval_freq | 1 | 每N个epoch验证一次 |
| batch_size | 8 | 根据GPU显存调整 |
3.2 训练监控与调优
通过TensorBoard可实时监控训练过程:
tensorboard --logdir=./logs --port=6006重点观察指标:
- train_loss:应平稳下降,最终收敛到0.2-0.3
- val_mae:验证集平均绝对误差,SHHA上通常能达到60-70
- lr:学习率变化曲线
常见训练问题处理:
- Loss震荡:减小batch_size或降低学习率
- 过拟合:增加数据增强或提前停止训练
- 显存不足:减小batch_size或使用梯度累积
4. 模型推理与部署
4.1 测试集评估
使用训练好的模型进行评估:
python run_test.py \ --weight_path ./weights/best.pth \ --output_dir ./results \ --dataset_path /path/to/SHHA/test输出包括:
- 预测密度图(_density.jpg)
- 点预测结果(_points.txt)
- 可视化标注(_vis.jpg)
4.2 自定义数据推理
对新的图像进行预测:
from models import build_model from PIL import Image model = build_model(args) checkpoint = torch.load('weights/best.pth', map_location='cpu') model.load_state_dict(checkpoint['model']) img = Image.open('custom_image.jpg').convert('RGB') points = model.predict(img) # 获取预测点坐标4.3 性能优化技巧
TorchScript导出:将模型转换为TorchScript提升推理速度
traced_model = torch.jit.trace(model, example_input) traced_model.save("p2pnet.pt")ONNX转换:支持跨平台部署
torch.onnx.export(model, dummy_input, "p2pnet.onnx")TensorRT加速:对CUDA核心进行优化,可获得2-3倍速度提升
实际部署中发现,输入分辨率对推理速度影响显著。将图像缩放至800×600左右可在精度和速度间取得较好平衡。