MetaPath2Vec

class dgl.nn.pytorch.MetaPath2Vec(g, metapath, window_size, emb_dim=128, negative_size=5, sparse=True)[source]

基类: Module

metapath2vec 模块,来自论文 metapath2vec: Heterogeneous Networks 的可扩展表示学习

为了实现高效优化,我们在训练过程中利用了负采样技术。对于元路径中的每个节点,我们重复将其视为中心节点,并在上下文窗口大小内采样附近的正节点,并在所有元路径中的所有类型的节点中抽取负样本。然后,我们可以使用中心-上下文配对节点和上下文-负配对节点来更新网络。

参数:
  • g (DGLGraph) – 用于学习节点嵌入的图。不允许两种不同的规范边类型 (utype, etype, vtype) 具有相同的 etype

  • metapath (list[str]) – 字符串形式的边类型序列。它通过按顺序组合多个边类型来定义一种新的边类型。注意,起始节点类型和结束节点类型通常是相同的。

  • window_size (int) – 在随机游走 w 中,如果 i - window_size <= j <= i + window_size,则节点 w[j] 被认为是靠近节点 w[i] 的。

  • emb_dim (int, 可选) – 每个嵌入向量的大小。默认值: 128

  • negative_size (int, 可选) – 每个正样本使用的负样本数量。默认值: 5

  • sparse (bool, 可选) – 如果为 True,相对于可学习权重的梯度将是稀疏的。默认值: True

node_embed

所有节点的嵌入表

类型:

nn.Embedding

local_to_global_nid

从类型特定节点 ID 到全局节点 ID 的映射

类型:

dict[str, list]

示例

>>> import torch
>>> import dgl
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from dgl.nn.pytorch import MetaPath2Vec
>>> # Define a model
>>> g = dgl.heterograph({
...     ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(),
...     ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(),
...     ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(),
...     ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges()
... })
>>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
>>> # Use the source node type of etype 'uc'
>>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128,
...                         shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.025)
>>> for (pos_u, pos_v, neg_v) in dataloader:
...     loss = model(pos_u, pos_v, neg_v)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Get the embeddings of all user nodes
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids)
forward(pos_u, pos_v, neg_v)[source]

计算一批正样本和负样本的损失

参数:
  • pos_u (torch.Tensor) – 正中心节点

  • pos_v (torch.Tensor) – 正上下文节点

  • neg_v (torch.Tensor) – 负上下文节点

返回:

损失值

返回类型:

torch.Tensor

reset_parameters()[source]

重新初始化可学习参数