DGL 外部函数接口 (FFI)
我们都喜欢 Python,因为它易于操作。我们也都喜欢 C 语言,因为它快速、可靠且有类型。为了兼顾这两者的优点,DGL 主要用 Python 编写,以便快速原型开发,同时将性能关键部分下放到 C 语言。因此,DGL 开发者经常面临编写 C 例程并通过一种称为 外部函数接口 (FFI) 的机制将其暴露给 Python 的情况。
市面上有许多 FFI 解决方案。在 DGL 中,我们希望它对于关键用例来说简单、直观且高效。这就是为什么当我们偶然发现 TVM 项目中的 FFI 解决方案时,我们立即为之倾倒的原因。它利用了函数式编程的思想,因此只暴露了几十个 C API,而新的 API 可以在此基础上构建。
我们决定(厚颜无耻地)借鉴这个想法。例如,要定义一个暴露给 Python 的 C API,只需几行代码
// file: calculator.cc (put it in dgl/src folder)
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime;
DGL_REGISTER_GLOBAL("calculator.MyAdd")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int a = args[0];
int b = args[1];
*rv = a + b;
});
编译并构建库。在 Python 端,在 dgl/python/dgl/
目录下创建一个名为 calculator.py
的文件
# file: calculator.py
from ._ffi.function import _init_api
def add(a, b):
# MyAdd has been registered via `_ini_api` call below
return MyAdd(a, b)
_init_api("dgl.calculator")
这里的技巧是 FFI 系统首先屏蔽了函数参数的类型信息,因此所有的 C 函数调用都可以通过一个 C API(DGLFuncCall
)。类型信息在函数体中通过静态转换检索,并且我们会进行运行时类型检查以确保类型转换正确。只要函数调用不是太轻量级(上面这个例子实际上不是个好例子),这种来回的开销是可以忽略不计的。TVM 的 PackedFunc 文档有更多细节。
定义新类型
DGLArgs
和 DGLRetValue
只支持有限数量的类型
数值类型:int, float, double, …
string
函数 (以 PackedFunc 的形式)
NDArray
尽管有限,但上述类型系统非常强大,因为它支持函数作为一等公民。例如,如果你想返回多个值,可以返回一个 PackedFunc,它根据整数索引返回每个值。然而,在许多情况下,仍然需要新类型来简化开发过程
参数/返回值是集合的组合(例如,字典嵌套字典嵌套列表)。
有时我们只是想有一个“结构”的概念(例如,给定一个苹果,通过
apple.color
获取它的颜色)。
为了实现这一点,我们引入了 Object 类型系统。例如,要定义一个新类型 Calculator
// file: calculator.cc
#include <dgl/packed_func_ext.h>
using namespace runtime;
class CalculatorObject : public Object {
public:
std::string brand;
int price;
void VisitAttrs(AttrVisitor *v) final {
v->Visit("brand", &brand);
v->Visit("price", &price);
}
static constexpr const char* _type_key = "Calculator";
DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);
};
// This is to define a reference class (the wrapper of an object shared pointer).
// A minimal implementation is as follows, but you could define extra methods.
class Calculator : public ObjectRef {
public:
const CalculatorObject* operator->() const {
return static_cast<const CalculatorObject*>(obj_.get());
}
using ContainerType = CalculatorObject;
};
DGL_REGISTER_GLOBAL("calculator.CreateCaculator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string brand = args[0];
int price = args[1];
auto o = std::make_shared<CalculatorObject>();
o->brand = brand;
o->price = price;
*rv = o;
}
在 Python 端
# file: calculator.py
from dgl._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
@register_object
class Calculator(ObjectBase):
@staticmethod
def create(brand, price):
# invoke a C API, the return value is of `Calculator` type
return CreateCalculator(brand, price)
_init_api("dgl.calculator")
然后我们可以简单地通过以下方式创建 Calculator
对象
calc = Calculator.create("casio", 100)
这个对象的优点在于,它定义了一个访问者模式,这本质上是一种反射机制,用于获取其内部属性。例如,你可以简单地访问计算器的品牌属性并打印出来。
print(calc.brand)
print(calc.price)
由于字符串键查找,反射确实有点慢。为了加速它,可以定义一个属性访问 API
// file: calculator.cc
DGL_REGISTER_GLOBAL("calculator.CaculatorGetBrand")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Calculator calc = args[0];
*rv = calc->brand;
}
容器
容器也是对象。例如,下面的 C API 接受一个整数列表并返回它们的总和
// in file: calculator.cc
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.Sum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// All the DGL supported values are represented as a ValueObject, which
// contains a data field.
List<Value> values = args[0];
int sum = 0;
for (int i = 0; i < values.size(); ++i) {
sum += static_cast<int>(values[i]->data);
}
}
调用这个 API 很简单——只需传递一个 Python 整数列表。DGL FFI 将自动把 Python 的 list/tuple/dictionary 转换为相应的对象类型。
# in file: calculator.py
from ._ffi.function import _init_api
Sum([0, 1, 2, 3, 4, 5])
_init_api("dgl.calculator")
容器中的元素可以是任何对象,这使得容器可以组合。下面是一个接受计算器列表并打印出它们价格的 API
// in file: calculator.cc
#include <iostream>
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.PrintCalculators")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<Calculator> calcs = args[0];
for (int i = 0; i < calcs.size(); ++i) {
std::cout << calcs[i]->price << std::endl;
}
}
请注意,容器不适用于在 C API 之间传递大量项集合。在这些情况下会相当慢。建议先进行基准测试。作为替代方案,对于大量的数值集合,请使用 NDArray;对于许多 DGLGraph
,请使用 dgl.batch
将它们批处理成一个单独的 DGLGraph
。