TypedLinear

class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[源代码]

基类: Module

根据类型进行线性变换。

对于输入批量 \(x \in X\) 中的每个样本,应用线性变换 \(xW_t\),其中 \(t\)\(x\) 的类型。

该模块支持 “Modeling Relational Data with Graph Convolutional Networks” 提出的两种正则化方法(基分解和块对角分解)。

基正则化将 \(W_t\) 分解为

\[W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\]

其中 \(B\) 是基的数量,\(V_b^{(l)}\) 与系数 \(a_{tb}^{(l)}\) 线性组合。

块对角分解正则化将 \(W_t\) 分解为 \(B\) 个块对角矩阵。我们将 \(B\) 称为基的数量

\[W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}\]

其中 \(B\) 是基的数量,\(Q_{tb}^{(l)}\) 是形状为 \(R^{(d^{(l+1)}/B)\times(d^{l}/B)}\) 的块基。

参数:
  • in_size (int) – 输入特征大小。

  • out_size (int) – 输出特征大小。

  • num_types (int) – 总类型数。

  • regularizer (str, 可选) –

    要使用的权重正则化器:“basis” 或 “bdd”

    • “basis” 是基分解的缩写。

    • “bdd” 是块对角分解的缩写。

    默认不应用正则化。

  • num_bases (int, 可选) – 基的数量。当指定 regularizer 时需要。通常小于 num_types。默认值: None

示例

无正则化。

>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])

带基正则化

>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
forward(x, x_type, sorted_by_type=False)[源代码]

前向计算。

参数:
  • x (torch.Tensor) – 一个 2D 输入张量。形状: (N, D1)

  • x_type (torch.Tensor) – 一个 1D 整数张量,按一一对应关系存储 x 中元素的类型。形状: (N,)

  • sorted_by_type (bool, 可选) – 输入是否已按类型排序。对预排序输入进行前向计算可能更快。

返回值:

y – 变换后的输出张量。形状: (N, D2)

返回类型:

torch.Tensor

reset_parameters()[源代码]

重置参数