HeteroLinear

class dgl.nn.pytorch.HeteroLinear(in_size, out_size, bias=True)[source]

基类:Module

对异构输入应用线性变换。

参数:
  • in_size (dict[key, int]) – 异构输入的输入特征大小。键可以是字符串或字符串元组。

  • out_size (int) – 输出特征大小。

  • bias (bool, optional) – 如果为 True,则学习一个偏置项。默认值:True

示例

>>> import dgl
>>> import torch
>>> from dgl.nn import HeteroLinear
>>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3)
>>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)}
>>> out_feats = layer(in_feats)
>>> print(out_feats['user'].shape)
torch.Size([2, 3])
>>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3])
forward(feat)[source]

前向函数

参数:

feat (dict[key, Tensor]) – 异构输入特征。它将键映射到特征。

返回值:

变换后的特征。

返回类型:

dict[key, Tensor]