pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...
生活随笔
收集整理的這篇文章主要介紹了
pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
圖神經網絡中最流行和廣泛采用的任務之一就是節點分類,其中訓練集/驗證集/測試集中的每個節點從一組預定義的類別中分配一個真實類別。
為了對節點進行分類,圖神經網絡利用節點自身的特征,以及相鄰節點和邊的特征進行消息傳遞。消息傳遞可以重復多次,以聚合來自更大范圍的鄰居節點的信息。
dgl框架為我們提供了一些內置的圖卷積模塊,可以執行一輪的消息傳遞。
在本文中,我們使用dgl.nn.pytorch的SAGEConv模塊,該模塊來自這篇論文GraphSAGE:Inductive Representation Learning on Large Graphs
通常對于圖上的深度學習模型,我們需要一個多層圖神經網絡,在這里我們進行多輪的消息傳遞。這可以通過如下方式堆疊圖卷積模塊來實現。
1 構造GNN模型
先導入必要包(本文dgl 版本為 0.5.2)
import dgl.nn as dglnn
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import * 構造一個兩層的gnn模型
class SAGE(nn.Module):def __init__(self, in_feats, hid_feats, out_feats, dropout=0.2):super().__init__()self.conv1 = dglnn.SAGEConv( in_feats=in_feats, out_feats=hid_feats, feat_drop=0.2, aggregator_type='gcn')self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, feat_drop=0.2, aggregator_type='mean')self.dropout = nn.Dropout(dropout)def forward(self, graph, inputs):# inputs 是節點的特征 [N, in_feas]h = self.conv1(graph, inputs)h = self.dropout(F.relu(h))h = self.conv2(graph, h)return h 注意,我們不僅可以將上面的模型用于節點分類,還可以獲取節點的特征表示為了其他下游任務,如邊分類/回歸、鏈接預測或圖分類。
2 數據集與數據分析
dataset = CoraGraphDataset() # Cora citation network dataset
graph = dataset[0]
graph = dgl.remove_self_loop(graph) # 消除自環
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1) print("圖的節點數和邊數: ", graph.num_nodes(), graph.num_edges())
print("訓練集節點數:", train_mask.sum().item())
print("驗證集集節點數:", valid_mask.sum().item())
print("測試集節點數:", test_mask.sum().item())
print("節點特征維數:", n_features)
print("標簽類目數:", n_labels)隨機抽200個節點并畫圖展示:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt G = graph.to_networkx()
res = np.random.randint(0, high=G.number_of_nodes(), size=(200))k = G.subgraph(res)
pos = nx.spring_layout(k)plt.figure()
nx.draw(k, pos=pos, node_size=8 )
plt.savefig('cora.jpg', dpi=600)
plt.show()3 訓練模型與評估
def evaluate(model, graph, features, labels, mask):model.eval()with torch.no_grad():logits = model(graph, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)model = SAGE(in_feats=n_features, hid_feats=128, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())# 開始訓練
best_val_acc = 0
for epoch in range(200): print('Epoch {}'.format(epoch))model.train()# 用所有的節點進行前向傳播logits = model(graph, node_features)# 計算損失loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])# 計算驗證集accuracyacc = evaluate(model, graph, node_features, node_labels, valid_mask)# backward propagationopt.zero_grad()loss.backward()opt.step()print('loss = {:.4f}'.format(loss.item()))if acc > best_val_acc:best_val_acc = acctorch.save(model.state_dict(), 'save_model/best_model.pth')print("current val acc = {}, best val acc = {}".format(acc, best_val_acc))測試集評估
model.load_state_dict(torch.load("save_model/best_model.pth"))
acc = evaluate(model, graph, node_features, node_labels, test_mask)
print("test accuracy: ", acc)完結:-) 覺得有用記得雙擊點贊呀!
總結
以上是生活随笔為你收集整理的pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国拍卖最贵的画是谁画的呢?
- 下一篇: 冰天雪地