toch_geometric 笔记:message passing GCNConv
?1?message passing介紹
? ? ? ? 將卷積算子推廣到不規則域通常表示為一個鄰域聚合(neighborhood aggregation)或消息傳遞(message passing?)方案
? ? ? ? 給定第(k-1)層點的特征,以及可能有的點與點之間邊的特征,依靠信息傳遞的GNN可以被描述成:
?
?其中表示一個可微分的可微,置換不變的函數(比如sum、mean或者max),γ和Φ表示可微分方程(比如MLP)
2 message passing 類??
????????PyG提供了message?passing基類,它通過自動處理消息傳播來幫助創建這類消息傳遞圖神經網絡。
? ? ? ? 使用者只需要定義γ(update函數)和Φ(message函數),以及聚合方式aggr(即)【aggr="add",?aggr="mean"?or?aggr="max"】即可
2.1?MessagePassing
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)?定義了聚合方式(這里是’add‘)
信息傳遞的流方向("source_to_target"?【默認】or?"target_to_source")
node_dim表示了沿著哪個軸進行傳遞
2.2?MessagePassing.propagate
MessagePassing.propagate(edge_index, size=None, **kwargs)????????開始傳播消息的初始調用。
????????獲取邊索引(edge index)和所有額外的數據,這些數據是構造消息和更新節點嵌入所需要的。
? ? ? ? propagate()不僅可以在[N,N]的鄰接方陣中傳遞消息,還可以在非方陣中傳遞消息,(比如二部圖[N,M],此時設置size=(N,M)作為額外的形參)
? ? ? ? 如果size參數設置為None,那么矩陣默認是一個方陣。
? ? ? ? 對于二部圖[N,M]來說,它有兩組互相獨立的點集,我們還需要設置x=(x_N,x_M)
2.3?MessagePassing.message(...)
? ? ? ? 類似于Φ。將信息傳遞到節點i上。?如果flow="source_to_target",那么是找所有(j,i)∈E;如果flow="target_to_source",那么找所有(i,j)屬于E。
????????可以接受最初傳遞給propagate()的任何參數。
????????此外,傳遞給propagate()的張量可以通過在變量名后面附加_i或_j,映射到各自的節點。例如,x_i(表示中心節點)、?x_j(表示鄰居節點)。
????????注意,我們通常將i稱為匯聚信息的中心節點,將j稱為相鄰節點,因為這是最常見的表示法。
2.4?MessagePassing.update(aggr_out,?...)
? ? ? ? 類比γ,對每個點i∈ V,更新它的node embedding
? ? ? ? 第一個參數是聚合輸出,同時將所有傳遞給propagate()的參數作為后續參數
3 舉例: GCN
3.1 GCN回顧
GCN層可以表示為:
?????????k-1層的鄰居節點先通過權重矩陣Θ加權,然后用中心節點和這個鄰居節點的度來進行歸一化,最后求和聚合 。
3.2 message passing 實現過程
? ? ? ? 這個方程可以劃分成以下幾個步驟
????????步驟1~3在message passing開始前就已經計算完畢了;步驟4,5則可以用MessagePassing操作來進行處理 。
3.3 代碼解析
import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add') # "Add" aggregation (Step 5).#GCN類從MessagePssing中繼承得到的聚合方式:“add”self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels] ——N個點,每個點in_channels維屬性# edge_index has shape [2, E]——E條邊,每條邊有出邊和入邊# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))#添加自環# Step 2: Linearly transform node feature matrix.x = self.lin(x)#對X進行線性變化# Step 3: Compute normalization.row, col = edge_index#出邊和入邊deg = degree(col, x.size(0), dtype=x.dtype)#各個點的入度(無向圖,所以入讀和出度相同)deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]#1/sqrt(di) *1/sqrt(dj)# Step 4-5: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)#進行propagate#propagate的內部會調用message(),aggregate()和update()#作為消息傳播的附加參數,我們傳遞節點嵌入x和標準化系數norm。 def message(self, x_j, norm):# x_j has shape [E, out_channels]#我們需要對相鄰節點特征x_j進行norm標準化#這里x_j為一個張量,其中包含每條邊的源節點特征,即每個節點的鄰居。# Step 4: Normalize node features.return norm.view(-1, 1) * x_j#1/sqrt(di) *1/sqrt(dj) *X_j? 之后,我們就可以用這種方法輕松調用了:
conv = GCNConv(16, 32) x = conv(x, edge_index)總結
以上是生活随笔為你收集整理的toch_geometric 笔记:message passing GCNConv的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch_geometric 笔记:
- 下一篇: torch_geometric 笔记:T