dgl.udf.NodeBatch.batch_size
- NodeBatch.batch_size()[source]
返回批处理中的节点数。
- 返回类型:
示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
>>> # Instantiate a graph. >>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0]))) >>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received for >>> # each node and increments the result by 1. >>> def node_udf(nodes): >>> return {'h': torch.ones(nodes.batch_size(), 1) >>> + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing. >>> import dgl.function as fn >>> g.update_all(fn.copy_u('h', 'm'), node_udf) >>> g.ndata['h'] tensor([[2.], [3.]])