1 任务介绍

为了丰富飞桨框架的矩阵索引 API,我们新增 take API,它能够根据索引返回指定索引上的数据集合,调用路径为:paddle.takeTensor.take

任务要求见 [PaddlePaddle Hackathon 3] API 开发任务合集#44073。我们需要熟悉 take 算法原理和适用场景;熟悉飞桨动静态图下数学计算过程;熟练掌握 Python。

任务的难度一般,通过已有的矩阵索引 API,能够组合实现 take API,并模仿已有的索引 API 的写法就可以实现调用。

2 设计文档

2.1 任务背景

目前飞桨可由 Tensor.flatten、Tensor.index_select 和 Tensor.reshape 组合实现 take API 的功能。

其主要实现逻辑为:

  • 通过 Tensor.flatten() 将输入 x 和 index 展开成 1D Tensor。

  • 通过 Tensor.index_select(index) 按照 index 中的索引提取对应元素。

  • 通过 Tensor.reshape(index.shape) 将输出的 Tensor 形状转成 index 的形状。

我们新增 take API,将为飞桨增加新的 Tensor 索引函数,丰富飞桨 Tensor 的索引功能。

2.2 竞品情况

在调研业内竞品情况时,我们选择了主流的 PyTorch、TensorFlow 和 Numpy。调研的方面主要包括:

  • 实现方法

  • 形参设置

  • 越界处理

  • 返回值

2.2.1 PyTorch

Pytorch 中有 API torch.take(input, index)官方文档 描述如下:torch.take 返回一个新的 Tensor,其中包含给定索引处的输入元素。输入 Tensor 被视为一维张量。输出的 Tensor 与索引矩阵的形状相同。

torch/onnx/symbolic_opset*.py 中定义了 take 的 Python 实现方法:

def take(g, self, index):
    self_flattened = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
    out = index_select(g, self_flattened, 0, index)
    out = reshape_as(g, out, index)
    return out

实现方法

  • 通过 Tensor.flatten() 将输入 x 和 index 展开成 1D Tensor。

  • 通过 Tensor.index_select(index) 按照 index 中的索引提取对应元素。

  • 通过 Tensor.reshape(index.shape) 将输出的 Tensor 转成 index 的形状。

形参设置

  • input (Tensor) – the input tensor

  • index (LongTensor) – the indices into tensor

越界处理

没有额外的越界处理,越界即报错。

返回值

Tensor

2.2.2 Numpy

Numpy 中有 API numpy.take(a, indices, axis=None, out=None, mode='raise')官方文档 描述为沿轴(axis)从 ndarray 中取元素。
如果坐标轴不是None,那么这个函数的作用与“fancy”索引(使用 ndarray 对 ndarray 进行索引)相同。

实现方法

  • 当指定 axis 的时候,numpy.take 执行与 “fancy indexing” 相同的索引操作(使用数组索引数组);例如 np.take(arr, indices, axis=3) 等价于 arr[:, :, :, indices, ...]

  • 当不指定 axis 的时候,numpy.take 默认将输入展平再使用 “fancy indexing”。

  • 当提供参数 out 的时候,输出的数据将填充到 out 中。

形参设置

  • a (array_like) – The source array.

  • indices (array_like) – The indices of the values to extract.

  • axis (int) – 指定轴,默认情况下,使用展平(flattened)的输入 array。

  • out (ndarray, optional)

    • 如果提供,结果将放在这个 ndarray 中。它应该具有适当的形状和类型。请注意,如果 mode = ‘raise’,则始终缓存 out; 使用其他模式以获得更好的性能。
  • mode ({‘raise’, ‘wrap’, ‘clip’}, optional)

    指定越界索引的行为方式。

    • ‘raise’ – raise an error (default).

    • ‘wrap’ – 通过取余约束越界的 indices。

    • ‘clip’ – 将超出范围的索引约束到 [0, max_index),此模式将禁用带有负数的索引。

越界处理

如上所述,通过 mode 参数指定。

返回值

out (ndarray) - The returned array has the same type as a.

2.2.3 TensorFlow

据我们的调研情况,TensorFlow 中没有自己实现 take API,而是直接调用 numpy.taketf.experimental.numpy.take

在充分梳理三者的逻辑后,对比了三者实现的相同点以及不同点。大致包括如下几个方面:

2.2.4 对比分析

  • torch.takeindex 参数必须为 LongTensor 类型;numpy.take 直接对参数 indices 的元素取整再进行索引。

  • 在维度支持上,numpy.take 支持指定轴,torch.take 不支持。

  • 当不指定轴时,对于相同的索引矩阵,numpy.take 的执行结果等于 torch.take

  • numpy.take 支持通过 mode 参数指定索引越界的 3 种处理方式,默认直接报错;torch.take 在索引越界时直接报错。

2.3 设计思路

由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 paddle.take 的功能与 torch.takenumpy.take 的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 torch.take 的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性。

2.3.1 参数设置

paddle.take(
  x: Tensor,
  index: Tensor,
  mode: str='raise',
  name: str=None)

注:其中添加参数 name 为了与飞桨其他 API 参数名保持一致。

2.3.2 底层 OP 设计

使用已有 API 组合实现,不再单独设计 OP。

2.4 实现方案

该 API 需要添加在飞桨 repo 的 python/paddle/tensor/math.py 文件中;并在 python/paddle/tensor/__init__.py 以及 python/paddle/__init__.py 中添加 take API,以支持 Tensor.take 和 paddle.take 的调用方式。

目前飞桨可由 Tensor.flattenTensor.index_selectTensor.reshape 组合实现该 API 的功能。

其主要实现逻辑为:

  1. 通过 Tensor.flatten() 将输入 x 和 index 展开成 1D Tensor。

  2. 根据 mode 参数对索引进行越界处理:

    • mode='raise',若索引越界,通过最后调用的 paddle.index_select 抛出错误 (默认);
    • mode='wrap',通过取余约束越界的 indices;
    • mode='clip',通过 paddle.clip 将两端超出范围的索引约束到 [0, max_index-1]。
  3. 通过 Tensor.index_select(index) 按照 index 中的索引提取对应元素。

    • numpy.taketorch.take 支持负值索引;
    • 然而 index_select 不支持,因此需要先将 index 的负值索引转为对应的正值索引。
  4. 通过 Tensor.reshape(index.shape) 将输出的 Tensor 形状转成 index 的形状。

2.5 单测方案

测试考虑的 case 如下:

  • 参数 index 数据类型必须为 paddle.int32paddle.int64 类型的 Tensor(与 paddle.index_select 一致)。

  • x 的数据类型支持 int32int64float32float64

  • index 索引越界的三种处理方式:

    • mode='raise',若索引越界,通过最后调用的 paddle.index_select 抛出错误 (默认);
    • mode='wrap',通过取余约束越界的 indices;
    • mode='clip',通过 paddle.clip 将两端超出范围的索引约束到 [0, max_index-1]。
  • 在动态图、静态图下,以及 CPU、GPU 下,都能得到正确的结果。

3 代码开发

3.1 API 开发

我个人的开发环境是 Win10 + PyCharm + Git。

依照设计文档,我们就可以开始代码开发了。take API 开发较为简单,使用现有 API 组合实现即可。

在飞桨内部工程师、社区 Committer 的 review 和帮助下,我们不断优化迭代 take API,最终成功合入。详见我们的 PR #44741

def take(x, index, mode='raise', name=None):

    if mode not in ['raise', 'wrap', 'clip']:
        raise ValueError(
            "'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode))

    if paddle.in_dynamic_mode():
        if not isinstance(index, (paddle.Tensor, Variable)):
            raise TypeError(
                "The type of 'index' must be Tensor, but got {}".format(type(index)))
        if index.dtype not in [paddle.int32, paddle.int64]:
            raise TypeError(
                "The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
                    index.dtype))

    else:
        check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')

    input_1d = x.flatten()  # 将输入的 Tensor 展开成一维
    index_1d = index.flatten()  # 将输入的索引矩阵也展开成一维
    max_index = input_1d.shape[-1]

    if mode == 'raise':
        # This processing enables 'take' to handle negative indexes within the correct range.
        index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d)
    elif mode == 'wrap':
        # The out of range indices are constrained by taking the remainder.
        index_1d = paddle.where(index_1d < 0,
                                index_1d % max_index, index_1d)
        index_1d = paddle.where(index_1d >= max_index,
                                index_1d % max_index, index_1d)
    elif mode == 'clip':
        # 'clip' mode disables indexing with negative numbers.
        index_1d = clip(index_1d, 0, max_index - 1)

    out = input_1d.index_select(index_1d).reshape(index.shape)

    return out

代码开发完成了,我们还要按照贡献指南的注释规范,补充和修改注释,书写中文和英文文档。

3.2 单测开发

我们以 numpy.take 作为基准,验证 API 的正确性。

根据测试文档的单测方案,我们要测试的 cases 如下:

import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard


class TestTakeAPI(unittest.TestCase):
    """
    单测主类,将同类 case 的测试参数进行封装,方便继承和修改。
    """

    def set_mode(self):
        """
        设置索引越界处理模式
        """
        self.mode = 'raise'

    def set_dtype(self):
        """
        设置输入的数据类型
        """
        self.input_dtype = 'float64'
        self.index_dtype = 'int64'

    def set_input(self):
        """
        设置输入的 Tensor 形状和元素
        """
        self.input_shape = [3, 4]
        self.index_shape = [2, 3]
        self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
            self.input_dtype)
        self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
            self.index_dtype)

    def setUp(self):
        """
        初始化
        """
        self.set_mode()
        self.set_dtype()
        self.set_input()
        self.place = fluid.CUDAPlace(
            0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

    def test_static_graph(self):
        """
        每一种 case 都需要经过动态图、静态图的测试
        """
        paddle.enable_static()
        startup_program = Program()
        train_program = Program()
        with program_guard(startup_program, train_program):
            x = fluid.data(name='input',
                           dtype=self.input_dtype,
                           shape=self.input_shape)
            index = fluid.data(name='index',
                               dtype=self.index_dtype,
                               shape=self.index_shape)
            out = paddle.take(x, index, mode=self.mode)

            exe = fluid.Executor(self.place)
            st_result = exe.run(fluid.default_main_program(),
                                feed={
                                    'input': self.input_np,
                                    'index': self.index_np
                                },
                                fetch_list=out)
            np.testing.assert_allclose(
                st_result[0],
                np.take(self.input_np, self.index_np, mode=self.mode))

    def test_dygraph(self):
        paddle.disable_static(self.place)
        x = paddle.to_tensor(self.input_np)
        index = paddle.to_tensor(self.index_np)
        dy_result = paddle.take(x, index, mode=self.mode)
        np.testing.assert_allclose(
            np.take(self.input_np, self.index_np, mode=self.mode),
            dy_result.numpy())

数据类型 cases,确保预设的数据类型都能通过:

class TestTakeInt32(TestTakeAPI):
    """Test take API with data type int32"""

    def set_dtype(self):
        self.input_dtype = 'int32'
        self.index_dtype = 'int64'


class TestTakeInt64(TestTakeAPI):
    """Test take API with data type int64"""

    def set_dtype(self):
        self.input_dtype = 'int64'
        self.index_dtype = 'int64'


class TestTakeFloat32(TestTakeAPI):
    """Test take API with data type float32"""

    def set_dtype(self):
        self.input_dtype = 'float32'
        self.index_dtype = 'int64'

数据类型报错的 case,确保预设的数据类型错误有合适的报错提示:

class TestTakeTypeError(TestTakeAPI):
    """Test take Type Error"""

    def test_static_type_error(self):
        """Argument 'index' must be Tensor"""
        paddle.enable_static()
        with program_guard(Program()):
            x = fluid.data(name='input',
                           dtype=self.input_dtype,
                           shape=self.input_shape)
            self.assertRaises(TypeError, paddle.take, x, self.index_np,
                              self.mode)

    def test_dygraph_type_error(self):
        paddle.disable_static(self.place)
        x = paddle.to_tensor(self.input_np)
        self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)

    def test_static_dtype_error(self):
        """Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
        paddle.enable_static()
        with program_guard(Program()):
            x = fluid.data(name='input',
                           dtype='float64',
                           shape=self.input_shape)
            index = fluid.data(name='index',
                               dtype='float32',
                               shape=self.index_shape)
            self.assertRaises(TypeError, paddle.take, x, index, self.mode)

    def test_dygraph_dtype_error(self):
        paddle.disable_static(self.place)
        x = paddle.to_tensor(self.input_np)
        index = paddle.to_tensor(self.index_np, dtype='float32')
        self.assertRaises(TypeError, paddle.take, x, index, self.mode)

正值索引越界错误提示:

class TestTakeModeRaisePos(unittest.TestCase):
    """Test positive index out of range error"""

    def set_mode(self):
        self.mode = 'raise'

    def set_dtype(self):
        self.input_dtype = 'float64'
        self.index_dtype = 'int64'

    def set_input(self):
        self.input_shape = [3, 4]
        self.index_shape = [5, 6]
        self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
            self.input_dtype)
        self.index_np = np.arange(-10, 20).reshape(self.index_shape).astype(
            self.index_dtype)  # positive indices are out of range

    def setUp(self):
        self.set_mode()
        self.set_dtype()
        self.set_input()
        self.place = fluid.CUDAPlace(
            0) if core.is_compiled_with_cuda() else fluid.CPUPlace()

    def test_static_index_error(self):
        """When the index is out of range,
        an error is reported directly through `paddle.index_select`"""
        paddle.enable_static()
        with program_guard(Program()):
            x = fluid.data(name='input',
                           dtype=self.input_dtype,
                           shape=self.input_shape)
            index = fluid.data(name='index',
                               dtype=self.index_dtype,
                               shape=self.index_shape)
            self.assertRaises(ValueError, paddle.index_select, x, index)

    def test_dygraph_index_error(self):
        paddle.disable_static(self.place)
        x = paddle.to_tensor(self.input_np)
        index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
        self.assertRaises(ValueError, paddle.index_select, x, index)

负值索引越界错误提示:

class TestTakeModeRaiseNeg(TestTakeModeRaisePos):
    """Test negative index out of range error"""

    def set_input(self):
        self.input_shape = [3, 4]
        self.index_shape = [5, 6]
        self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
            self.input_dtype)
        self.index_np = np.arange(-20, 10).reshape(self.index_shape).astype(
            self.index_dtype)  # negative indices are out of range

不同索引越界处理模式的 cases:

class TestTakeModeWrap(TestTakeAPI):
    """Test take index out of range mode"""

    def set_mode(self):
        self.mode = 'wrap'

    def set_input(self):
        self.input_shape = [3, 4]
        self.index_shape = [5, 8]
        self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
            self.input_dtype)
        self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
            self.index_dtype)  # Both ends of the index are out of bounds


class TestTakeModeClip(TestTakeAPI):
    """Test take index out of range mode"""

    def set_mode(self):
        self.mode = 'clip'

    def set_input(self):
        self.input_shape = [3, 4]
        self.index_shape = [5, 8]
        self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
            self.input_dtype)
        self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
            self.index_dtype)  # Both ends of the index are out of bounds


if __name__ == "__main__":
    unittest.main()

4 成果展示

至此,我们完成了 take API 的开发任务。现在,take API 已经合入 Paddle2.4 版本,详见官方文档

我们可以测试一下代码示例:

import paddle

x_int = paddle.arange(0, 12).reshape([3, 4])
x_float = x_int.astype(paddle.float64)

idx_pos = paddle.arange(4, 10).reshape([2, 3])  # positive index
idx_neg = paddle.arange(-2, 4).reshape([2, 3])  # negative index
idx_err = paddle.arange(-2, 13).reshape([3, 5])  # index out of range

paddle.take(x_int, idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
#        [[4, 5, 6],
#         [7, 8, 9]])

paddle.take(x_int, idx_neg)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
#        [[10, 11, 0 ],
#         [1 , 2 , 3 ]])

paddle.take(x_float, idx_pos)
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
#        [[4., 5., 6.],
#         [7., 8., 9.]])

x_int.take(idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
#        [[4, 5, 6],
#         [7, 8, 9]])

paddle.take(x_int, idx_err, mode='wrap')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
#        [[10, 11, 0 , 1 , 2 ],
#         [3 , 4 , 5 , 6 , 7 ],
#         [8 , 9 , 10, 11, 0 ]])

paddle.take(x_int, idx_err, mode='clip')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
#        [[0 , 0 , 0 , 1 , 2 ],
#         [3 , 4 , 5 , 6 , 7 ],
#         [8 , 9 , 10, 11, 11]])

5 总结

Pre-commit

本次任务总体难度一般,工作量也不大,基本完成了 take API 的需求。这个过程中,代码和文档规范是非常重要的部分,需要迭代完善。在进行开发之前,在我们的虚拟环境中先安装好 pre-commit 并按要求配置规范化格式之后,才可以提交自己的修改。

pre-commit 会在每一次的 git commit 自动触发文档规范化检查,同时自动修改,例如代码换行,这时需要再次 git addgit commit,直到通过所有的自动化检查,这是文档规范的首道防线。

我们需要按照贡献指南的流程,补充和修改注释,书写中文和英文文档。

rst 格式的文档预览

我们在代码中的注释都会最终成为英文文档,而中文文档需要我们写成 rst 格式的文档,并提交到 Paddle/docs 仓库。

然而这个格式的文档目前没有一个便捷的预览方式(我用 VS Code 编辑,没有找到合适的插件),我们 push 到 GitHub 之后可以预览一部分的效果,但不多,有些列表缩进、字段高亮和公式不能完整显示,与飞桨官方文档的最终效果有差异。

我们找到了一个 GitHub 的开源工具,它提供了将 rst 文档编译成 html 的工具:ieflex/newretaildoc

我现在把使用方法汉化如下:

clone 仓库到本地

git clone https://github.com/ieflex/newretaildoc.git
newretaildoc
|- .git
|- docs
  |- ...
  ...
  |- create_html.bat
  |- make.bat
  |- index.rst  # 这些文件不要误删
  |- license.rst
  |- Makefile
  |- README.md

安装样式库

其实没必要开个新的虚拟环境,直接安装在 base 环境即可。

 pip install sphinx
 pip install sphinx_rtd_theme

编译 rst 文档

  • 将需要编译的 rst 文档放入 docs 目录下
  • docs 目录下呼出 cmd,并激活 base 环境
  • 命令行输入 create_html.bat 运行脚本。他会把当前目录下的所有 rst 文档全部编译完成的 html 文件在在 docs 目录的 _build/html 文件夹中,即可在浏览器预览编译结果。

PS. 学有余力的大佬可以把它做成 web,或者 VS Code、浏览器的插件那就最好了

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐