HeteroEmbedding
- class dgl.nn.pytorch.HeteroEmbedding(num_embeddings, embedding_dim)[source]
基类:
Module
创建一个异构嵌入表。
它内部包含多个字典大小不同的
torch.nn.Embedding
。示例
>>> import dgl >>> import torch >>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4) >>> # Get the heterogeneous embedding table >>> embeds = layer.weight >>> print(embeds['user'].shape) torch.Size([2, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([3, 4])
>>> # Get the embeddings for a subset >>> input_ids = {'user': torch.LongTensor([0]), ... ('user', 'follows', 'user'): torch.LongTensor([0, 2])} >>> embeds = layer(input_ids) >>> print(embeds['user'].shape) torch.Size([1, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([2, 4])