PYG教程【四】Node2Vec节点分类及其可视化
生活随笔
收集整理的這篇文章主要介紹了
PYG教程【四】Node2Vec节点分类及其可视化
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本文主要是介紹如何用PyTorch Geometric快速實現Node2Vec節點分類,并對其結果進行可視化。
整個過程包含四個步驟:
- 導入圖數據(這里以Cora為例)
- 創建Node2Vec模型
- 訓練和測試數據
- TSNE降維后可視化
Node2vec方法的參數如下:
- edge_index (LongTensor):鄰接矩陣
- embedding_dim (int):每個節點的embedding維度
- walk_length (int):步長
- context_size (int):正采樣時的窗口大小
- walks_per_node (int, optional) :每個節點走多少步
- p (float, optional) :p值
- q (float, optional) :q值
- num_negative_samples (int, optional) :每個正采樣對應多少負采樣
代碼如下:
import torch import matplotlib.pyplot as plt from sklearn.manifold import TSNE from torch_geometric.datasets import Planetoid from torch_geometric.nn import Node2Vecdataset = Planetoid(root='G:/torch_geometric_datasets', name='Cora') data = dataset[0]device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,context_size=10, walks_per_node=10, num_negative_samples=1,sparse=True).to(device) loader = model.loader(batch_size=128, shuffle=True, num_workers=4)# 在pytorch舊版本中使用torch.optim.SparseAdam(model.parameters(), lr=0.01),新版本中需要轉為list, 本文pytorch版本1.7.1 optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)def train():model.train()total_loss = 0for pos_rw, neg_rw in loader:optimizer.zero_grad()loss = model.loss(pos_rw.to(device), neg_rw.to(device))loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(loader)@torch.no_grad() def test():model.eval()z = model()acc = model.test(z[data.train_mask], data.y[data.train_mask],z[data.test_mask], data.y[data.test_mask], max_iter=150) # 使用train_mask訓練一個分類器,用test_mask分類return accfor epoch in range(1, 101):loss = train()acc = test()print(f'Epoch:{epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')@torch.no_grad() def plot_points(colors):model.eval()z = model(torch.arange(data.num_nodes, device=device))z = TSNE(n_components=2).fit_transform(z.cpu().numpy())y = data.y.cpu().numpy()plt.figure(figsize=(8, 8))for i in range(dataset.num_classes):plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])plt.axis('off')plt.show()colors = ['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'] plot_points(colors)輸出結果如下:
Epoch:01, Loss: 8.0661, Acc: 0.1570 Epoch:02, Loss: 6.0309, Acc: 0.1800 Epoch:03, Loss: 4.9328, Acc: 0.2050 Epoch:04, Loss: 4.1206, Acc: 0.2400 Epoch:05, Loss: 3.4587, Acc: 0.2760 Epoch:06, Loss: 2.9389, Acc: 0.2950 Epoch:07, Loss: 2.5340, Acc: 0.3220 Epoch:08, Loss: 2.2042, Acc: 0.3410 Epoch:09, Loss: 1.9404, Acc: 0.3700 Epoch:10, Loss: 1.7295, Acc: 0.4050 Epoch:11, Loss: 1.5594, Acc: 0.4340 Epoch:12, Loss: 1.4231, Acc: 0.4660 Epoch:13, Loss: 1.3143, Acc: 0.4850 Epoch:14, Loss: 1.2242, Acc: 0.5100 Epoch:15, Loss: 1.1539, Acc: 0.5310 Epoch:16, Loss: 1.0997, Acc: 0.5560 Epoch:17, Loss: 1.0559, Acc: 0.5760 Epoch:18, Loss: 1.0199, Acc: 0.6020 Epoch:19, Loss: 0.9921, Acc: 0.6120 Epoch:20, Loss: 0.9671, Acc: 0.6190 Epoch:21, Loss: 0.9487, Acc: 0.6300 Epoch:22, Loss: 0.9335, Acc: 0.6390 Epoch:23, Loss: 0.9203, Acc: 0.6480 Epoch:24, Loss: 0.9106, Acc: 0.6580 Epoch:25, Loss: 0.8994, Acc: 0.6630 Epoch:26, Loss: 0.8924, Acc: 0.6600 Epoch:27, Loss: 0.8858, Acc: 0.6610 Epoch:28, Loss: 0.8792, Acc: 0.6670 Epoch:29, Loss: 0.8731, Acc: 0.6800 Epoch:30, Loss: 0.8697, Acc: 0.6830 Epoch:31, Loss: 0.8652, Acc: 0.6850 Epoch:32, Loss: 0.8618, Acc: 0.6840 Epoch:33, Loss: 0.8586, Acc: 0.6920 Epoch:34, Loss: 0.8550, Acc: 0.6900 Epoch:35, Loss: 0.8523, Acc: 0.6820 Epoch:36, Loss: 0.8507, Acc: 0.6800 Epoch:37, Loss: 0.8483, Acc: 0.6870 Epoch:38, Loss: 0.8469, Acc: 0.6930 Epoch:39, Loss: 0.8449, Acc: 0.6950 Epoch:40, Loss: 0.8433, Acc: 0.6920 Epoch:41, Loss: 0.8422, Acc: 0.6980 Epoch:42, Loss: 0.8398, Acc: 0.6960 Epoch:43, Loss: 0.8401, Acc: 0.6930 Epoch:44, Loss: 0.8374, Acc: 0.6930 Epoch:45, Loss: 0.8377, Acc: 0.6990 Epoch:46, Loss: 0.8363, Acc: 0.6970 Epoch:47, Loss: 0.8354, Acc: 0.7060 Epoch:48, Loss: 0.8339, Acc: 0.7130 Epoch:49, Loss: 0.8333, Acc: 0.7060 Epoch:50, Loss: 0.8340, Acc: 0.7090 Epoch:51, Loss: 0.8332, Acc: 0.7090 Epoch:52, Loss: 0.8325, Acc: 0.7090 Epoch:53, Loss: 0.8321, Acc: 0.7070 Epoch:54, Loss: 0.8316, Acc: 0.7160 Epoch:55, Loss: 0.8317, Acc: 0.7100 Epoch:56, Loss: 0.8297, Acc: 0.7130 Epoch:57, Loss: 0.8309, Acc: 0.7140 Epoch:58, Loss: 0.8296, Acc: 0.7230 Epoch:59, Loss: 0.8296, Acc: 0.7230 Epoch:60, Loss: 0.8276, Acc: 0.7190 Epoch:61, Loss: 0.8287, Acc: 0.7120 Epoch:62, Loss: 0.8294, Acc: 0.7120 Epoch:63, Loss: 0.8272, Acc: 0.7050 Epoch:64, Loss: 0.8286, Acc: 0.7040 Epoch:65, Loss: 0.8283, Acc: 0.7090 Epoch:66, Loss: 0.8278, Acc: 0.7110 Epoch:67, Loss: 0.8274, Acc: 0.7140 Epoch:68, Loss: 0.8283, Acc: 0.7190 Epoch:69, Loss: 0.8269, Acc: 0.7160 Epoch:70, Loss: 0.8271, Acc: 0.7210 Epoch:71, Loss: 0.8260, Acc: 0.7190 Epoch:72, Loss: 0.8273, Acc: 0.7130 Epoch:73, Loss: 0.8252, Acc: 0.7150 Epoch:74, Loss: 0.8264, Acc: 0.7120 Epoch:75, Loss: 0.8250, Acc: 0.7160 Epoch:76, Loss: 0.8253, Acc: 0.7190 Epoch:77, Loss: 0.8244, Acc: 0.7220 Epoch:78, Loss: 0.8263, Acc: 0.7220 Epoch:79, Loss: 0.8271, Acc: 0.7180 Epoch:80, Loss: 0.8253, Acc: 0.7110 Epoch:81, Loss: 0.8260, Acc: 0.7080 Epoch:82, Loss: 0.8246, Acc: 0.7140 Epoch:83, Loss: 0.8256, Acc: 0.7170 Epoch:84, Loss: 0.8257, Acc: 0.7210 Epoch:85, Loss: 0.8256, Acc: 0.7190 Epoch:86, Loss: 0.8244, Acc: 0.7170 Epoch:87, Loss: 0.8254, Acc: 0.7240 Epoch:88, Loss: 0.8249, Acc: 0.7170 Epoch:89, Loss: 0.8252, Acc: 0.7160 Epoch:90, Loss: 0.8243, Acc: 0.7010 Epoch:91, Loss: 0.8254, Acc: 0.7050 Epoch:92, Loss: 0.8249, Acc: 0.7030 Epoch:93, Loss: 0.8249, Acc: 0.7110 Epoch:94, Loss: 0.8233, Acc: 0.6990 Epoch:95, Loss: 0.8243, Acc: 0.6990 Epoch:96, Loss: 0.8248, Acc: 0.7140 Epoch:97, Loss: 0.8240, Acc: 0.7090 Epoch:98, Loss: 0.8247, Acc: 0.7100 Epoch:99, Loss: 0.8255, Acc: 0.7060 Epoch:100, Loss: 0.8242, Acc: 0.7160
從輸出結果看出train的loss后面降低,但是精度卻沒有降低,有點過擬合了。
總結
以上是生活随笔為你收集整理的PYG教程【四】Node2Vec节点分类及其可视化的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PYG教程【三】对Cora数据集进行半监
- 下一篇: 破碎之地怎么传送?