1、任务介绍

在飞桨黑客松比赛的第三期,我参与了 CINN 算子开发方向的任务,完成了 one_hot 算子的开发,在这里我想分享一下我的开发过程,希望能帮助大家了解 CINN 的算子开发。

CINN 是一种在不改变模型代码的条件下加速飞桨模型运行速度的深度学习编译器。不同于深度学习框架算子,深度学习编译器算子的粒度更细,算子数目也更少,因此在算子融合和自动调优方面具有更大的优势。在对接上层框架时,编译器会将上层的框架算子进一步拆分为若干基础算子,这样做的目的一方面是为了减少算子开发的工作量,仅实现有限的基础算子便可以组合出大量的上层框架算子;另一方面便于算子融合技术在编译器中可以实现跨算子自动融合,减少最终执行时的kernel数目和访存开销,达到更好的性能;此外,结合自动调优技术使得编译器可以自动优化融合后的kernel,提升kernel性能。

我完成的是 one_hot 算子的开发任务。任务需要我们具备编译器的基础知识,了解神经网络的基本原理。如果你还学习过LLVM,开发过程会更轻松,如果没学过也没关系,我们只需参照已有的算子,学会相关 API 的使用方法,任务的重点是理解和运用 CINN IR。

2、设计文档

2.1 算子介绍

one_hot 算子接收5个参数:indiceson_valueoff_valuedepthaxisdtype

算子输出一个新张量,在新张量中,indices 指示的位置的值为 on_value,其它位置的值为 off_valueaxis 表示算子所操作的轴,depth 表示新张量在axis 轴上的长度,dtype 表示新张量的数据类型。

算子的功能描述可能比较难理解,我们可以看一些算子计算示例。

one_hot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=3,
    axis=0,
    dtype="float32"
)
# [[1. 0. 0.]
#  [0. 0. 0.]
#  [0. 1. 1.]]

one_hot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=3,
    axis=-1,
    dtype="float32"
)
# [[1. 0. 0.]
#  [0. 0. 1.]
#  [0. 0. 1.]]

one_hot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=4,
    axis=-1,
    dtype="float32"
)
# [[1. 0. 0. 0.]
#  [0. 0. 1. 0.]
#  [0. 0. 1. 0.]]

one_hot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=4,
    axis=0,
    dtype="float32"
)
# [[1. 0. 0]
#  [0. 0. 1]
#  [0. 0. 1]
#  [0. 0. 0]]

2.2 实现方法

CINN 的结构比较复杂,我刚开始做的时候有些无从下手。为了明确任务的工作内容,我先学习了 CINN 已有的基础算子内容,分析算子开发的共性特征。

新增一个算子主要的工作可分为前端和后端两个部分,例如我们要增加一个名为 op 的算子,需要完成一下的工作内容。

1)前端部分(cinn/frontend)

  • NetBuilder::Op函数:实现算子的前端接口。

2)后端部分(cinn/hlir/op/contrib)

  • Op 函数:实现算子的compute。
  • InferShapeForOp 函数:获取算子的结果张量的 shape。
  • InferDtypeForOp 函数:获取算子的结果张量的数据类型。
  • StrategyForOp 函数:整合算子的 compute 和 schedule。
  • 注册算子:使用CINN_REGISTER_HELPER 注册。

这些函数名称的后缀都是算子名称 op。

任务的重点是是使用 CINN IR 构造算子的 compute,其它内容可参考已有算子的形式,照葫芦画瓢就能完成。

CINN IR是CINN底层进行计算表达的IR(Intermediate Representation),在框架中扮演重要角色。其中,Expr 是 CINN IR 的主要数据类型,它可以表示数值和计算。

下面是一些 Expr 的使用例子。例子中的语句形式,也是开发one hot算子所涉及的全部CINN IR形式,目前我们了解这些就足够了。

// a+b
Expr a(1);
Expr b(1);
Expr c = a + b;

// int类型转换为float类型
Expr d = Cast::Make(common::Str2Type("float32"), a);

// 判断a与b是否相等
Expr e = EQ::Make(a, b)
    
// ?:三元表达式
Expr f = Select::Make(e, a, b)

CINN IR的完整定义可浏览CINN IR抽象语法树
飞桨专家们也提供了算子开发的视频讲解:深度学习框架开发指南 (课节10:深度学习编译器算子应用与开发介绍)。

3 代码开发

在开始代码开发之前,我们需要先阅读CINN项目贡献指南 ,文中介绍了开发环境和PR提交过程。搭建好开发环境,就可以开始编写代码了。

新增 one_hot 算子需要完成以下的工作。

1)前端部分(cinn/frontend)

  • 实现NetBuilder::OneHot函数

2)后端部分(cinn/hlir/op/contrib)

  • 实现 OneHot 函数
  • 实现 InferShapeForOneHot 函数
  • 实现 InferDtypeForOneHot 函数
  • 实现 StrategyForOneHot 函数
  • 注册算子

我们先开发算子的后端,再开发算子的前端。

3.1 算子后端

1)InferDtypeForOneHot

InferDtypeForOneHot函数的实现很简单,从算子的输入参数获得 dtype即可。

std::vector<Type> InferDtypeForOneHot(const std::vector<Type>& inputs_type, const framework::AttrMapType& attrs) {
  CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";

  std::string dtype = "float32";
  if (attrs.find("dtype") != attrs.end()) {
    dtype = absl::get<std::string>(attrs.at("dtype"));
  }

  std::vector<Type> res{common::Str2Type(dtype)};
  return res;
}

2)InferShapeForOneHot

InferShapeForOneHot函数计算输出张量的 shape,我们将 depth 插入输入张量的 shape 的 axis 轴,得到新的shape。

std::vector<framework::shape_t> InferShapeForOneHot(const std::vector<framework::shape_t>& inputs_shape,
                                                    const framework::AttrMapType& attrs) {
  CHECK_EQ(inputs_shape.size(), 3UL) << "The number of one_hot's input should be 3";

  int depth;
  int axis;

  for (auto& iter : attrs) {
    if (iter.first == "depth") {
      depth = absl::get<int>(iter.second);
    } else if (iter.first == "axis") {
      axis = absl::get<int>(iter.second);
    }
  }

  const std::vector<int>& in_shape = inputs_shape[0];
  int ndim                         = static_cast<int>(in_shape.size());
  int true_axis                    = (axis == -1) ? in_shape.size() : axis;
  int indices_index                = 0;
  std::vector<int> new_shape;

  for (int i = 0; i < ndim + 1; ++i) {
    if (i == true_axis) {
      new_shape.push_back(depth);
    } else {
      new_shape.push_back(in_shape[indices_index++]);
    }
  }

  std::vector<std::vector<int>> res{new_shape};
  return res;
}

3)OneHot

OneHot函数中实现算子的 compute,主要过程是参数检查,计算输出张量的 shape,以及使用 CINN IR 构造 compute。

对于新张量X的每个多维索引iter,将iteraxis轴删除得到另一个索引indices_indices。输入张量indices在索引indices_indices处的值,指定了新张量X在索引iter处的整个axis轴的值。
如果indices[indices_indices]iter[axis]相等,那么X[iter]的值取 on_value,否则取off_value。按照这个思路,我们就能构造出compute

ir::Tensor OneHot(const ir::Tensor& indices,
                  const ir::Tensor& on_value,
                  const ir::Tensor& off_value,
                  const int depth,
                  const int axis,
                  const Type& dtype,
                  const std::string& output_name) {
  int ndim = static_cast<int>(indices->shape.size());
  CHECK(axis == -1 || (0 <= axis && axis <= ndim)) << "one_hot only accepts `axis` in [-1, data.ndim]"
                                                   << ", but got axis = " << axis << ", and data.ndim = " << ndim;
  CHECK(depth > 0) << "one_hot only accepts `depth > 0`"
                   << ", but got depth = " << depth;

  CHECK(on_value->shape.size() == 1U && on_value->shape[0].as_int32() == 1U) << "The shape of on_value must be [1]";
  CHECK(off_value->shape.size() == 1U && off_value->shape[0].as_int32() == 1U) << "The shape of off_value must be [1]";

  int true_axis = (axis == -1) ? ndim : axis;
  std::vector<Expr> new_shape;
  int indices_index = 0;

  for (int i = 0; i < ndim + 1; ++i) {
    if (i == true_axis) {
      new_shape.push_back(Expr(depth));
    } else {
      new_shape.push_back(indices->shape[indices_index++]);
    }
  }

  Expr on_value_cast  = ir::Cast::Make(dtype, on_value(Expr(0)));
  Expr off_value_cast = ir::Cast::Make(dtype, off_value(Expr(0)));

  ir::Tensor res = lang::Compute(
      new_shape,
      [=](const std::vector<Expr>& iter) {
        std::vector<Expr> indices_indices;

        for (size_t i = 0; i < iter.size(); i++) {
          if (static_cast<int>(i) == true_axis) {
            continue;
          }
          indices_indices.push_back(iter[i]);
        }

        Expr idx  = iter[true_axis];
        Expr elem = ir::Cast::Make(idx.type(), indices(indices_indices));
        return ir::Select::Make(ir::EQ::Make(elem, idx), on_value_cast, off_value_cast);
      },
      common::UniqName(output_name));

  return res;
}

4)StrategyForOneHot

StrategyForOneHot函数整合算子的 compute 和 schedule,这里schedule的内容与其它算子的保持相同。

std::shared_ptr<framework::OpStrategy> StrategyForOneHot(const framework::NodeAttr& attrs,
                                                         const std::vector<ir::Tensor>& inputs,
                                                         const std::vector<Type>& out_type,
                                                         const std::vector<std::vector<int>>& output_shapes,
                                                         const Target& target) {
  ...

  framework::CINNCompute one_hot_compute([=](lang::Args args, lang::RetValue* ret) {
      
    ...
        
  });

  framework::CINNSchedule one_hot_schedule([=](lang::Args args, lang::RetValue* ret) {
    
    ...
        
  });
    
  //整合算子的 compute 和 schedule
  auto strategy = std::make_shared<framework::OpStrategy>();
  strategy->AddImpl(one_hot_compute, one_hot_schedule, "strategy.one_hot.x86", 1);

  return strategy;
}

5)算子注册

使用CINN_REGISTER_HELPER注册算子。

CINN_REGISTER_HELPER(one_hot_ops) {
  CINN_REGISTER_OP(one_hot)
      .describe(
          "Returns a one-hot tensor where the locations repsented by indices take value `on_value`, "
          "other locations take value `off_value`.")
      .set_num_inputs(3)
      .set_num_outputs(1)
      .set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForOneHot)
      .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForOneHot))
      .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForOneHot))
      .set_support_level(4);

  return true;
}

另外还要在 cinn/hlir/op/use_ops.h 中注册算子,后端的内容就完成了。

CINN_USE_REGISTER(one_hot_ops)

3.2 算子前端

前端的工作比较简单,在NetBuilder中实现OneHot的前端接口,与其它算子类似。

Variable NetBuilder::OneHot(const Variable& indices,
                            const Variable& on_value,
                            const Variable& off_value,
                            const int depth,
                            const int axis,
                            const std::string& dtype) {
  return CustomInstr("one_hot", {indices, on_value, off_value}, {{"depth", depth}, {"axis", axis}, {"dtype", dtype}})
      .front();
}

3.3 算子单测

完成新算子的代码后,必须增加新算子的单测。算子的前端和后端都需要测试:在前端,我们测试算子的计算结果的正确性。在后端,我们测试算子代码生成的结果的正确性。

单测的内容比较模式化,我们可以模仿其它算子的单测进行编写。任务的完整代码可查看PR:add one_hot op

项目编译完成后,我们使用ctest指令运行单测。

ctest -R one_hot_test 
ctest -R net_builder_test 

在开发过程中,我们也可以通过运行单测来打印一些数据,辅助算子代码的调试。

4 总结

CINN 的基础算子开发任务的关键是使用 CINN IR 构造 compute,框架中现有的算子都是很好的学习材料。

深度学习编译器是近年来新兴的开发方向,涉及许多新颖而有趣的知识。如果你对新领域技术有好奇心,想看看业界大牛新的工作和成果,欢迎参与 CINN 的开源项目。

此文章为搬运
原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐