1、 为 Paddle 新增 multi margin loss API

该项目为 Paddle 添加了关于计算 multimarginloss 的损失函数的 API,全部流程基于 python 开发。

2、任务介绍

该任务是第三期 paddle Hackathon 其中一项基础任务,丰富和完善飞桨的计算损失函数的 API。作为一个基础任务,开发的难度也较其他的高级任务容易上手。在此也希望能抛砖引玉,激发大家的热情,参与到飞桨开源社区的开发中,为 PaddlePaddle 添砖加瓦。

multi_margin_loss 是用于多分类问题的 Hinge loss。关于 Margin Loss, Hinge Loss, Ranking Loss 之间的区别,感兴趣的朋友可以看看这篇内容《Understanding Ranking Loss, Contrastive Loss, Margin Loss, Triplet Loss, Hinge Loss and all those confusing names》 或是这篇知乎专栏

这里我们就只关注如何实现根据公式实现这个函数即可,其公式如下

在这里插入图片描述

其中 C C C 为类别数。若 batch_size 为 1 时, x x x 是一个(C,)的 input, y y y 是一个标量的 label ( 0 ≤ y ≤ C − 1 0\leq y\leq C-1 0yC1),且 i ≠ y i\neq y i=y p p p为一个幂指数。

该损失函数对于某一个样本不是考虑样本输出与真实类别之间的误差,而是考虑对应真实类别预测与其他错误类别与猜测之间的误差。即 x [ i ] − x [ y ] x[i]-x[y] x[i]x[y] 表示的即为两者的相似关系。

优化时令正确预测 x [ y ] x[y] x[y] 越来越大,错误预测 x [ i ] x[i] x[i] 越来越小,最终两者的差值为一个 m a r g i n margin margin 的大小即可,差距更大并不会有任何奖励。这样设计的目的在于,过分的关注单个样本的准确性可能会使得整体分类的误差较大。

同时,如果有权重 w 的情况下(即对每一个类别赋予不同的权重),则公式为

在这里插入图片描述

3、设计文档

作为一个基础 API 任务,最主要的是理解公式。我在最开始的时候没有很好的理解公式,在之后开发的过程中也出现了一些小错误。

对于设计文档还需要对竞品进行调研,调研同样也是一个学习的过程。包括 Pytorch、 Tensorflow 等,调研的过程中可以发现 Pytorch 的算子库还是较为丰富, tensorflow 就没有该算子,需要用户自己去组合实现。

Pytorch 的该算子是通过 C 实现的。

在这里插入图片描述

同时还有 Pytorch 给的示例。

在这里插入图片描述

这个示例同样也帮助我更好的理解这个算子开发的具体细节。

同时为了要保证该算子的使用方法与其他算子的一致性,阅读了 PaddlePaddle 已有的损失函数的代码框架, 此外还需要注意在设计文档的部分需要测试方面的设计。设计的案例以及覆盖范围要尽可能的全面。同时由于该算子是用的已有的算子进行组合。因此不需要对算子额外进行测试。只需对输入的参数合法性判断的语句进行编写测试案例即可。

4、代码开发

该 API 使用的都是已有的算子进行组合,因此只要按照公式,在 python 上进行开发即可。

根据 Paddle 的损失函数框架,核心函数部分是在 nn.functional.multi_margin_loss 中,算子的实际功能都在这当中实现。在 nn.MultiMarginLoss 的模型中只需要在 forward 函数中调用该nn.functional.multi_margin_loss。

在这里插入图片描述

检验 API 的正确性时,可以通过构造一些测试的随机数据与 Pytorch 的结果进行对比。

此外还需要注意的就是注释和测试用例,为了更好的让他人理解和使用该 API 。在这里要感谢 paddle 的小伙伴们对我的代码反复耐心的 review,帮我找到了代码中的错误,以及完善注释和测试用例的全面性。在他们的帮助下,才正确的完成了整个代码开发(PR 链接)。

5、成果展示

使用的方法与其他算子相同, nn.functional.multi_margin_loss 使用方法如下

在这里插入图片描述

nn.MultiMarginLoss 的使用方法如下

在这里插入图片描述

6、总结

任务本身并不算难,但在这个过程中能够体验为大型开源项目贡献代码的整个流程,是很有意义也富有成就感。同时对个人的代码能力也是一种提升。对于没有太多基础但是对深度学习有兴趣、想要为开源社区做贡献的、希望加入大型项目进行开发的小伙伴,可以尝试从基础 API 的开发这类的任务入手,这个过程也能很好的熟悉一下 Paddle 的框架。日后不论是继续深入对 Paddle 的开发,或是参与到其他的项目中都有极大的帮助。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐