dgl.function
这个子包包含 DGL 提供的所有 内置函数。内置函数是 DGL 推荐的表达不同类型 第二章:消息传递 计算 (即通过 update_all()
) 或从节点特征计算边特征 (即通过 apply_edges()
) 的方式。内置函数以符号方式描述节点和边的计算,不涉及实际计算,因此 DGL 可以分析并将它们映射到高效的底层内核。以下是一些示例
import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.num_edges(), 1)) # each edge has feature size 1
# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
# compute edge embedding by multiplying source and destination node embeddings
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))
fn.copy_u
、fn.u_mul_e
、fn.u_mul_v
是内置消息函数,而 fn.sum
和 fn.max
是内置归约函数。DGL 的惯例是使用 u
、v
和 e
分别表示源节点、目标节点和边。例如,copy_u
告诉 DGL 将源节点数据复制作为消息;u_mul_e
告诉 DGL 将源节点特征与边特征相乘。
要定义一元消息函数(例如 copy_u
),请指定一个输入特征名称和一个输出消息名称。要定义二元消息函数(例如 u_mul_e
),请指定两个输入特征名称和一个输出消息名称。在计算过程中,消息函数将读取指定名称下的数据,执行计算,并使用输出名称返回结果。例如,上面的 fn.u_mul_e('h', 'w', 'm')
与以下用户自定义函数相同
def udf_u_mul_e(edges):
return {'m' : edges.src['h'] * edges.data['w']}
要定义归约函数,需要指定一个输入消息名称和一个输出节点特征名称。例如,上面的 fn.max('m', 'h_max')
与以下用户自定义函数相同
def udf_max(nodes):
return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}
所有二元消息函数都支持 广播 (broadcasting),这是一种将元素级操作扩展到不同形状的张量输入上的机制。DGL 通常遵循 NumPy 和 PyTorch 的标准广播语义。以下是一些示例
import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
# case 1
g.ndata['h'] = th.randn((g.num_nodes(), 10))
g.edata['w'] = th.randn((g.num_edges(), 1))
# OK, valid broadcasting between feature shapes (10,) and (1,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 10)
# case 2
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 10))
# OK, valid broadcasting between feature shapes (5, 10) and (10,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 5, 10)
# case 3
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 5))
# NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
# shapes are aligned from right
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
# case 3
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
# OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
g.apply_edges(fn.u_add_v('h1', 'h2', 'x')) # apply_edges also supports broadcasting
g.edata['x'] # shape: (g.num_edges(), 5, 10)
# case 4
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
# OK, u_dot_v supports broadcasting but requires the last dimension to match
g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
g.edata['x'] # shape: (g.num_edges(), 5, 10, 1)
DGL 内置函数
这里列出了所有 DGL 内置函数的速查表。
类别 |
函数 |
备注 |
---|---|---|
一元消息函数 |
|
|
|
||
二元消息函数 |
|
|
|
||
|
||
|
||
|
||
|
||
归约函数 |
|
|
|
||
|
||
|
消息函数
|
使用源节点特征计算消息的内置消息函数。 |
|
使用边特征计算消息的内置消息函数。 |
|
通过对 u 和 v 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 v 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 v 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 v 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 e 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 e 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 e 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 e 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 u 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 u 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 u 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 u 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 e 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 e 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 e 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 e 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 u 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 u 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 u 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 u 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 v 的特征进行元素级加法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 v 的特征进行元素级减法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 v 的特征进行元素级乘法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 v 的特征进行元素级除法运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 v 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 u 和 e 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 e 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 v 和 u 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 u 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
|
通过对 e 和 v 的特征进行元素级点积运算来计算边上的消息的内置消息函数,如果特征形状相同;否则先将特征广播到新的形状再进行元素级运算。 |
归约函数
|
通过求和聚合消息的内置归约函数。 |
|
通过求最大值聚合消息的内置归约函数。 |
|
通过求最小值聚合消息的内置归约函数。 |
|
通过求平均值聚合消息的内置归约函数。 |