CuGraphGATConv

class dgl.nn.pytorch.conv.CuGraphGATConv(in_feats, out_feats, num_heads, feat_drop=0.0, negative_slope=0.2, residual=False, activation=None, bias=True)[source]

基类: CuGraphBaseConv

来自 Graph Attention Networks 的图注意力层,通过 cugraph-ops 加速了稀疏聚合。

有关数学模型,请参阅 dgl.nn.pytorch.conv.GATConv

此模块依赖于 pylibcugraphops 包,可以通过 conda install -c nvidia pylibcugraphops=23.04 安装。pylibcugraphops 23.04 需要 Python 3.8.x 或 3.10.x 版本。

注意

这是一项实验性功能。

参数:
  • in_feats (int) – 输入特征大小。

  • out_feats (int) – 输出特征大小。

  • num_heads (int) – 多头注意力中的头数。

  • feat_drop (float, 可选) – 特征的 Dropout 比率。默认值:0

  • negative_slope (float, 可选) – LeakyReLU 负斜率的角度。默认值:0.2

  • residual (bool, 可选) – 如果为 True,则使用残差连接。默认值:False

  • activation (可调用激活函数/层None, 可选。) – 如果不为 None,则将激活函数应用于更新后的节点特征。默认值:None

  • bias (bool, 可选) – 如果为 True,则学习偏置项。默认值:True

示例

>>> import dgl
>>> import torch
>>> from dgl.nn import CuGraphGATConv
>>> device = 'cuda'
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)
>>> g = dgl.add_self_loop(g)
>>> feat = torch.ones(6, 10).to(device)
>>> conv = CuGraphGATConv(10, 2, num_heads=3).to(device)
>>> res = conv(g, feat)
>>> res
tensor([[[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]],
        [[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]],
        [[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]],
        [[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]],
        [[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]],
        [[ 0.2340,  1.9226],
        [ 1.6477, -1.9986],
        [ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)
forward(g, feat, max_in_degree=None)[source]

前向计算。

参数:
  • g (DGLGraph) – 图。

  • feat (torch.Tensor) – 形状为 \((N, D_{in})\) 的输入特征。

  • max_in_degree (int) – 目标节点的最大入度。它仅在 g 是一个 DGLBlock (即二部图) 时有效。当 g 是由邻居采样器生成时,该值应设置为相应的 fanout。如果未给定,max_in_degree 将会即时计算。

返回:

形状为 \((N, H, D_{out})\) 的输出特征,其中 \(H\) 是头数,\(D_{out}\) 是输出特征的大小。

返回类型:

torch.Tensor

reset_parameters()[source]

重新初始化可学习参数。