Focal Loss初步试验

Link:

看到知乎上对 Focal Loss for Dense Object Detection 这篇论文的介绍,发现是针对所谓的one-stage的检测框架的数据不平衡问题的。考虑到目前我们使用的也是one-stage的检测框架,于是就打出来看了下,本来按照自己最近的慵懒状态,肯定懒得动手。不过看的也是囫囵吞枣,自然应该有些理解错误和遗漏。

目前训练one-stage的检测,确实存在数据严重不平衡的问题,我们采取的方法都是对不同类加权重,不过这个权重的大小需要人工指定,换一个数据集就需要重新针对不同数据集的样本分布来试权重,一般都要跑3、5次才能有个让自己满意的值。目前也有一个新的针对数据分布自适应调整权重的版本,但是目前还是没有太明显效果,毕竟也是刚用上,熟悉调整一些参数还需要时间和精力。自然希望Focal Loss能更好的解决这个问题。

文章主要是挑了感兴趣的FL看了,剩下的就简单扫了扫。里面提了一下RetinaNet,这个基本one-stage的思路都是这样,尽量获得多尺度信息。这里的是每个尺度都出BBox。目前我们自己的思路是综合所有多尺度信息后再统一出BBox,之前倒是试过训练的时候做类似的工作,希望以此引导网络更好的区分不同尺度下需要关注的东西,但是最后嫌这样训练麻烦效果也没好多少没怎么用了。网络输出都是类似R-CNN,不多废话了。

以下就是FL相关的:

FL的思路就是期望解决数据分布不平衡的问题,其实更准确说的是数据的难度不平衡的问题。FCN One-stage Detector其实都可以看作滑动窗口训练,其中背景类别其实有很多是难度比较小的,比如检测啤酒在桌子上的任务,桌子的纹理颜色都是很统一的,而且难度比较小,一般光一个颜色基本就能过滤掉了(毕竟啤酒瓶子颜色的桌子,我本人是…比较难以理解的)。

Cascade的训练就很适应这样的问题,最简单的案例第一层分类器已经搞定了,后面的分类器不会受这一部分数据的困扰,或者浪费算力在训练这些数据上面。对于网络来说,最简单的案例一般都是很快就0.9的置信度了,剩下的都是难度大的。主要问题是滑动窗口没有过滤机制,背景里面一堆堆难度简单的数据每次都会再加入训练,网络还需要考虑这些案例,浪费算力(其实应该是浪费每次更新的梯度)。

当然,一般来说简单案例因为快速就搞定了,在loss里面的权重其实自动就降下来,不过这篇论文里面说到其实大部分loss还是被这些简单案例占住了。比较好奇的是这个结论是怎么得来的,等我后续再看看能不能再挖一挖细节,从结论来看应该就是架不住简单案例的数量实在太多。既然有了结论,解决的方法还是很朴素的,对应案例的那个loss再*(1-p),这样简单案例的权重会数量级的降低,实在不行还可以继续多乘几次。如此来看,如果自己的任务里面简单案例实在太多可以进一步增加论文里面的gamma。

我用Pytorch简单实现了一下FocalLoss,但愿老天保佑没有实现错。目前初步观测的现象:

  1. 前期收敛速度增加。
  2. 未观察到显著的后期准确率提升效果。
  3. 使用针对数据分布不平衡的自适应加权后,效果有显著下降。

目前我能给的初步解释是:

  1. 前期收敛速度增加可能是:普通的loss初期浪费了一部分算力在简单的数据上,focal loss可以更好的让算力跳过太简单的数据提前开始考虑难度稍微大的案例。
  2. 最后准确率提升未见明显效果,可能是自己的网络结构较为简单的结果,换用普通loss也已经训练到网络的在数据集的极限附近。显然,这个解释不够满意,而且我用的网络也不算简单,也有二十多层了,宽度也不算太小。等后续调整了网络复杂度后再对比看看。
  3. 数据分布不平衡的自适应加权后,准确率反而进一步下降的问题待查。不过即使使用普通loss也存在类似的问题,只是从评价指标上看不明显,可能的原因是少量类别过大的权重影响了网络的学习过程,过于看中少量分类的召回,影响了其他量多的类别。而量多的类别对整体准确率的影响是非常明显的。

当然,还有一个可能的解释:我的实现写错了。-_-!

Over