节点嵌入
- class dgl.nn.pytorch.sparse_emb.NodeEmbedding(num_embeddings, embedding_dim, name, init_func=None, device=None, partition=None)[source]
基类:
object
用于存储节点嵌入的类。
此类针对大规模节点嵌入的训练进行了优化。它以稀疏方式更新嵌入,并且可以扩展到包含数百万个节点的图。它还支持在(单台机器上的)多个 GPU 上进行分区以获得更多加速。它不支持跨机器分区。
目前,DGL 提供了两个适用于此 NodeEmbedding 类的优化器:
SparseAdagrad
和SparseAdam
。此实现基于 torch.distributed 包。它依赖于 pytorch 默认的分布式进程组来收集多进程信息,并使用
torch.distributed.TCPStore
在多个 GPU 进程之间共享元数据信息。它使用本地地址 '127.0.0.1:12346' 来初始化 TCPStore。注意:对 NodeEmbedding 的支持是实验性的。
- 参数:
示例
在启动多个 GPU 进程之前
>>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb
在每个训练进程中
>>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step()