为 PaddlePaddle 新增矩阵索引 API paddle.take
在 Paddle 框架中新增 take API,调用路径为:paddle.take 和 Tensor.take。
1 任务介绍
为了丰富飞桨框架的矩阵索引 API,我们新增 take API,它能够根据索引返回指定索引上的数据集合,调用路径为:paddle.take
和 Tensor.take
。
任务要求见 [PaddlePaddle Hackathon 3] API 开发任务合集#44073。我们需要熟悉 take 算法原理和适用场景;熟悉飞桨动静态图下数学计算过程;熟练掌握 Python。
任务的难度一般,通过已有的矩阵索引 API,能够组合实现 take API,并模仿已有的索引 API 的写法就可以实现调用。
2 设计文档
2.1 任务背景
目前飞桨可由 Tensor.flatten、Tensor.index_select 和 Tensor.reshape 组合实现 take API 的功能。
其主要实现逻辑为:
-
通过
Tensor.flatten()
将输入 x 和 index 展开成 1D Tensor。 -
通过
Tensor.index_select(index)
按照 index 中的索引提取对应元素。 -
通过
Tensor.reshape(index.shape)
将输出的 Tensor 形状转成 index 的形状。
我们新增 take API,将为飞桨增加新的 Tensor 索引函数,丰富飞桨 Tensor 的索引功能。
2.2 竞品情况
在调研业内竞品情况时,我们选择了主流的 PyTorch、TensorFlow 和 Numpy。调研的方面主要包括:
-
实现方法
-
形参设置
-
越界处理
-
返回值
2.2.1 PyTorch
Pytorch 中有 API torch.take(input, index)
。官方文档 描述如下:torch.take
返回一个新的 Tensor,其中包含给定索引处的输入元素。输入 Tensor 被视为一维张量。输出的 Tensor 与索引矩阵的形状相同。
在 torch/onnx/symbolic_opset*.py 中定义了 take
的 Python 实现方法:
def take(g, self, index):
self_flattened = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
out = index_select(g, self_flattened, 0, index)
out = reshape_as(g, out, index)
return out
实现方法
-
通过
Tensor.flatten()
将输入 x 和 index 展开成 1D Tensor。 -
通过
Tensor.index_select(index)
按照 index 中的索引提取对应元素。 -
通过
Tensor.reshape(index.shape)
将输出的 Tensor 转成 index 的形状。
形参设置
-
input
(Tensor) – the input tensor -
index
(LongTensor) – the indices into tensor
越界处理
没有额外的越界处理,越界即报错。
返回值
Tensor
2.2.2 Numpy
Numpy 中有 API numpy.take(a, indices, axis=None, out=None, mode='raise')
。官方文档 描述为沿轴(axis)从 ndarray 中取元素。
如果坐标轴不是None
,那么这个函数的作用与“fancy”索引(使用 ndarray 对 ndarray 进行索引)相同。
实现方法
-
当指定
axis
的时候,numpy.take
执行与 “fancy indexing” 相同的索引操作(使用数组索引数组);例如np.take(arr, indices, axis=3)
等价于arr[:, :, :, indices, ...]
。 -
当不指定
axis
的时候,numpy.take
默认将输入展平再使用 “fancy indexing”。 -
当提供参数
out
的时候,输出的数据将填充到out
中。
形参设置
-
a
(array_like) – The source array. -
indices
(array_like) – The indices of the values to extract. -
axis
(int) – 指定轴,默认情况下,使用展平(flattened)的输入 array。 -
out
(ndarray, optional)- 如果提供,结果将放在这个 ndarray 中。它应该具有适当的形状和类型。请注意,如果 mode = ‘raise’,则始终缓存 out; 使用其他模式以获得更好的性能。
-
mode
({‘raise’, ‘wrap’, ‘clip’}, optional)指定越界索引的行为方式。
-
‘raise’ – raise an error (default).
-
‘wrap’ – 通过取余约束越界的 indices。
-
‘clip’ – 将超出范围的索引约束到 [0, max_index),此模式将禁用带有负数的索引。
-
越界处理
如上所述,通过 mode 参数指定。
返回值
out (ndarray) - The returned array has the same type as a
.
2.2.3 TensorFlow
据我们的调研情况,TensorFlow 中没有自己实现 take
API,而是直接调用 numpy.take
:tf.experimental.numpy.take。
在充分梳理三者的逻辑后,对比了三者实现的相同点以及不同点。大致包括如下几个方面:
2.2.4 对比分析
-
torch.take
的index
参数必须为 LongTensor 类型;numpy.take
直接对参数indices
的元素取整再进行索引。 -
在维度支持上,
numpy.take
支持指定轴,torch.take
不支持。 -
当不指定轴时,对于相同的索引矩阵,
numpy.take
的执行结果等于torch.take
。 -
numpy.take
支持通过mode
参数指定索引越界的 3 种处理方式,默认直接报错;torch.take
在索引越界时直接报错。
2.3 设计思路
由于 numpy 在指定轴索引后得到的结果不能保证与 index 的 shape 一致,会破坏 take 方法的输出结果形状与 index 一致的特性。因此我们决定新增的 paddle.take
的功能与 torch.take
和 numpy.take
的默认形式保持一致,即,不增加 axis 参数指定索引轴;在 torch.take
的基础上增加 mode 参数提供三种 index 索引越界的处理方式。尽可能保持 take 索引方法简洁、易理解的特性。
2.3.1 参数设置
paddle.take(
x: Tensor,
index: Tensor,
mode: str='raise',
name: str=None)
注:其中添加参数 name
为了与飞桨其他 API 参数名保持一致。
2.3.2 底层 OP 设计
使用已有 API 组合实现,不再单独设计 OP。
2.4 实现方案
该 API 需要添加在飞桨 repo 的 python/paddle/tensor/math.py
文件中;并在 python/paddle/tensor/__init__.py
以及 python/paddle/__init__.py
中添加 take
API,以支持 Tensor.take 和 paddle.take
的调用方式。
目前飞桨可由 Tensor.flatten
、Tensor.index_select
和 Tensor.reshape
组合实现该 API 的功能。
其主要实现逻辑为:
-
通过
Tensor.flatten()
将输入 x 和 index 展开成 1D Tensor。 -
根据 mode 参数对索引进行越界处理:
mode='raise'
,若索引越界,通过最后调用的paddle.index_select
抛出错误 (默认);mode='wrap'
,通过取余约束越界的 indices;mode='clip'
,通过paddle.clip
将两端超出范围的索引约束到 [0, max_index-1]。
-
通过
Tensor.index_select(index)
按照 index 中的索引提取对应元素。numpy.take
和torch.take
支持负值索引;- 然而
index_select
不支持,因此需要先将 index 的负值索引转为对应的正值索引。
-
通过
Tensor.reshape(index.shape)
将输出的 Tensor 形状转成 index 的形状。
2.5 单测方案
测试考虑的 case 如下:
-
参数
index
数据类型必须为paddle.int32
和paddle.int64
类型的 Tensor(与paddle.index_select
一致)。 -
x
的数据类型支持int32
,int64
,float32
,float64
。 -
index
索引越界的三种处理方式:mode='raise'
,若索引越界,通过最后调用的paddle.index_select
抛出错误 (默认);mode='wrap'
,通过取余约束越界的 indices;mode='clip'
,通过paddle.clip
将两端超出范围的索引约束到 [0, max_index-1]。
-
在动态图、静态图下,以及 CPU、GPU 下,都能得到正确的结果。
3 代码开发
3.1 API 开发
我个人的开发环境是 Win10 + PyCharm + Git。
依照设计文档,我们就可以开始代码开发了。take API 开发较为简单,使用现有 API 组合实现即可。
在飞桨内部工程师、社区 Committer 的 review 和帮助下,我们不断优化迭代 take API,最终成功合入。详见我们的 PR #44741
def take(x, index, mode='raise', name=None):
if mode not in ['raise', 'wrap', 'clip']:
raise ValueError(
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode))
if paddle.in_dynamic_mode():
if not isinstance(index, (paddle.Tensor, Variable)):
raise TypeError(
"The type of 'index' must be Tensor, but got {}".format(type(index)))
if index.dtype not in [paddle.int32, paddle.int64]:
raise TypeError(
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
index.dtype))
else:
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')
input_1d = x.flatten() # 将输入的 Tensor 展开成一维
index_1d = index.flatten() # 将输入的索引矩阵也展开成一维
max_index = input_1d.shape[-1]
if mode == 'raise':
# This processing enables 'take' to handle negative indexes within the correct range.
index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d)
elif mode == 'wrap':
# The out of range indices are constrained by taking the remainder.
index_1d = paddle.where(index_1d < 0,
index_1d % max_index, index_1d)
index_1d = paddle.where(index_1d >= max_index,
index_1d % max_index, index_1d)
elif mode == 'clip':
# 'clip' mode disables indexing with negative numbers.
index_1d = clip(index_1d, 0, max_index - 1)
out = input_1d.index_select(index_1d).reshape(index.shape)
return out
代码开发完成了,我们还要按照贡献指南的注释规范,补充和修改注释,书写中文和英文文档。
3.2 单测开发
我们以 numpy.take
作为基准,验证 API 的正确性。
根据测试文档的单测方案,我们要测试的 cases 如下:
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
class TestTakeAPI(unittest.TestCase):
"""
单测主类,将同类 case 的测试参数进行封装,方便继承和修改。
"""
def set_mode(self):
"""
设置索引越界处理模式
"""
self.mode = 'raise'
def set_dtype(self):
"""
设置输入的数据类型
"""
self.input_dtype = 'float64'
self.index_dtype = 'int64'
def set_input(self):
"""
设置输入的 Tensor 形状和元素
"""
self.input_shape = [3, 4]
self.index_shape = [2, 3]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-4, 2).reshape(self.index_shape).astype(
self.index_dtype)
def setUp(self):
"""
初始化
"""
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
def test_static_graph(self):
"""
每一种 case 都需要经过动态图、静态图的测试
"""
paddle.enable_static()
startup_program = Program()
train_program = Program()
with program_guard(startup_program, train_program):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
out = paddle.take(x, index, mode=self.mode)
exe = fluid.Executor(self.place)
st_result = exe.run(fluid.default_main_program(),
feed={
'input': self.input_np,
'index': self.index_np
},
fetch_list=out)
np.testing.assert_allclose(
st_result[0],
np.take(self.input_np, self.index_np, mode=self.mode))
def test_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np)
dy_result = paddle.take(x, index, mode=self.mode)
np.testing.assert_allclose(
np.take(self.input_np, self.index_np, mode=self.mode),
dy_result.numpy())
数据类型 cases,确保预设的数据类型都能通过:
class TestTakeInt32(TestTakeAPI):
"""Test take API with data type int32"""
def set_dtype(self):
self.input_dtype = 'int32'
self.index_dtype = 'int64'
class TestTakeInt64(TestTakeAPI):
"""Test take API with data type int64"""
def set_dtype(self):
self.input_dtype = 'int64'
self.index_dtype = 'int64'
class TestTakeFloat32(TestTakeAPI):
"""Test take API with data type float32"""
def set_dtype(self):
self.input_dtype = 'float32'
self.index_dtype = 'int64'
数据类型报错的 case,确保预设的数据类型错误有合适的报错提示:
class TestTakeTypeError(TestTakeAPI):
"""Test take Type Error"""
def test_static_type_error(self):
"""Argument 'index' must be Tensor"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
self.assertRaises(TypeError, paddle.take, x, self.index_np,
self.mode)
def test_dygraph_type_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)
def test_static_dtype_error(self):
"""Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype='float64',
shape=self.input_shape)
index = fluid.data(name='index',
dtype='float32',
shape=self.index_shape)
self.assertRaises(TypeError, paddle.take, x, index, self.mode)
def test_dygraph_dtype_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype='float32')
self.assertRaises(TypeError, paddle.take, x, index, self.mode)
正值索引越界错误提示:
class TestTakeModeRaisePos(unittest.TestCase):
"""Test positive index out of range error"""
def set_mode(self):
self.mode = 'raise'
def set_dtype(self):
self.input_dtype = 'float64'
self.index_dtype = 'int64'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-10, 20).reshape(self.index_shape).astype(
self.index_dtype) # positive indices are out of range
def setUp(self):
self.set_mode()
self.set_dtype()
self.set_input()
self.place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
def test_static_index_error(self):
"""When the index is out of range,
an error is reported directly through `paddle.index_select`"""
paddle.enable_static()
with program_guard(Program()):
x = fluid.data(name='input',
dtype=self.input_dtype,
shape=self.input_shape)
index = fluid.data(name='index',
dtype=self.index_dtype,
shape=self.index_shape)
self.assertRaises(ValueError, paddle.index_select, x, index)
def test_dygraph_index_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.input_np)
index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
self.assertRaises(ValueError, paddle.index_select, x, index)
负值索引越界错误提示:
class TestTakeModeRaiseNeg(TestTakeModeRaisePos):
"""Test negative index out of range error"""
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 6]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 10).reshape(self.index_shape).astype(
self.index_dtype) # negative indices are out of range
不同索引越界处理模式的 cases:
class TestTakeModeWrap(TestTakeAPI):
"""Test take index out of range mode"""
def set_mode(self):
self.mode = 'wrap'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds
class TestTakeModeClip(TestTakeAPI):
"""Test take index out of range mode"""
def set_mode(self):
self.mode = 'clip'
def set_input(self):
self.input_shape = [3, 4]
self.index_shape = [5, 8]
self.input_np = np.arange(0, 12).reshape(self.input_shape).astype(
self.input_dtype)
self.index_np = np.arange(-20, 20).reshape(self.index_shape).astype(
self.index_dtype) # Both ends of the index are out of bounds
if __name__ == "__main__":
unittest.main()
4 成果展示
至此,我们完成了 take API 的开发任务。现在,take API 已经合入 Paddle2.4 版本,详见官方文档
我们可以测试一下代码示例:
import paddle
x_int = paddle.arange(0, 12).reshape([3, 4])
x_float = x_int.astype(paddle.float64)
idx_pos = paddle.arange(4, 10).reshape([2, 3]) # positive index
idx_neg = paddle.arange(-2, 4).reshape([2, 3]) # negative index
idx_err = paddle.arange(-2, 13).reshape([3, 5]) # index out of range
paddle.take(x_int, idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
paddle.take(x_int, idx_neg)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 ],
# [1 , 2 , 3 ]])
paddle.take(x_float, idx_pos)
# Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
# [[4., 5., 6.],
# [7., 8., 9.]])
x_int.take(idx_pos)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
paddle.take(x_int, idx_err, mode='wrap')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[10, 11, 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 0 ]])
paddle.take(x_int, idx_err, mode='clip')
# Tensor(shape=[3, 5], dtype=int32, place=Place(cpu), stop_gradient=True,
# [[0 , 0 , 0 , 1 , 2 ],
# [3 , 4 , 5 , 6 , 7 ],
# [8 , 9 , 10, 11, 11]])
5 总结
Pre-commit
本次任务总体难度一般,工作量也不大,基本完成了 take API 的需求。这个过程中,代码和文档规范是非常重要的部分,需要迭代完善。在进行开发之前,在我们的虚拟环境中先安装好 pre-commit 并按要求配置规范化格式之后,才可以提交自己的修改。
pre-commit 会在每一次的 git commit
自动触发文档规范化检查,同时自动修改,例如代码换行,这时需要再次 git add
和 git commit
,直到通过所有的自动化检查,这是文档规范的首道防线。
我们需要按照贡献指南的流程,补充和修改注释,书写中文和英文文档。
rst 格式的文档预览
我们在代码中的注释都会最终成为英文文档,而中文文档需要我们写成 rst 格式的文档,并提交到 Paddle/docs 仓库。
然而这个格式的文档目前没有一个便捷的预览方式(我用 VS Code 编辑,没有找到合适的插件),我们 push 到 GitHub 之后可以预览一部分的效果,但不多,有些列表缩进、字段高亮和公式不能完整显示,与飞桨官方文档的最终效果有差异。
我们找到了一个 GitHub 的开源工具,它提供了将 rst
文档编译成 html
的工具:ieflex/newretaildoc 。
我现在把使用方法汉化如下:
clone 仓库到本地
git clone https://github.com/ieflex/newretaildoc.git
newretaildoc
|- .git
|- docs
|- ...
...
|- create_html.bat
|- make.bat
|- index.rst # 这些文件不要误删
|- license.rst
|- Makefile
|- README.md
安装样式库
其实没必要开个新的虚拟环境,直接安装在 base 环境即可。
pip install sphinx
pip install sphinx_rtd_theme
编译 rst 文档
- 将需要编译的
rst
文档放入docs
目录下 - 在
docs
目录下呼出 cmd,并激活 base 环境 - 命令行输入
create_html.bat
运行脚本。他会把当前目录下的所有rst
文档全部编译完成的 html 文件在在docs
目录的_build/html
文件夹中,即可在浏览器预览编译结果。
PS. 学有余力的大佬可以把它做成 web,或者 VS Code、浏览器的插件那就最好了
此文章为搬运
原项目链接
更多推荐
所有评论(0)