dgl.udf.NodeBatch.nodes

NodeBatch.nodes()[source]

返回批次中的节点。

返回:

NID – 批次中节点的 ID。 \(NID[i]\) 表示第 i 个节点的 ID。

返回类型:

Tensor

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch
>>> # Instantiate a graph and set a feature 'h'.
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> g.ndata['h'] = torch.ones(2, 1)
>>> # Define a UDF that computes the sum of the messages received and
>>> # the original ID for each node.
>>> def node_udf(nodes):
>>>     # nodes.nodes() is a tensor of shape (N),
>>>     # nodes.mailbox['m'] is a tensor of shape (N, D, 1),
>>>     # where N is the number of nodes in the batch, D is the number
>>>     # of messages received per node for this node batch.
>>>     return {'h': nodes.nodes().unsqueeze(-1).float()
>>>         + nodes.mailbox['m'].sum(1)}
>>> # Use node UDF in message passing.
>>> import dgl.function as fn
>>> g.update_all(fn.copy_u('h', 'm'), node_udf)
>>> g.ndata['h']
tensor([[1.],
        [3.]])