9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral


9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
新智元推荐
来源:知乎
作者:杨朔
【新智元导读】本文介绍一篇最新发表在ICLR2021Oral上的少样本学习工作 , 他们尝试从数据分布估计的角度去缓解少样本学习中的过拟合现象 , 并提出通过分布矫正(估计)的方式弥合这种差距 。
链接:https://openreview.net/forum?id=JWOiYxMG92s
代码:
https://github.com/ShuoYang-1998/ICLR2021-Oral_Distribution_Calibration
简介
从极少量样本中学习到泛化性能良好的模型是很困难的 , 因为极少的样本形成的数据分布往往与真实数据分布相差较大 , 在偏斜的数据分布上训练模型会导致严重的过拟合现象并严重破坏模型的泛化能力(见图1) 。
在本文中我们尝试从数据分布估计的角度去缓解少样本学习中的过拟合现象 , 利用一个样本去估计该类别的整体数据分布 , 如果该分布估计足够准确 , 也许可以弥合少样本学习和传统多样本学习的差距 。

9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
方法
直接从一个样本中估计整体数据分布是非常困难的 , 需要很强的先验去约束分布估计的过程 。 我们观察到如果假设每一个类别的特征都服从高斯分布 , 那么相似类别的分布统计量相似度非常高 , 如表1 。

9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
从直观的角度理解 , 一个类别的mean代表该类别的generalappearance , variance代表该类别某属性的变化范围(颜色、形状、姿势等) 。
而相似的类别(如猫和老虎)具有相似的整体外观和相似的属性变化范围 。 受此启发 , 我们提出了通过迁移基类(baseclass)的分布统计量的方式对少样本类别的数据分布做‘矫正’(calibration) 。
具体来说 , 我们首先为每一个baseclassi计算一个mean和covariance:
计算好的

储存起来当作baseclass分布先验 。 然后在进行少样本分类时我们利用baseclass的分布先验去修正少样本类别的数据分布:
得到修正后的少样本类别的分布

后 , 我们便可以从修正后的分布中直接采样:
然后利用采样得到的数据和supportset共同训练分类器:
至此 , 该算法结束 。
流程如图:

9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
实验
我们的算法无需任何可训练参数 , 可以建立在任何已有的特征提取器和分类器之上 , 并极大的提高模型的泛化能力 。
代码已开源 , 核心代码只有9行(evaluate_DC.py中的第10-19行) 。
我们的方法搭配最简单的线性分类器便可以达到非常高的1-shot分类性能 。
实验结果如图:

9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
我们分布估计的可视化如图:

9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral
文章图片
总结
在本工作中我们思考了少样本学习和多样本学习的核心差距 , 并提出通过分布矫正(估计)的方式弥合这种差距 。
在该工作的后续期刊拓展版本中我们从generalizationerrorbound的角度为‘基于数据分布估计的少样本学习’这一类方法建立了理论框架 , 并证明了当数据分布足够准确时 , 少样本学习和多样本学习的泛化误差等价 。
9行代码提高少样本学习泛化能力,代码已开源|ICLR2021 Oral】知乎链接:https://zhuanlan.zhihu.com/p/344531704