从‘电池’到‘易拉罐’:手把手教你用YOLOv8训练一个垃圾分类模型(数据集已备好)
在环保科技和智慧城市快速发展的今天,计算机视觉技术正成为垃圾分类智能化的重要推手。YOLOv8作为目标检测领域的新星,以其卓越的实时性和准确性,成为开发垃圾分类系统的理想选择。本文将带您从零开始,完成一个能够识别电池、易拉罐等常见生活垃圾的智能模型训练全流程。
无论您是刚接触AI的爱好者,还是希望拓展环保科技应用的开发者,这篇实战指南都将提供清晰的操作路径。我们将使用一个精心标注的YOLO格式数据集,涵盖有害垃圾、可回收物、厨余垃圾和其他垃圾四大类别,包含电池、药品包装、易拉罐、矿泉水瓶等典型物品。
1. 环境准备与数据检查
1.1 搭建YOLOv8开发环境
YOLOv8基于PyTorch框架,环境配置相对简单。推荐使用Python 3.8+和CUDA 11.3+(如有GPU)的组合:
conda create -n yolo_env python=3.8 conda activate yolo_env pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install ultralytics验证安装是否成功:
import torch from ultralytics import YOLO print(torch.__version__) # 应显示1.12.0+ print(YOLO('yolov8n.pt')) # 测试YOLO模型加载1.2 数据集结构解析
下载并解压数据集后,您将看到如下目录结构:
垃圾数据集/ ├── images/ │ ├── train/ # 训练集图片 │ ├── val/ # 验证集图片 │ └── test/ # 测试集图片 ├── labels/ │ ├── train/ # 训练集标注 │ ├── val/ # 验证集标注 │ └── test/ # 测试集标注 └── data.yaml # 数据集配置文件使用以下代码快速检查数据集完整性:
from pathlib import Path def check_dataset(dataset_path): img_train = len(list((Path(dataset_path)/"images"/"train").glob("*.jpg"))) label_train = len(list((Path(dataset_path)/"labels"/"train").glob("*.txt"))) assert img_train == label_train, "训练集图片与标注数量不匹配" # 同样检查val和test集 print(f"数据集检查通过:训练集{img_train}张,验证集...")2. 模型训练与调优
2.1 配置data.yaml
修改data.yaml文件以适配您的数据集:
path: /path/to/垃圾数据集 train: images/train val: images/val test: images/test names: 0: 有害垃圾 1: 可回收物 2: 厨余垃圾 3: 其他垃圾2.2 启动基础训练
使用YOLOv8s(小型模型)开始训练:
from ultralytics import YOLO model = YOLO('yolov8s.pt') # 加载预训练模型 results = model.train( data='data.yaml', epochs=100, imgsz=640, batch=16, device='0' # 使用GPU 0 )关键参数说明:
| 参数 | 推荐值 | 作用 |
|---|---|---|
| epochs | 50-300 | 训练轮次 |
| patience | 20 | 早停等待轮次 |
| imgsz | 640 | 输入图像尺寸 |
| batch | 8-32 | 批处理大小 |
| lr0 | 0.01 | 初始学习率 |
2.3 垃圾分类的特殊调优策略
生活垃圾检测面临独特挑战:
小物体检测优化:
results = model.train( ... fl_gamma=1.5, # 聚焦小物体 box=7.5, # 加大box loss权重 cls=0.5 # 适当降低分类权重 )类别不平衡处理:
- 使用过采样策略:
from torch.utils.data import WeightedRandomSampler # 计算每个类别的样本权重数据增强配置:
augment: True hsv_h: 0.015 # 色相增强 hsv_s: 0.7 # 饱和度增强 hsv_v: 0.4 # 明度增强 translate: 0.1 # 平移增强
3. 模型评估与结果分析
3.1 性能指标解读
训练完成后,查看关键评估指标:
metrics = model.val() # 在验证集上评估 print(f"mAP50-95: {metrics.box.map}") # 平均精度 print(f"各类别AP: {metrics.box.maps}")典型垃圾分类模型的性能基准:
| 类别 | AP50 | 常见误检 |
|---|---|---|
| 有害垃圾 | 0.85 | 易与药品混淆 |
| 可回收物 | 0.92 | 不同材质易拉罐 |
| 厨余垃圾 | 0.78 | 形状变化大 |
| 其他垃圾 | 0.81 | 背景干扰多 |
3.2 可视化分析工具
使用YOLOv8内置工具生成分析图表:
yolo detect val model=best.pt data=data.yaml plots=True这将生成:
- 混淆矩阵(confusion_matrix.png)
- F1曲线(F1_curve.png)
- PR曲线(PR_curve.png)
提示:重点关注有害垃圾的召回率,确保电池等危险物品不被漏检
4. 模型部署与实用技巧
4.1 模型导出与优化
将PyTorch模型转换为ONNX格式以便部署:
model.export(format='onnx', dynamic=True, simplify=True)针对不同部署场景的推荐格式:
| 平台 | 推荐格式 | 优化方法 |
|---|---|---|
| 移动端 | TensorRT | FP16量化 |
| 边缘设备 | ONNX | 动态轴 |
| Web服务 | TorchScript | 脚本优化 |
4.2 实时检测代码示例
使用训练好的模型进行实时检测:
import cv2 from ultralytics import YOLO model = YOLO('best.pt') cap = cv2.VideoCapture(0) # 摄像头输入 while True: ret, frame = cap.read() results = model(frame, stream=True) for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) cls = int(box.cls) conf = float(box.conf) label = f"{model.names[cls]} {conf:.2f}" cv2.rectangle(frame, (x1,y1), (x2,y2), (0,255,0), 2) cv2.putText(frame, label, (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2) cv2.imshow('垃圾分类检测', frame) if cv2.waitKey(1) == ord('q'): break4.3 实际应用中的优化建议
光照条件处理:
# 在推理前添加自适应直方图均衡化 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2YCrCb) y, cr, cb = cv2.split(frame) y = cv2.equalizeHist(y) frame = cv2.merge((y, cr, cb))多尺度检测策略:
results = model(frame, imgsz=[320, 640], augment=True)后处理优化:
# 使用NMS和置信度过滤 results = model(frame, conf=0.5, iou=0.45)
在部署到智能垃圾桶等实际场景时,建议添加以下功能模块:
- 检测结果的时间平滑处理
- 分类结果的逻辑校验(如易拉罐不应出现在有害垃圾)
- 用户交互反馈机制