TypedLinear
- class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[源代码]
基类:
Module
根据类型进行线性变换。
对于输入批量 \(x \in X\) 中的每个样本,应用线性变换 \(xW_t\),其中 \(t\) 是 \(x\) 的类型。
该模块支持 “Modeling Relational Data with Graph Convolutional Networks” 提出的两种正则化方法(基分解和块对角分解)。
基正则化将 \(W_t\) 分解为
\[W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\]其中 \(B\) 是基的数量,\(V_b^{(l)}\) 与系数 \(a_{tb}^{(l)}\) 线性组合。
块对角分解正则化将 \(W_t\) 分解为 \(B\) 个块对角矩阵。我们将 \(B\) 称为基的数量
\[W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}\]其中 \(B\) 是基的数量,\(Q_{tb}^{(l)}\) 是形状为 \(R^{(d^{(l+1)}/B)\times(d^{l}/B)}\) 的块基。
- 参数:
示例
无正则化。
>>> from dgl.nn import TypedLinear >>> import torch >>> >>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])
带基正则化
>>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64])