DGL_图的创建、保存、加载
生活随笔
收集整理的這篇文章主要介紹了
DGL_图的创建、保存、加载
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
import dgl
import torch as th
from dgl.data.utils import save_graphsg1 = dgl.DGLGraph()
g1.add_nodes(3)
g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
g1.ndata["x"] = th.ones(3, 5) # 3個節(jié)點的embedding
g1.edata['y'] = th.zeros(6, 5) # 6條邊的embedding
# 補充:添加邊的方式
# g1.add_edges(th.tensor([3, 4, 5]), 1) # three edges: 3->1, 4->1, 5->1
# g1.add_edges(4, [7, 8, 9]) # three edges: 4->7, 4->8, 4->9
# g1.add_edges([1, 2, 3], [3, 4, 5]) # three edges: 1->3, 2->4, 3->5g2 = dgl.DGLGraph()
g2.add_nodes(3)
g2.add_edges([0, 1, 2], [1, 2, 1])
g2.edata["e"] = th.ones(3, 4)graph_labels = {"graph_sizes": th.tensor([3, 3])}save_graphs("data/try1.bin", [g1, g2], graph_labels)
from dgl.data.utils import load_graphs
from dgl.data.utils import load_labels# glist, label_dict = load_graphs("data/small.bin") # glist will be [g1, g2]
glist, label_dict = load_graphs("data/try1.bin", [0]) # glist will be [g1]
graph_sizes = load_labels("data/try1.bin")print(glist)
# [DGLGraph(num_nodes=3, num_edges=6,
# ndata_schemes={'x': Scheme(shape=(5,), dtype=torch.float32)}
# edata_schemes={'y': Scheme(shape=(5,), dtype=torch.float32)})]
print(label_dict)
# {'graph_sizes': tensor([3, 3])}
print(graph_sizes)
# {'graph_sizes': tensor([3, 3])}
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎勵來咯,堅持創(chuàng)作打卡瓜分現(xiàn)金大獎
總結(jié)
以上是生活随笔為你收集整理的DGL_图的创建、保存、加载的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pycharm使用远程服务器运行代码
- 下一篇: 文档主题分析项目