pytorch 学习: STGCN
1 main.ipynb
1.1 導入庫
import random import torch import numpy as np import pandas as pd from sklearn.preprocessing import StandardScaler from load_data import * from utils import * from stgcn import *1.2 隨機種子
torch.manual_seed(2021) torch.cuda.manual_seed(2021) np.random.seed(2021) random.seed(2021) torch.backends.cudnn.deterministic = True1.3 cpu or gpu
if torch.cuda.is_available():device = torch.device("cuda") else:device=torch.device("cpu")1.4 file path
matrix_path = "dataset/W_228.csv" #鄰接矩陣 #228*228,228是觀測點的數量data_path = "dataset/V_228.csv" #數據矩陣 #12672*228 #12672=288*44,288是一天中有幾個5分鐘,44是我數據集一共44天save_path = "save/model.pt" #模型保存路徑1.5 參數
day_slot = 288 #24小時*12(12是一小時有幾個5分鐘的時間片) #一天有幾個5分鐘n_train, n_val, n_test = 34, 5, 5 # 訓練集(前34天) 評估集(中間5天) 測試集(最后5天) n_his = 12 #用過去12個時間片段的交通數據n_pred = 3 #預測未來的第3個時間片段的交通數據n_route = 228 #子路段數量Ks, Kt = 3, 3 #空間和時間卷積核大小blocks = [[1, 32, 64], [64, 32, 128]] ##兩個ST塊各隱藏層大小drop_prob = 0 #dropout概率 batch_size = 50 epochs = 50 lr = 1e-31.6 圖的一些操作?
W = load_matrix(matrix_path) #load_data里面的函數 #鄰接矩陣,是一個ndarrayL = scaled_laplacian(W) #utils.py里面的函數 #標準化拉普拉斯矩陣,是一個ndarrayLk = cheb_poly(L, Ks) #L的切比雪夫多項式 #[Ks,n,n]大小的list(n是L的size)Lk = torch.Tensor(Lk.astype(np.float32)).to(device) #轉換成Tensor1.7 歸一化
train, val, test = load_data(data_path, n_train * day_slot, n_val * day_slot) #訓練集,測試集,驗證集 #load_data load_data.py的函數scaler = StandardScaler() train = scaler.fit_transform(train) val = scaler.transform(val) test = scaler.transform(test) #數據歸一化(每一個點的數十天的數據歸一化成N(0,1))1.8 x,y的構造
x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device) #在load_data.py中x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device) #分別是測試集、驗證集和測試集的數據集和標簽值1.9 DataLoader
dataLoader部分見:pytorch筆記:Dataloader_UQI-LIUWJ的博客-CSDN博客
train_data = torch.utils.data.TensorDataset(x_train, y_train) #先轉化成pytorch可以識別的Dataset格式train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True) #把dataset導入dataloader,并設置batch_size和shuffleval_data = torch.utils.data.TensorDataset(x_val, y_val) val_iter = torch.utils.data.DataLoader(val_data, batch_size)test_data = torch.utils.data.TensorDataset(x_test, y_test) test_iter = torch.utils.data.DataLoader(test_data, batch_size)''' for x, y in train_iter:print(x.size()) 返回的結果都是:torch.Size([50, 1, 12, 228])print(x.size()) 返回的結果都是:torch.Size([50, 228]) '''1.10 損失函數
loss = nn.MSELoss() #均方誤差1.11 模型部分
model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device) #模型1.12 優化函數
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)1.13 LRScheduler?
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7) #每經過5步,學習率乘0.7?1.14 模型的訓練和保存
min_val_loss = np.inf for epoch in range(1, epochs + 1):l_sum, n = 0.0, 0model.train()for x, y in train_iter:y_pred = model(x).view(len(x), -1)#x_size:50, 1, 12, 228]l = loss(y_pred, y)#計算誤差optimizer.zero_grad()l.backward()optimizer.step()#pytorch三部曲l_sum += l.item() * y.shape[0]#y.shape[0]是50(一個batch 的數據量)#因為我們的LOSS是MSELOSS,所以在計算loss的時候除了m(即50),這邊就需要乘回去n += y.shape[0]#n表示一個epoch中總的數據量(其實就是34*288=9732)scheduler.step()#更新學習率val_loss = evaluate_model(model, loss, val_iter)#在utils.py里面#做用是求得驗證集在當前這一組參屬下的誤差if val_loss < min_val_loss:min_val_loss = val_losstorch.save(model.state_dict(), save_path)#如果驗證集得到的誤差小,那么將驗證集的參數保存 print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)''' epoch 1 , train loss: 0.2372948690231597 , validation loss: 0.17270135993722582 epoch 2 , train loss: 0.16071674468762734 , validation loss: 0.1874343464626883 epoch 3 , train loss: 0.15448020929178746 , validation loss: 0.15503579677238952 epoch 4 , train loss: 0.14851808142814324 , validation loss: 0.1571340094572001 epoch 5 , train loss: 0.14439846146427904 , validation loss: 0.1607034688638727 epoch 6 , train loss: 0.13501421282825268 , validation loss: 0.15179621507107777 epoch 7 , train loss: 0.13397674925686107 , validation loss: 0.1501583637547319 epoch 8 , train loss: 0.13199909963433504 , validation loss: 0.15549336293589894 epoch 9 , train loss: 0.13083163166267517 , validation loss: 0.1436274949678757 epoch 10 , train loss: 0.12860295229930127 , validation loss: 0.1711318050069313 epoch 11 , train loss: 0.12468195441724815 , validation loss: 0.14502346818845202 epoch 12 , train loss: 0.12422825037287816 , validation loss: 0.1424633072294893 epoch 13 , train loss: 0.12274483556448518 , validation loss: 0.14821374778003588 epoch 14 , train loss: 0.12206453774660224 , validation loss: 0.14754791510203025 epoch 15 , train loss: 0.12099895425406379 , validation loss: 0.14229175160183524 epoch 16 , train loss: 0.11788094088358396 , validation loss: 0.14172261148473642 epoch 17 , train loss: 0.11743906428081737 , validation loss: 0.14362958854023558 epoch 18 , train loss: 0.11658749032162606 , validation loss: 0.14289248521256187 epoch 19 , train loss: 0.11578559385394271 , validation loss: 0.14577691240684829 epoch 20 , train loss: 0.11517422387001339 , validation loss: 0.14248750845554972 epoch 21 , train loss: 0.11292880779622501 , validation loss: 0.14378667825384298 epoch 22 , train loss: 0.11236149522433111 , validation loss: 0.1418098776064215 epoch 23 , train loss: 0.11190123393005597 , validation loss: 0.14487336483532495 epoch 24 , train loss: 0.11122141592764404 , validation loss: 0.14256540075433952 epoch 25 , train loss: 0.11055498759427415 , validation loss: 0.1417213207804156 epoch 26 , train loss: 0.10926588731084119 , validation loss: 0.14354881562673263 epoch 27 , train loss: 0.10878032141678218 , validation loss: 0.14406675109843703 epoch 28 , train loss: 0.10831604593266689 , validation loss: 0.14266293554356063 epoch 29 , train loss: 0.10783299739592932 , validation loss: 0.14181039777387233 epoch 30 , train loss: 0.10746425136239193 , validation loss: 0.14267496105256308 epoch 31 , train loss: 0.10646289705865472 , validation loss: 0.14362520976060064 epoch 32 , train loss: 0.10611696387435193 , validation loss: 0.1432999167183455 epoch 33 , train loss: 0.10574598974132804 , validation loss: 0.14397347505020835 epoch 34 , train loss: 0.10544157493979493 , validation loss: 0.14419378039773798 epoch 35 , train loss: 0.1051575989090946 , validation loss: 0.1453490975537222 epoch 36 , train loss: 0.10441591932940965 , validation loss: 0.14409059120246964 epoch 37 , train loss: 0.10416163295225915 , validation loss: 0.1449487895915543 epoch 38 , train loss: 0.10386519186668972 , validation loss: 0.14444787363882047 epoch 39 , train loss: 0.10369502502373996 , validation loss: 0.14437076065988436 epoch 40 , train loss: 0.10344708665002564 , validation loss: 0.14485514112306339 epoch 41 , train loss: 0.10296985521567077 , validation loss: 0.1442400562801283 epoch 42 , train loss: 0.10274617794937922 , validation loss: 0.14564144609999047 epoch 43 , train loss: 0.10261664642584892 , validation loss: 0.14551366431924112 epoch 44 , train loss: 0.102446699424612 , validation loss: 0.14577252360699822 epoch 45 , train loss: 0.10227145068907287 , validation loss: 0.1455480455536477 epoch 46 , train loss: 0.10193707958222101 , validation loss: 0.1456132891050873 epoch 47 , train loss: 0.1017713555406352 , validation loss: 0.14567107602573223 epoch 48 , train loss: 0.10164602311826305 , validation loss: 0.14578005224194404 epoch 49 , train loss: 0.10153527037844785 , validation loss: 0.14653010304718123 epoch 50 , train loss: 0.10142039881116231 , validation loss: 0.1462976201607363 '''?1.15 加載最佳模型對應參數
best_model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device) best_model.load_state_dict(torch.load(save_path))1.16? 測評
l = evaluate_model(best_model, loss, test_iter) MAE, MAPE, RMSE = evaluate_metric(best_model, test_iter, scaler) print("test loss:", l, "\nMAE:", MAE, ", MAPE:", MAPE, ", RMSE:", RMSE)''' test loss: 0.13690029052052186 MAE: 2.2246220055150383 , MAPE: 0.051902304533065484 , RMSE: 3.995202803143325 '''2 load_data.py
2.1 庫函數導入
import torch import numpy as np import pandas as pd2.2 load_matrix
def load_matrix(file_path):return pd.read_csv(file_path, header=None).values.astype(float)2.2?load_data
def load_data(file_path, len_train, len_val):df = pd.read_csv(file_path, header=None).values.astype(float)#數據集[12672,228]train = df[: len_train]#訓練集:[34*288,228] val = df[len_train: len_train + len_val]#驗證集 中間的5天#[5*288,228]test = df[len_train + len_val:]#測試集 最后的5天#[5*288,228]return train, val, test2.3 data_transform
def data_transform(data, n_his, n_pred, day_slot, device):n_day = len(data) // day_slot#訓練集,驗證集,測試集的天數n_route = data.shape[1]#邊的數量n_slot = day_slot - n_his - n_pred + 1#一天有n_slot組(n_his歷史時間片段長度+n_pred預測時間片段長度)的預測區間段x = np.zeros([n_day * n_slot, 1, n_his, n_route])#[數據集一共有n_day天*每天有的預測區間段數量,1,歷史事件片長度,子路段數量]y = np.zeros([n_day * n_slot, n_route])#[數據集一共有n_day天*每天有的預測區間段數量,子路段數量]#換言之,每個[1,his]的內容,預測一個速度值for i in range(n_day):for j in range(n_slot):t = i * n_slot + j#第t個預測區間段(每天有n_slot個,第i天從i*n_slot開始,這是這一天的第j個)s = i * day_slot + j#總體的第i天第j個時間段(因為n_slot的時候,是不考慮跨天的情況的,所以n_slot<day_slot)e = s + n_his#當前時間節點x[t, :, :, :] = data[s:e].reshape(1, n_his, n_route)#第t個預測區間段的值(從[n_his,n_route]升維至[1,n_his,n_route]y[t] = data[e + n_pred - 1]#第t個預測區間段對應的label值return torch.Tensor(x).to(device), torch.Tensor(y).to(device)3 utils.py?
3.1 庫函數導入
import torch import numpy as np3.2?scaled_laplacian
計算標準化圖拉普拉斯矩陣
def scaled_laplacian(A):n = A.shape[0]#228d = np.sum(A, axis=1)#度矩陣L = np.diag(d) - A#拉普拉斯矩陣=D-Afor i in range(n):for j in range(n):if d[i] > 0 and d[j] > 0:L[i, j] /= np.sqrt(d[i] * d[j])#D^(-1/2)*L*D^(1/2)lam = np.linalg.eigvals(L).max().real#lambda_max,歸一化拉普拉斯矩陣最大的特征值return 2 * L / lam - np.eye(n)#(2/lambda_max)L-In3.3?cheb_poly
切比雪夫多項式近似的圖卷積項(零階卷積、一階卷積、二階卷積。。。)
def cheb_poly(L, Ks):n = L.shape[0]#228LL = [np.eye(n), L[:]]#LL[0]=T0(L)=In#LL[1]=T1(L)=Lfor i in range(2, Ks):LL.append(np.matmul(2 * L, LL[-1]) - LL[-2])#切比雪夫多項式的迭代公式:#T_k(L)=2LT_{k-1}(L)-T_{k-2}(L)return np.asarray(LL)#[Ks,L,L]大小的list3.4??evaluate_model
計算模型的損失函數值
def evaluate_model(model, loss, data_iter):model.eval()l_sum, n = 0.0, 0with torch.no_grad():for x, y in data_iter:y_pred = model(x).view(len(x), -1)l = loss(y_pred, y)l_sum += l.item() * y.shape[0]n += y.shape[0]return l_sum / n?3.5?evaluate_metric
def evaluate_metric(model, data_iter, scaler):model.eval()with torch.no_grad():mae, mape, mse = [], [], []for x, y in data_iter:y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1)#歸一化的數據還原為源數據y_pred = scaler.inverse_transform(model(x).view(len(x), -1).cpu().numpy()).reshape(-1)d = np.abs(y - y_pred)mae += d.tolist()#mae=sigma(|pred(x)-y|)/mmape += (d / y).tolist()#mape=sigma(|(pred(x)-y)/y|)/mmse += (d ** 2).tolist()#mse=sigma((pred(y)-y)^2)/mMAE = np.array(mae).mean()MAPE = np.array(mape).mean()RMSE = np.sqrt(np.array(mse).mean())return MAE, MAPE, RMSE4 stgcn.py
4.1 庫函數導入
import math import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F4.2 align
?用于殘差連接x的計算
Pad見pytorch筆記:torch.nn.functional.pad_UQI-LIUWJ的博客-CSDN博客
class align(nn.Module):#殘差連接需要的那個xdef __init__(self, c_in, c_out):super(align, self).__init__()self.c_in = c_inself.c_out = c_outif c_in > c_out:self.conv1x1 = nn.Conv2d(c_in, c_out, 1)def forward(self, x):if self.c_in > self.c_out:return self.conv1x1(x)#如果輸出的維度小,那么就降維至輸出的維度if self.c_in < self.c_out:return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0])#如果輸出的維度大,那么就將維度升至輸出的維度#注:降維和升維,動的都是從左向右的第二個維度,比如一開始每一個batch是[50,1,12,228],之后我們升維和降維操作的都是1對應的維度return x4.3?temporal_conv_layer
時間卷積
class temporal_conv_layer(nn.Module): ''' kt:時間卷積核大小 '''def __init__(self, kt, c_in, c_out, act="relu"):super(temporal_conv_layer, self).__init__()self.kt = ktself.act = actself.c_out = c_outself.align = align(c_in, c_out)#殘差連接 H(x)=F(x)+x的那個+xif self.act == "GLU":self.conv = nn.Conv2d(c_in, c_out * 2, (kt, 1), 1)#門控部分控制c_out維輸出中,哪些維度是重要的,哪些是不重要的#所以輸出的維度是c_out*2,分別對應P和Q#(kt,1)是卷積的維度,每一列(也就是每一個觀測點)的kt個元素和卷積核進行卷積else:self.conv = nn.Conv2d(c_in, c_out, (kt, 1), 1)def forward(self, x):#x [batch_size,c_in,n_his,n_route]x_in = self.align(x)[:, :, self.kt - 1:, :]#x_in [batch_size,c_out,n_his,n_route]if self.act == "GLU:x_conv = self.conv(x)#x_conv_1 [batch_size,c_out*2,n_his,n_route]return (x_conv[:, :self.c_out, :, :] + x_in) * torch.sigmoid(x_conv[:, self.c_out:, :, :])#x_conv[:, :self.c_out, :, :] + x_in:殘差連接#torch.sigmoid(x_conv[:, self.c_out:, :, :]):sigma(Q)#返回值的維度是[batch_size,c_out,n_his,n_route]if self.act == "sigmoid":return torch.sigmoid(self.conv(x) + x_in)#返回值的維度是[batch_size,c_out,n_his,n_route]return torch.relu(self.conv(x) + x_in) #返回值的維度是[batch_size,c_out,n_his,n_route]4.3?spatio_conv_layer
空間卷積(交通預測論文筆記:Spatio-Temporal Graph Convolutional Networks: A Deep Learning Frameworkfor Traffic Forecast_UQI-LIUWJ的博客-CSDN博客)
kaiming分布:pytorch學習:xavier分布和kaiming分布_UQI-LIUWJ的博客-CSDN博客
einsum:python 筆記:愛因斯坦求和 einsum_UQI-LIUWJ的博客-CSDN博客
class spatio_conv_layer(nn.Module):def __init__(self, ks, c, Lk):super(spatio_conv_layer, self).__init__()self.Lk = Lkself.theta = nn.Parameter(torch.FloatTensor(c, c, ks))self.b = nn.Parameter(torch.FloatTensor(1, c, 1, 1))self.reset_parameters()def reset_parameters(self):init.kaiming_uniform_(self.theta, a=math.sqrt(5))#將和各個圖卷積切比雪夫近似項的權重參數 初始化fan_in, _ = init._calculate_fan_in_and_fan_out(self.theta)bound = 1 / math.sqrt(fan_in)init.uniform_(self.b, -bound, bound)#將圖卷積切比雪夫多項式近似的偏差初始化def forward(self, x):#x:[batch_size,c[1],n_his,n_route]#Lk:[Ks,n,n]x_c = torch.einsum("knm,bitm->bitkn", self.Lk, x)#x_c:[batch_size,c[1],n_his,Ks,n_route]#Tk(L)xx_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b#theta [c[1],c[1],ks]#x_gc:[batch_size,c[1],n_his,n_route]return torch.relu(x_gc + x)?forward的兩部由于使用了愛因斯坦求和的內容,我們詳細展開說一下
先看第一條:x_c = torch.einsum("knm,bitm->bitkn", self.Lk, x)
self.Lk是Ks個n_route*n_route的矩陣。
????????每個矩陣我們可以想成不同階的反應圖信息的矩陣(吸收了一階鄰居,兩階鄰居,三階鄰居。。。。Ks-1階鄰居信息之后的矩陣),其中(i,j)表示i對j的影響
x我們可以這么想:batch_size*n_his 個 C[1]*n_route的矩陣,其中每一列是一條路徑交通預測值的編碼
簡化起見,可以看成("nm,im->in")
假設有3條邊,每條邊用一個四維向量表示它的交通狀態
Lk:
x:
("nm,im->in")——結果的第(i,n)個元素,是,對所有的m,Lk中第(n,m)個元素和x中第(i,m)個元素乘積的和。
也就是,表示第n條邊的第i維交通狀態,等于所有邊第i維的交通狀態(x[i][m]),乘以這一條邊對于第n條邊的影響(Lk[i][m]),然后求和
??在切比雪夫近似的圖卷積里面,這一條einsum相當于
x_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b#x_c:[batch_size,c[1],n_his,Ks,n_route]#theta [c[1],c[1],ks] #x_gc:[batch_size,c[1],n_his,n_route]再看這一條
相當于('iok,ikn'->'on')
iok 的部分是Ks個 c[1]*c[1]的權重矩陣
ikn的部分是ks個c[1]*n_route的矩陣,表示每條邊的速度編碼
也即是說,最后結論里面,第n個點的第o維表示交通狀態的內容,等于各個theta中第o列分別乘以各個x_c中第n列的內容,然后求和
這個對應的是切比雪夫圖卷積中乘θ再求和的內容?
4.4?st_conv_block
class st_conv_block(nn.Module):def __init__(self, ks, kt, n, c, p, Lk): ''' ks:空間卷積核大小 kt:時間卷積核大小 c:blocks = [[1, 32, 64], [64, 32, 128]]中的一個,表示時間-空間-時間卷積層各有幾個隱藏層變量 n:n_route 路段數量 Lk:切比雪夫多項式近似后的圖拉普拉斯矩陣 p:dropout概率 '''super(st_conv_block, self).__init__()self.tconv1 = temporal_conv_layer(kt, c[0], c[1], "GLU")#門控時間卷積self.sconv = spatio_conv_layer(ks, c[1], Lk)self.tconv2 = temporal_conv_layer(kt, c[1], c[2])self.ln = nn.LayerNorm([n, c[2]])self.dropout = nn.Dropout(p)def forward(self, x):#x:[batch_size,c[0],n_his,n_route]x_t1 = self.tconv1(x)#x_t1:[batch_size,c[1],n_his,n_route]#X1經過了GRU門控,知道哪些時間片更重要x_s = self.sconv(x_t1)#x_s:[batch_size,c[1],n_his,n_route]x_t2 = self.tconv2(x_s)#x_t2:[batch_size,c[2],n_his,n_route]#x_t2直接relux_ln = self.ln(x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)#x_t2.permute(0, 2, 3, 1) [batch_size,n_his,n_route,c[2]]#對每個[n_route,c[2]](一個時刻,一個過去時間篇內所有路段的速度進行歸一化)#x_ln [batch_size,c[2],n_his,n_route]return self.dropout(x_ln)4.5?fully_conv_layer
class fully_conv_layer(nn.Module):def __init__(self, c):super(fully_conv_layer, self).__init__()self.conv = nn.Conv2d(c, 1, 1)#輸入channel數 c ,輸出channel數1,kernel size1*1def forward(self, x):return self.conv(x)4.6 output_layer
?
class output_layer(nn.Module):def __init__(self, c, T, n):#c:bs[1][2]#T:12-4*2=4super(output_layer, self).__init__()self.tconv1 = temporal_conv_layer(T, c, c, "GLU")#(T,1)的kenel_sizeself.ln = nn.LayerNorm([n, c])self.tconv2 = temporal_conv_layer(1, c, c, "sigmoid")self.fc = fully_conv_layer(c)def forward(self, x):#x:[batch_size,bs[1][2],n_his,n_route]x_t1 = self.tconv1(x)#x:[batch_size,bs[1][2],n_his,n_route]x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)#x_t1.permute(0, 2, 3, 1) [batch_size,n_his,n_route,c[2]]#對每個[n_route,c[2]](一個時刻,一個過去時間篇內所有路段的速度進行歸一化)#x_ln [batch_size,c[2],n_his,n_route]x_t2 = self.tconv2(x_ln)#x:[batch_size,bs[1][2],n_his,n_route]return self.fc(x_t2)#x:[batch_size,1,n_his,n_route]
4.7 STGCN
class STGCN(nn.Module):def __init__(self, ks, kt, bs, T, n, Lk, p): ''' ks:空間卷積核大小 kt:時間卷積核大小 bs:blocks = [[1, 32, 64], [64, 32, 128]] T:n_his,過去幾個時間片段來預測未來 n:n_route 路段數量 Lk:切比雪夫多項式近似后的圖拉普拉斯矩陣 p:dropout概率 '''super(STGCN, self).__init__()self.st_conv1 = st_conv_block(ks, kt, n, bs[0], p, Lk)#第一個ST卷積塊self.st_conv2 = st_conv_block(ks, kt, n, bs[1], p, Lk)#第二個ST卷積塊self.output = output_layer(bs[1][2], T - 4 * (kt - 1), n)def forward(self, x):#x:[batch_size,bs[0][0],n_his,n_route]x_st1 = self.st_conv1(x)#x_st1:[batch_size,bs[0][2],n_his,n_route]x_st2 = self.st_conv2(x_st1)#x_st2:[batch_size,bs[1][2],n_his,n_route]return self.output(x_st2)#x:[batch_size,1,n_his,n_route]總結
以上是生活随笔為你收集整理的pytorch 学习: STGCN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python 笔记:爱因斯坦求和 ein
- 下一篇: ntu 课程笔记 :MAS714(7)