WeightBasis

class dgl.nn.pytorch.utils.WeightBasis(shape, num_bases, num_outputs)[source]

基类: Module

基于 Modeling Relational Data with Graph Convolutional Networks 的基分解

可以描述如下:

\[W_o = \sum_{b=1}^B a_{ob} V_b\]

每个权重输出 \(W_o\) 本质上是基变换 \(V_b\) 的线性组合,其系数为 \(a_{ob}\)

这对于大型参数矩阵来说是一种有用的正则化形式。因此,权重输出的数量通常大于基的数量。

参数:
  • shape (tuple[int]) – 基参数的形状。

  • num_bases (int) – 基的数量。

  • num_outputs (int) – 输出的数量。

forward()[source]

前向计算

返回:

weight – 组成的权重张量,形状为 (num_outputs,) + shape

返回类型:

torch.Tensor