DeepWalk

class dgl.nn.pytorch.DeepWalk(g, emb_dim=128, walk_length=40, window_size=5, neg_weight=1, negative_size=5, fast_neg=True, sparse=True)[source]

基类: Module

基于 DeepWalk: Online Learning of Social Representations 的 DeepWalk 模块

对于给定的图,它通过最大化相邻节点对(正样本节点对)的相似度并最小化其他随机节点对(负样本节点对)的相似度来从头学习节点表示。

参数:
  • g (DGLGraph) – 用于学习节点嵌入的图

  • emb_dim (int, optional) – 每个嵌入向量的大小。默认值:128

  • walk_length (int, optional) – 随机游走序列中的节点数。默认值:40

  • window_size (int, optional) – 在随机游走 w 中,如果 i - window_size <= j <= i + window_size,则节点 w[j] 被认为与节点 w[i] 接近。默认值:5

  • neg_weight (float, optional) – 总损失中负样本损失项的权重。默认值:1.0

  • negative_size (int, optional) – 每个正样本使用的负样本数。默认值:5

  • fast_neg (bool, optional) – 如果为 True,则在随机游走批次中采样负样本节点对。默认值:True

  • sparse (bool, optional) – 如果为 True,则可学习权重的梯度将是稀疏的。默认值:True

node_embed

节点的嵌入表

类型:

nn.Embedding

示例

>>> import torch
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import DeepWalk
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from sklearn.linear_model import LogisticRegression
>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> model = DeepWalk(g)
>>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128,
...                         shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.01)
>>> num_epochs = 5
>>> for epoch in range(num_epochs):
...     for batch_walk in dataloader:
...         loss = model(batch_walk)
...         optimizer.zero_grad()
...         loss.backward()
...         optimizer.step()
>>> train_mask = g.ndata['train_mask']
>>> test_mask = g.ndata['test_mask']
>>> X = model.node_embed.weight.detach()
>>> y = g.ndata['label']
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
forward(batch_walk)[source]

计算随机游走批次的损失

参数:

batch_walk (torch.Tensor) – 以节点 ID 序列形式表示的随机游走。Tensor 的形状为 (batch_size, walk_length)

返回值:

损失值

返回类型:

torch.Tensor

reset_parameters()[source]

重新初始化可学习参数