DGL 外部函数接口 (FFI)

我们都喜欢 Python,因为它易于操作。我们也都喜欢 C 语言,因为它快速、可靠且有类型。为了兼顾这两者的优点,DGL 主要用 Python 编写,以便快速原型开发,同时将性能关键部分下放到 C 语言。因此,DGL 开发者经常面临编写 C 例程并通过一种称为 外部函数接口 (FFI) 的机制将其暴露给 Python 的情况。

市面上有许多 FFI 解决方案。在 DGL 中,我们希望它对于关键用例来说简单、直观且高效。这就是为什么当我们偶然发现 TVM 项目中的 FFI 解决方案时,我们立即为之倾倒的原因。它利用了函数式编程的思想,因此只暴露了几十个 C API,而新的 API 可以在此基础上构建。

我们决定(厚颜无耻地)借鉴这个想法。例如,要定义一个暴露给 Python 的 C API,只需几行代码

// file: calculator.cc (put it in dgl/src folder)
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>

using namespace dgl::runtime;

DGL_REGISTER_GLOBAL("calculator.MyAdd")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    int a = args[0];
    int b = args[1];
    *rv = a + b;
  });

编译并构建库。在 Python 端,在 dgl/python/dgl/ 目录下创建一个名为 calculator.py 的文件

# file: calculator.py
from ._ffi.function import _init_api

def add(a, b):
  # MyAdd has been registered via `_ini_api` call below
  return MyAdd(a, b)

_init_api("dgl.calculator")

这里的技巧是 FFI 系统首先屏蔽了函数参数的类型信息,因此所有的 C 函数调用都可以通过一个 C API(DGLFuncCall)。类型信息在函数体中通过静态转换检索,并且我们会进行运行时类型检查以确保类型转换正确。只要函数调用不是太轻量级(上面这个例子实际上不是个好例子),这种来回的开销是可以忽略不计的。TVM 的 PackedFunc 文档有更多细节。

定义新类型

DGLArgsDGLRetValue 只支持有限数量的类型

  • 数值类型:int, float, double, …

  • string

  • 函数 (以 PackedFunc 的形式)

  • NDArray

尽管有限,但上述类型系统非常强大,因为它支持函数作为一等公民。例如,如果你想返回多个值,可以返回一个 PackedFunc,它根据整数索引返回每个值。然而,在许多情况下,仍然需要新类型来简化开发过程

  • 参数/返回值是集合的组合(例如,字典嵌套字典嵌套列表)。

  • 有时我们只是想有一个“结构”的概念(例如,给定一个苹果,通过 apple.color 获取它的颜色)。

为了实现这一点,我们引入了 Object 类型系统。例如,要定义一个新类型 Calculator

// file: calculator.cc
#include <dgl/packed_func_ext.h>
using namespace runtime;
class CalculatorObject : public Object {
 public:
  std::string brand;
  int price;

  void VisitAttrs(AttrVisitor *v) final {
    v->Visit("brand", &brand);
    v->Visit("price", &price);
  }

  static constexpr const char* _type_key = "Calculator";
  DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);
};

// This is to define a reference class (the wrapper of an object shared pointer).
// A minimal implementation is as follows, but you could define extra methods.
class Calculator : public ObjectRef {
 public:
  const CalculatorObject* operator->() const {
    return static_cast<const CalculatorObject*>(obj_.get());
  }
  using ContainerType = CalculatorObject;
};

DGL_REGISTER_GLOBAL("calculator.CreateCaculator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  std::string brand = args[0];
  int price = args[1];
  auto o = std::make_shared<CalculatorObject>();
  o->brand = brand;
  o->price = price;
  *rv = o;
}

在 Python 端

# file: calculator.py
from dgl._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api

@register_object
class Calculator(ObjectBase):
  @staticmethod
  def create(brand, price):
    # invoke a C API, the return value is of `Calculator` type
    return CreateCalculator(brand, price)

_init_api("dgl.calculator")

然后我们可以简单地通过以下方式创建 Calculator 对象

calc = Calculator.create("casio", 100)

这个对象的优点在于,它定义了一个访问者模式,这本质上是一种反射机制,用于获取其内部属性。例如,你可以简单地访问计算器的品牌属性并打印出来。

print(calc.brand)
print(calc.price)

由于字符串键查找,反射确实有点慢。为了加速它,可以定义一个属性访问 API

// file: calculator.cc
DGL_REGISTER_GLOBAL("calculator.CaculatorGetBrand")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  Calculator calc = args[0];
  *rv = calc->brand;
}

容器

容器也是对象。例如,下面的 C API 接受一个整数列表并返回它们的总和

// in file: calculator.cc
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.Sum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  // All the DGL supported values are represented as a ValueObject, which
  //   contains a data field.
  List<Value> values = args[0];
  int sum = 0;
  for (int i = 0; i < values.size(); ++i) {
    sum += static_cast<int>(values[i]->data);
  }
}

调用这个 API 很简单——只需传递一个 Python 整数列表。DGL FFI 将自动把 Python 的 list/tuple/dictionary 转换为相应的对象类型。

# in file: calculator.py
from ._ffi.function import _init_api

Sum([0, 1, 2, 3, 4, 5])

_init_api("dgl.calculator")

容器中的元素可以是任何对象,这使得容器可以组合。下面是一个接受计算器列表并打印出它们价格的 API

// in file: calculator.cc
#include <iostream>
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.PrintCalculators")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
  List<Calculator> calcs = args[0];
  for (int i = 0; i < calcs.size(); ++i) {
    std::cout << calcs[i]->price << std::endl;
  }
}

请注意,容器不适用于在 C API 之间传递大量项集合。在这些情况下会相当慢。建议先进行基准测试。作为替代方案,对于大量的数值集合,请使用 NDArray;对于许多 DGLGraph,请使用 dgl.batch 将它们批处理成一个单独的 DGLGraph