027、Detect 检测头源码逐行解析:分类分支和回归分支的 reshape permute concat 操作
从一次诡异的mAP下降说起
上周有个读者私信我,说他在YOLOv8上改了个检测头,把分类分支和回归分支的输出维度调了一下,结果mAP直接掉了15个点。他贴了代码给我看,我一眼就发现问题出在reshape和permute的顺序上——这哥们把维度搞反了,导致分类和回归的特征图对不上号。
这种坑我当年在YOLOv5上就踩过,那时候debug了整整两天,最后发现是permute之后忘了contiguous,导致后续的view操作报错。今天咱们就把Detect检测头里这些reshape、permute、concat的骚操作掰开揉碎,看看它们到底在干什么。
检测头的整体结构:两条分支的宿命
先看YOLOv5/v8检测头的核心逻辑。输入是来自Neck的多尺度特征图,比如P3、P4、P5,每个特征图经过一个卷积层后,会分裂成两条分支:
- 分类分支:输出形状是
[batch, num_classes * num_anchors, H, W] - 回归分支:输出形状是
[batch, 4 * num_anchors, H, W](4代表x,y,w,h)
这里有个关键点:YOLOv5的anchor-based检测头里,每个特征图位置会预测多个anchor,所以通道数要乘以anchor数量。YOLOv8虽然去掉了anchor,但结构上依然保留了这种“每个位置预测多个候选框”的设计思路。
源码逐行拆解:从卷积输出到最终预测
咱们直接看ultralytics里Detect类的forward方法,我加了口语化注释,你们感受一下:
defforward(self,x):# x是个list,里面是三个尺度的特征图 [P3, P4, P5]# 每个特征图形状: [batch, channels, H, W]shape=x[0].shape# 取第一个尺度的形状,后面要用foriinrange(self.nl):# nl=3,三个检测层# 先过卷积,把通道数压缩到指定维度# 这里self.cv2是分类卷积,self.cv3是回归卷积x[i]=torch.cat([self.cv2[i](x[i]),self.cv3[i](x[i])],dim=1)# 注意!这里concat是在通道维度上拼的# 分类分支在前,回归分支在后看到这里有人会问:为什么不分别处理两个分支,非要先拼起来?这是为了后续统一做reshape和permute,减少循环次数。但代价是后面必须小心翼翼地切分。
接下来是核心操作——把特征图从4D张量变成3D的预测结果:
# 把list转成tensor,形状变成 [batch, nl, channels, H, W]# 这里nl=3,代表三个检测层x=torch.cat([xi.view(shape[0],self.no,-1)forxiinx],dim=2)# view操作:把每个特征图的H*W展平到最后一维# 结果形状: [batch, no, H*W*nl] 注意这里no = num_classes + 4 + 1(置信度)这个view操作有个隐藏陷阱:如果特征图在之前做过transpose或permute,内存布局可能不是连续的,直接view会报错。所以YOLOv5里在view之前会加一个.contiguous(),但YOLOv8的代码里没显式写,因为前面的卷积输出默认是连续的。不过你要是自己魔改网络结构,比如在中间插了个permute,那就必须手动加contiguous了。
分类分支和回归分支的分离与重塑
现在x的形状是[batch, no, total_anchors],其中total_anchors = H1W1 + H2W2 + H3*W3。接下来要把它拆成分类和回归两部分:
# 切分:前self.nc个通道是分类,后面4个是回归,最后1个是置信度# 注意这里self.no = self.nc + 4 + 1x=x.split([self.nc,4,1],dim=1)# 返回三个部分: (分类, 回归, 置信度)# 每个部分形状: [batch, 通道数, total_anchors]这里split的维度是1,也就是通道维。分类分支取前nc个通道,回归分支取中间4个通道,置信度取最后1个通道。这个顺序千万别搞反,否则分类和回归的特征图会错位。
接下来是permute操作,把维度顺序从[batch, channels, anchors]变成[batch, anchors, channels]:
# 对每个部分做permute,把anchors维度放到中间# 这样后续计算loss时可以直接按anchor索引cls_pred=x[0].permute(0,2,1).contiguous()# [batch, anchors, nc]reg_pred=x[1].permute(0,2,1).contiguous()# [batch, anchors, 4]obj_pred=x[2].permute(0,2,1).contiguous()# [batch, anchors, 1]这里为什么要permute?因为后续计算分类损失时,我们需要对每个anchor独立计算交叉熵,而交叉熵函数期望的输入形状是[batch, anchors, num_classes],所以要把anchors维度提到第二维。回归分支同理,需要[batch, anchors, 4]的形状来计算CIoU损失。
那个让我debug两天的contiguous陷阱
你们注意看,上面permute之后我加了.contiguous()。这个操作在YOLOv5的官方代码里是有的,但YOLOv8的某些版本里没写。为什么?
因为permute只是改变了张量的视图,并没有改变内存布局。如果你permute之后直接做view或者reshape,PyTorch会报错说“view requires the tensor to be contiguous”。但如果你只是做矩阵运算或者切片,不涉及view,那contiguous就不是必须的。
我当年踩的坑是这样的:在YOLOv5的Detect模块里,我自作聪明地删掉了contiguous,想着能省点显存。结果训练到一半,loss突然变成nan,排查了半天才发现是view操作报错。从那以后,我只要做了permute或transpose,后面必加contiguous,宁可多花点内存,也不给自己挖坑。
多尺度特征图的concat逻辑
回到最开始的concat操作。YOLO检测头会把三个尺度的特征图拼在一起,但注意拼的是展平后的anchor维度,而不是通道维度:
# 每个尺度的特征图先view成 [batch, no, H*W]# 然后在最后一维(anchor维)上concatx=torch.cat([xi.view(shape[0],self.no,-1)forxiinx],dim=2)这样做的目的是把所有尺度的预测结果统一到一个张量里,方便后续的NMS和loss计算。但有个细节:不同尺度的特征图大小不同,比如P3是80x80,P4是40x40,P5是20x20,展平后分别是6400、1600、400个anchor,concat之后总anchor数是8400。
这个8400就是YOLOv5/v8默认的anchor数量。如果你改了网络结构,比如增加了检测层,那这个数字会变,对应的后处理代码也要跟着改。
分类分支和回归分支的维度对齐
最后说一个容易忽略的点:分类分支和回归分支在concat之前,它们的通道数必须匹配。分类分支的输出通道是num_classes * num_anchors,回归分支是4 * num_anchors,置信度分支是1 * num_anchors。
如果你自己魔改检测头,比如想增加回归分支的输出维度(比如加一个角度预测),那必须同步修改self.no的值,并且保证split的时候维度对应。否则split出来的张量形状不匹配,后续的permute会报维度错误。
个人经验:调试检测头的三板斧
打印形状:在forward里加一行
print(x[i].shape),看每个尺度的特征图经过卷积后的形状是否符合预期。特别是当你改了self.cv2或self.cv3的卷积核数量时,这一步能快速定位维度不匹配的问题。检查contiguous:如果遇到“view requires contiguous tensor”的错误,先检查最近有没有做permute或transpose。有的话,在view之前加
.contiguous()。如果加了还报错,那可能是view的维度计算错了。验证分类和回归的对齐:在split之后,分别取分类和回归的第一个anchor,看它们的数值是否合理。比如分类分支的softmax输出应该接近均匀分布(训练初期),回归分支的x,y应该接近0.5(因为YOLO的坐标是归一化的)。如果发现分类和回归的数值范围差很多,那大概率是split的维度搞反了。
检测头是整个YOLO网络里最容易出bug的地方,但也是最值得花时间理解的部分。搞懂了reshape、permute、concat这些操作背后的维度变换逻辑,你就能自由地魔改检测头,比如加新的预测分支、改损失函数、甚至做多任务学习。下次遇到mAP下降的问题,先检查维度对不对,别像我当年那样傻乎乎地调学习率调了两天。