DropNode
- class dgl.transforms.DropNode(p=0.5)[source]
基类:
BaseTransform
随机丢弃节点,如 Graph Contrastive Learning with Augmentations 中所述。
- 参数:
p (float, optional) – 节点被丢弃的概率。
示例
>>> import dgl >>> import torch >>> from dgl import DropNode
>>> transform = DropNode() >>> g = dgl.rand_graph(5, 20) >>> g.ndata['h'] = torch.arange(g.num_nodes()) >>> g.edata['h'] = torch.arange(g.num_edges()) >>> new_g = transform(g) >>> print(new_g) Graph(num_nodes=3, num_edges=7, ndata_schemes={'h': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}) >>> print(new_g.ndata['h']) tensor([0, 1, 2]) >>> print(new_g.edata['h']) tensor([0, 6, 14, 5, 17, 3, 11])