PYG教程【三】对Cora数据集进行半监督节点分类
Cora數據集
PyG包含有大量的基準數據集。初始化數據集非常簡單,數據集初始化會自動下載原始數據文件,并且會將它們處理成Data格式。
如下圖所示,Cora數據集中只有一個圖,該圖包含2708個節點,10556條邊,節點類別數為7,特征維度為1433。并且默認已經對數據集進行了劃分,分為了訓練集、驗證集和測試集。
然后看看節點特征和標簽。x為節點特征矩陣,維度為2708*1433。y為節點標簽向量,維度為2708,類別為7。
用GCN進行半監督節點分類
接下來就可以構建一個簡單的GCN模型,在Cora數據集上進行半監督節點分類。
下面的GCN模型包含兩個圖卷積層。第一層輸入維度為1433(節點特征維度),輸出為16(與第一層輸出一致),后面接上一個relu激活函數,以及dropout操作。第二層輸入維度為16,輸出為7(節點標簽數量),后接log_softmax函數進行分類。
模型構建完成后,指定訓練設備為GPU(沒有的話就用CPU),注意這里默認使用的是0號cuda。如果cuda:0被占用了的話會報錯,需要指定其他號碼的cuda才能運行。然后,分別將GCN模型以及Cora圖數據送入指定的設備。
優化器選擇Adam,學習率設置為0.01,權重衰減設置為5e-4。這些都配置好以后就可以訓練模型了,epoch設為200,每個epoch后清除上次的梯度信息,然后用nll_loss計算出訓練集上的損失,調用backward函數計算出梯度后傳回給Adam優化器進行參數更新。
最后在測試集上評估模型,計算分類正確率accuracy并顯示。
至此,就完成了Cora數據集上的節點分類任務了。
總結
以上是生活随笔為你收集整理的PYG教程【三】对Cora数据集进行半监督节点分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 人工智能技术应用学什么
- 下一篇: PYG教程【四】Node2Vec节点分类