HeteroEmbedding

class dgl.nn.pytorch.HeteroEmbedding(num_embeddings, embedding_dim)[source]

基类: Module

创建一个异构嵌入表。

它内部包含多个字典大小不同的 torch.nn.Embedding

参数:
  • num_embeddings (dict[key, int]) – 字典的大小。键可以是字符串或字符串元组。

  • embedding_dim (int) – 每个嵌入向量的大小。

示例

>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
>>> # Get the heterogeneous embedding table
>>> embeds = layer.weight
>>> print(embeds['user'].shape)
torch.Size([2, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([3, 4])
>>> # Get the embeddings for a subset
>>> input_ids = {'user': torch.LongTensor([0]),
...              ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
>>> embeds = layer(input_ids)
>>> print(embeds['user'].shape)
torch.Size([1, 4])
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
forward(input_ids)[source]

前向函数

参数:

input_ids (dict[key, Tensor]) – 用于检索嵌入的行 ID。它将键映射到特定于键的 ID。

返回:

检索到的嵌入。

返回类型:

dict[key, Tensor]

reset_parameters()[source]

使用 nn.init 模块中的 xavier 方法使参数均匀分布