dgl.function

这个子包包含 DGL 提供的所有 内置函数。内置函数是 DGL 推荐的表达不同类型 第二章:消息传递 计算 (即通过 update_all()) 或从节点特征计算边特征 (即通过 apply_edges()) 的方式。内置函数以符号方式描述节点和边的计算,不涉及实际计算,因此 DGL 可以分析并将它们映射到高效的底层内核。以下是一些示例

import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.num_edges(), 1))  # each edge has feature size 1
# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
# compute edge embedding by multiplying source and destination node embeddings
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))

fn.copy_ufn.u_mul_efn.u_mul_v 是内置消息函数,而 fn.sumfn.max 是内置归约函数。DGL 的惯例是使用 uve 分别表示源节点、目标节点和边。例如,copy_u 告诉 DGL 将源节点数据复制作为消息;u_mul_e 告诉 DGL 将源节点特征与边特征相乘。

要定义一元消息函数(例如 copy_u),请指定一个输入特征名称和一个输出消息名称。要定义二元消息函数(例如 u_mul_e),请指定两个输入特征名称和一个输出消息名称。在计算过程中,消息函数将读取指定名称下的数据,执行计算,并使用输出名称返回结果。例如,上面的 fn.u_mul_e('h', 'w', 'm') 与以下用户自定义函数相同

def udf_u_mul_e(edges):
   return {'m' : edges.src['h'] * edges.data['w']}

要定义归约函数,需要指定一个输入消息名称和一个输出节点特征名称。例如,上面的 fn.max('m', 'h_max') 与以下用户自定义函数相同

def udf_max(nodes):
   return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}

所有二元消息函数都支持 广播 (broadcasting),这是一种将元素级操作扩展到不同形状的张量输入上的机制。DGL 通常遵循 NumPyPyTorch 的标准广播语义。以下是一些示例

import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph

# case 1
g.ndata['h'] = th.randn((g.num_nodes(), 10))
g.edata['w'] = th.randn((g.num_edges(), 1))
# OK, valid broadcasting between feature shapes (10,) and (1,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new']  # shape: (g.num_nodes(), 10)

# case 2
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 10))
# OK, valid broadcasting between feature shapes (5, 10) and (10,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new']  # shape: (g.num_nodes(), 5, 10)

# case 3
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 5))
# NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
# shapes are aligned from right
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))

# case 3
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
# OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
g.apply_edges(fn.u_add_v('h1', 'h2', 'x'))  # apply_edges also supports broadcasting
g.edata['x']  # shape: (g.num_edges(), 5, 10)

# case 4
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
# OK, u_dot_v supports broadcasting but requires the last dimension to match
g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
g.edata['x']  # shape: (g.num_edges(), 5, 10, 1)

DGL 内置函数

这里列出了所有 DGL 内置函数的速查表。

类别

函数

备注

一元消息函数

copy_u

copy_e

二元消息函数

u_add_v, u_sub_v, u_mul_v, u_div_v, u_dot_v

u_add_e, u_sub_e, u_mul_e, u_div_e, u_dot_e

v_add_u, v_sub_u, v_mul_u, v_div_u, v_dot_u

v_add_e, v_sub_e, v_mul_e, v_div_e, v_dot_e

e_add_u, e_sub_u, e_mul_u, e_div_u, e_dot_u

e_add_v, e_sub_v, e_mul_v, e_div_v, e_dot_v

归约函数

max

min

sum

mean

消息函数

copy_u(u, out)

使用源节点特征计算消息的内置消息函数。

copy_e(e, out)

使用边特征计算消息的内置消息函数。

u_add_v(lhs_field, rhs_field, out)

通过对 u 和 v 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_sub_v(lhs_field, rhs_field, out)

通过对 u 和 v 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_mul_v(lhs_field, rhs_field, out)

通过对 u 和 v 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_div_v(lhs_field, rhs_field, out)

通过对 u 和 v 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_add_e(lhs_field, rhs_field, out)

通过对 u 和 e 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_sub_e(lhs_field, rhs_field, out)

通过对 u 和 e 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_mul_e(lhs_field, rhs_field, out)

通过对 u 和 e 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_div_e(lhs_field, rhs_field, out)

通过对 u 和 e 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_add_u(lhs_field, rhs_field, out)

通过对 v 和 u 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_sub_u(lhs_field, rhs_field, out)

通过对 v 和 u 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_mul_u(lhs_field, rhs_field, out)

通过对 v 和 u 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_div_u(lhs_field, rhs_field, out)

通过对 v 和 u 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_add_e(lhs_field, rhs_field, out)

通过对 v 和 e 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_sub_e(lhs_field, rhs_field, out)

通过对 v 和 e 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_mul_e(lhs_field, rhs_field, out)

通过对 v 和 e 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_div_e(lhs_field, rhs_field, out)

通过对 v 和 e 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_add_u(lhs_field, rhs_field, out)

通过对 e 和 u 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_sub_u(lhs_field, rhs_field, out)

通过对 e 和 u 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_mul_u(lhs_field, rhs_field, out)

通过对 e 和 u 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_div_u(lhs_field, rhs_field, out)

通过对 e 和 u 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_add_v(lhs_field, rhs_field, out)

通过对 e 和 v 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_sub_v(lhs_field, rhs_field, out)

通过对 e 和 v 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_mul_v(lhs_field, rhs_field, out)

通过对 e 和 v 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_div_v(lhs_field, rhs_field, out)

通过对 e 和 v 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_dot_v(lhs_field, rhs_field, out)

通过对 u 和 v 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

u_dot_e(lhs_field, rhs_field, out)

通过对 u 和 e 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_dot_e(lhs_field, rhs_field, out)

通过对 v 和 e 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

v_dot_u(lhs_field, rhs_field, out)

通过对 v 和 u 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_dot_u(lhs_field, rhs_field, out)

通过对 e 和 u 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

e_dot_v(lhs_field, rhs_field, out)

通过对 e 和 v 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。

归约函数

sum(msg, out)

通过求和聚合消息的内置归约函数。

max(msg, out)

通过求最大值聚合消息的内置归约函数。

min(msg, out)

通过求最小值聚合消息的内置归约函数。

mean(msg, out)

通过求平均值聚合消息的内置归约函数。