FeatMask

class dgl.transforms.FeatMask(p=0.5, node_feat_names=None, edge_feat_names=None)[源码]

基类:BaseTransform

随机遮蔽节点和边特征张量的列,如图对比学习与增强中所述。

参数:
  • p (float, optional) – 遮蔽特征张量列的概率。默认值:0.5

  • node_feat_names (list[str], optional) – 要被遮蔽的节点特征张量的名称。默认值:None,表示不遮蔽任何节点特征张量。

  • edge_feat_names (list[str], optional) – 要被遮蔽的边特征的名称。默认值:None,表示不遮蔽任何边特征张量。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch
>>> from dgl import FeatMask

情况 1:遮蔽同构图的节点和边特征张量。

>>> transform = FeatMask(node_feat_names=['h'], edge_feat_names=['w'])
>>> g = dgl.rand_graph(5, 10)
>>> g.ndata['h'] = torch.ones((g.num_nodes(), 10))
>>> g.edata['w'] = torch.ones((g.num_edges(), 10))
>>> g = transform(g)
>>> print(g.ndata['h'])
tensor([[0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
        [0., 0., 1., 1., 0., 0., 1., 1., 1., 0.]])
>>> print(g.edata['w'])
tensor([[1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 1., 0., 1., 0., 0., 0., 1.]])

情况 2:遮蔽异构图的节点和边特征张量。

>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])),
...     ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1]))
... })
>>> g.ndata['h'] = {'game': torch.ones(2, 5), 'player': torch.ones(3, 5)}
>>> g.edata['w'] = {('user', 'follows', 'user'): torch.ones(2, 5)}
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
>>> g = transform(g)
>>> print(g.ndata['h']['game'])
tensor([[1., 1., 0., 1., 0.],
        [1., 1., 0., 1., 0.]])
>>> print(g.edata['w'][('user', 'follows', 'user')])
tensor([[0., 1., 0., 1., 0.],
        [0., 1., 0., 1., 0.]])