Sequential
- class dgl.nn.pytorch.utils.Sequential(*args)[source]
基类:
Sequential
用于堆叠图神经网络模块的序列容器
DGL 支持两种模式:1) 在同一图上顺序应用 GNN 模块;或 2) 在给定图列表上顺序应用。在第二种情况下,图的数量等于此容器内模块的数量。
- 参数:
*args – torch.nn.Module 的子模块,将按照它们在构造函数中传递的顺序添加到容器中。
示例
以下示例使用 PyTorch 后端。
模式 1:在同一图上顺序应用 GNN 模块
>>> import torch >>> import dgl >>> import torch.nn as nn >>> import dgl.function as fn >>> from dgl.nn.pytorch import Sequential >>> class ExampleLayer(nn.Module): >>> def __init__(self): >>> super().__init__() >>> def forward(self, graph, n_feat, e_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> graph.apply_edges(fn.u_add_v('h', 'h', 'e')) >>> e_feat += graph.edata['e'] >>> return n_feat, e_feat >>> >>> g = dgl.DGLGraph() >>> g.add_nodes(3) >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2]) >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer()) >>> n_feat = torch.rand(3, 4) >>> e_feat = torch.rand(9, 4) >>> net(g, n_feat, e_feat) (tensor([[39.8597, 45.4542, 25.1877, 30.8086], [40.7095, 45.3985, 25.4590, 30.0134], [40.7894, 45.2556, 25.5221, 30.4220]]), tensor([[80.3772, 89.7752, 50.7762, 60.5520], [80.5671, 89.3736, 50.6558, 60.6418], [80.4620, 89.5142, 50.3643, 60.3126], [80.4817, 89.8549, 50.9430, 59.9108], [80.2284, 89.6954, 50.0448, 60.1139], [79.7846, 89.6882, 50.5097, 60.6213], [80.2654, 90.2330, 50.2787, 60.6937], [80.3468, 90.0341, 50.2062, 60.2659], [80.0556, 90.2789, 50.2882, 60.5845]]))
模式 2:在不同图上顺序应用 GNN 模块
>>> import torch >>> import dgl >>> import torch.nn as nn >>> import dgl.function as fn >>> import networkx as nx >>> from dgl.nn.pytorch import Sequential >>> class ExampleLayer(nn.Module): >>> def __init__(self): >>> super().__init__() >>> def forward(self, graph, n_feat): >>> with graph.local_scope(): >>> graph.ndata['h'] = n_feat >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> n_feat += graph.ndata['h'] >>> return n_feat.view(graph.num_nodes() // 2, 2, -1).sum(1) >>> >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)) >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)) >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8)) >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer()) >>> n_feat = torch.rand(32, 4) >>> net([g1, g2, g3], n_feat) tensor([[209.6221, 225.5312, 193.8920, 220.1002], [250.0169, 271.9156, 240.2467, 267.7766], [220.4007, 239.7365, 213.8648, 234.9637], [196.4630, 207.6319, 184.2927, 208.7465]])