LabelPropagation
- class dgl.nn.pytorch.utils.LabelPropagation(k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False)[source]
Bases:
Module
来自 Learning from Labeled and Unlabeled Data with Label Propagation 的标签传播
\[\mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}\]其中无标签数据初始设为零,并通过传播从有标签数据推断得出。\(\alpha\) 是一个权重参数,用于平衡更新后的标签和初始标签。\(\tilde{A}\) 表示归一化邻接矩阵。
- 参数:
k (int) – 传播步数。
alpha (float) – 范围在 [0, 1] 的 \(\alpha\) 系数。
norm_type (str, optional) –
应用于邻接矩阵的归一化类型,必须是以下选项之一
row
: 行归一化邻接矩阵,表示为 \(D^{-1}A\)sym
: 对称归一化邻接矩阵,表示为 \(D^{-1/2}AD^{-1/2}\)
默认值:‘sym’。
clamp (bool, optional) – 一个布尔标志,指示传播后是否将标签限制在 [0, 1] 范围内。默认值:True。
normalize (bool, optional) – 一个布尔标志,指示传播后是否应用行归一化。默认值:False。
reset (bool, optional) – 一个布尔标志,指示是否在每次传播后重置已知标签。默认值:False。
示例
>>> import torch >>> import dgl >>> from dgl.nn import LabelPropagation
>>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True) >>> g = dgl.rand_graph(5, 10) >>> labels = torch.tensor([0, 2, 1, 3, 0]).long() >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool() >>> new_labels = label_propagation(g, labels, mask)
- forward(g, labels, mask=None)[source]
计算标签传播过程。
- 参数:
g (DGLGraph) – 输入图。
labels (torch.Tensor) –
输入的节点标签。支持以下三种情况。
对于多类分类中的节点类别标签,形状为 \((N, 1)\) 或 \((N,)\) 的 LongTensor,其中 \(N\) 是节点数。
对于多类分类中节点类别标签的独热编码,形状为 \((N, C)\) 的 LongTensor,其中 \(C\) 是类别数。
对于多标签二分类中的节点标签,形状为 \((N, L)\) 的 LongTensor,其中 \(L\) 是标签数。
mask (torch.Tensor) – 形状为 \((N,)\) 的布尔指示器,True 表示有标签节点。默认值:None,表示所有节点都有标签。
- 返回值:
传播后的节点标签,形状为 \((N, D)\),类型为 float,其中 \(D\) 是类别数或标签数。
- 返回类型:
torch.Tensor