torch_geometric 笔记: 数据集Cora 简易 GNN
生活随笔
收集整理的這篇文章主要介紹了
torch_geometric 笔记: 数据集Cora 简易 GNN
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1 獲取數據集
該數據集用于semi-supervised的節點分類任務
from torch_geometric.datasets import Planetoiddataset = Planetoid(root='/tmp/Cora', name='Cora')dataset.num_classes #7 #節點一共七個類dataset.num_features #1433 #每個點1433個特征len(dataset) #1 #只有一張圖dataset[0].is_undirected() #Truedataset[0] #Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])''' edge_index=[2, 10556]————這張圖有10556條有向邊 x=[2708, 1433]————這張圖有2708個點,每個點1433個特征 y=[2708]——每個節點的標簽(一共有7個類) '''dataset[0]['train_mask'] #tensor([ True, True, True, ..., False, False, False]) #train_mask:2708維向量,訓練集的mask向量,標識哪些節點屬于訓練集。 #val_mask:2708維向量,驗證集的mask向量,標識哪些節點屬于驗證集。 #test_mask:2708維向量,測試集的mask向量,表示哪些節點屬于測試集。?
1.1 cora 數據集??
cora數據集的點表示的是機器學習的論文, 這些論文的選擇方式使得在最終的語料庫中每篇論文都引用或被至少另一篇論文引用。
全語料庫有2708篇論文。我們得到了一個大小為 1433 個唯一詞的詞匯表。 所有文檔頻率小于 10 的單詞都被刪除。
2 簡易GCN
2.1?torch_geometric.nn中有的模型
在torch_geometric.nn — pytorch_geometric 2.0.1 documentation (pytorch-geometric.readthedocs.io)
列舉了torch_geometric.nn中有的模型
2.2 簡易模型
2.2.1 導入庫
import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv2.2.2 設計模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)#兩層GCN,輸入是每個節點的num_node_features維特征,輸出是16維向量self.conv2 = GCNConv(16, dataset.num_classes)#兩層GCN,輸入是16維向量,輸出是點有的類別數def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)#GCN1'''forward(x: torch.Tensor, edge_index: Union[torch.Tensor, torch_sparse.tensor.SparseTensor],edge_weight: Optional[torch.Tensor] = None) → torch.Tensor'''x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)2.2.3 訓練模型
model = Net() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) #優化函數loss_func=F.nll_lossmodel.train() for epoch in range(200):optimizer.zero_grad() #清空上一步殘余的參數更新值out = model(data)loss = loss_func(out[data.train_mask], data.y[data.train_mask]) #計算誤差loss.backward() #清空上一步殘余的參數更新值optimizer.step()#將參數更新值施加到net的parameters上?2.2.4 測試模型
model.eval() _, pred = model(data).max(dim=1) #預測結果correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) acc = correct / int(data.test_mask.sum()) print('Accuracy: {:.4f}'.format(acc)) #Accuracy: 0.8080總結
以上是生活随笔為你收集整理的torch_geometric 笔记: 数据集Cora 简易 GNN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch_geometric笔记:数据
- 下一篇: toch_geometric 笔记:me