【算法竞赛学习】气象海洋预测-Task5 模型建立之 SA-ConvLSTM
氣象海洋預(yù)測(cè)-Task5 模型建立之 SA-ConvLSTM
該方案中采用的模型是SA-ConvLSTM。
前兩個(gè)TOP方案中選擇將賽題看作一個(gè)多輸出的任務(wù),通過(guò)構(gòu)建神經(jīng)網(wǎng)絡(luò)直接輸出24個(gè)nino3.4預(yù)測(cè)值,這種思路的問(wèn)題在于,序列問(wèn)題往往是時(shí)序依賴的,當(dāng)我們采用多輸出的方法時(shí)其實(shí)把這24個(gè)nino3.4預(yù)測(cè)值看作是完全獨(dú)立的,但是實(shí)際上它們之間是存在序列依賴的,即每個(gè)預(yù)測(cè)值往往受上一個(gè)時(shí)間步的預(yù)測(cè)值的影響。因此,在這次的TOP方案中,采用Seq2Seq結(jié)構(gòu)來(lái)考慮輸出預(yù)測(cè)值的序列依賴性。
Seq2Seq結(jié)構(gòu)包括Encoder(編碼器)和Decoder(解碼器)兩部分,Encoder部分將輸入序列編碼成一個(gè)向量,Decoder部分對(duì)向量進(jìn)行解碼,輸出一個(gè)預(yù)測(cè)序列。要將Seq2Seq結(jié)構(gòu)應(yīng)用于不同的序列問(wèn)題,關(guān)鍵在于每一個(gè)時(shí)間步所使用的Cell。我們之前說(shuō)到,挖掘空間信息通常會(huì)采用CNN,挖掘時(shí)間信息通常會(huì)采用RNN或LSTM,將二者結(jié)合在一起就得到了時(shí)空序列領(lǐng)域的經(jīng)典模型——ConvLSTM,我們本次要學(xué)習(xí)的SA-ConvLSTM模型是對(duì)ConvLSTM模型的改進(jìn),在其基礎(chǔ)上引入了自注意力機(jī)制來(lái)提高模型對(duì)于長(zhǎng)期空間依賴關(guān)系的挖掘能力。
另外與前兩個(gè)TOP方案所不同的一點(diǎn)是,該TOP方案沒(méi)有直接預(yù)測(cè)Nino3.4指數(shù),而是通過(guò)預(yù)測(cè)sst來(lái)間接求得Nino3.4指數(shù)序列。
學(xué)習(xí)目標(biāo)
內(nèi)容介紹
- 數(shù)據(jù)扁平化
- 空值填充
- 構(gòu)造數(shù)據(jù)集
- 構(gòu)造評(píng)估函數(shù)
- 模型構(gòu)造
- 模型訓(xùn)練
- 模型評(píng)估
代碼示例
數(shù)據(jù)處理
該TOP方案的數(shù)據(jù)處理主要包括三部分:
數(shù)據(jù)扁平化
采用滑窗構(gòu)造數(shù)據(jù)集。該方案中只使用了sst特征,且只使用了lon值在[90, 330]范圍內(nèi)的數(shù)據(jù),可能是為了節(jié)約計(jì)算資源。
def make_flatted(train_ds, label_ds, info, start_idx=0):# 只使用sst特征keys = ['sst']label_key = 'nino'# 年數(shù)years = info[1]# 模式數(shù)models = info[2]train_list = []label_list = []# 將同種模式下的數(shù)據(jù)拼接起來(lái)for model_i in range(models):blocks = []# 對(duì)每個(gè)特征,取每條數(shù)據(jù)的前12個(gè)月進(jìn)行拼接,只使用lon值在[90, 330]范圍內(nèi)的數(shù)據(jù)for key in keys:block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).datablocks.append(block)# 將所有特征在最后一個(gè)維度上拼接起來(lái)train_flatted = np.concatenate(blocks, axis=-1)# 取12-23月的標(biāo)簽進(jìn)行拼接,注意加上最后一年的最后12個(gè)月的標(biāo)簽(與最后一年12-23月的標(biāo)簽共同構(gòu)成最后一年前12個(gè)月的預(yù)測(cè)目標(biāo))label_flatted = np.concatenate([label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data], axis=0)train_list.append(train_flatted)label_list.append(label_flatted)return train_list, label_list soda_info = ('soda', 100, 1) cmip6_info = ('cmip6', 151, 15) cmip5_info = ('cmip5', 140, 17)soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info) cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info) cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])# 得到扁平化后的數(shù)據(jù)維度為(模式數(shù)×序列長(zhǎng)度×緯度×經(jīng)度×特征數(shù)),其中序列長(zhǎng)度=年數(shù)×12 np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains) ((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))空值填充
將空值填充為0。
# 填充SODA數(shù)據(jù)中的空值 soda_trains = np.array(soda_trains) soda_trains_nan = np.isnan(soda_trains) soda_trains[soda_trains_nan] = 0 print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains))) Number of null in soda_trains after fillna: 0 # 填充CMIP6數(shù)據(jù)中的空值 cmip6_trains = np.array(cmip6_trains) cmip6_trains_nan = np.isnan(cmip6_trains) cmip6_trains[cmip6_trains_nan] = 0 print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains))) Number of null in cmip6_trains after fillna: 0 # 填充CMIP5數(shù)據(jù)中的空值 cmip5_trains = np.array(cmip5_trains) cmip5_trains_nan = np.isnan(cmip5_trains) cmip5_trains[cmip5_trains_nan] = 0 print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains))) Number of null in cmip6_trains after fillna: 0構(gòu)造數(shù)據(jù)集
構(gòu)造訓(xùn)練和驗(yàn)證集。注意這里取每條輸入數(shù)據(jù)的序列長(zhǎng)度是38,這是因?yàn)檩斎雜st序列長(zhǎng)度是12,輸出sst序列長(zhǎng)度是26,在訓(xùn)練中采用teacher forcing策略(這個(gè)策略會(huì)在之后的模型構(gòu)造時(shí)詳細(xì)說(shuō)明),因此這里在構(gòu)造輸入數(shù)據(jù)時(shí)包含了輸出sst序列的實(shí)際值。
# 構(gòu)造訓(xùn)練集X_train = [] y_train = [] # 從CMIP5的17種模式中各抽取100條數(shù)據(jù) for model_i in range(17):samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)for ind in samples:X_train.append(cmip5_trains[model_i, ind: ind+38])y_train.append(cmip5_labels[model_i][ind: ind+24]) # 從CMIP6的15種模式種各抽取100條數(shù)據(jù) for model_i in range(15):samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)for ind in samples:X_train.append(cmip6_trains[model_i, ind: ind+38])y_train.append(cmip6_labels[model_i][ind: ind+24]) X_train = np.array(X_train) y_train = np.array(y_train) # 構(gòu)造測(cè)試集X_valid = [] y_valid = [] samples = np.random.choice(soda_trains.shape[1]-38, size=100) for ind in samples:X_valid.append(soda_trains[0, ind: ind+38])y_valid.append(soda_labels[0][ind: ind+24]) X_valid = np.array(X_valid) y_valid = np.array(y_valid) # 查看數(shù)據(jù)集維度 X_train.shape, y_train.shape, X_valid.shape, y_valid.shape ((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24)) # 保存數(shù)據(jù)集 np.save('X_train_sample.npy', X_train) np.save('y_train_sample.npy', y_train) np.save('X_valid_sample.npy', X_valid) np.save('y_valid_sample.npy', y_valid)模型構(gòu)建
# 讀取數(shù)據(jù)集 X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy') y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy') X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy') y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy') X_train.shape, y_train.shape, X_valid.shape, y_valid.shape ((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24)) # 構(gòu)造數(shù)據(jù)管道 class AIEarthDataset(Dataset):def __init__(self, data, label):self.data = torch.tensor(data, dtype=torch.float32)self.label = torch.tensor(label, dtype=torch.float32)def __len__(self):return len(self.label)def __getitem__(self, idx):return self.data[idx], self.label[idx] batch_size = 2trainset = AIEarthDataset(X_train, y_train) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)validset = AIEarthDataset(X_valid, y_valid) validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)構(gòu)造評(píng)估函數(shù)
def rmse(y_true, y_preds):return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))# 評(píng)估函數(shù) def score(y_true, y_preds):# 相關(guān)性技巧評(píng)分accskill_score = 0# RMSErmse_scores = 0a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6y_true_mean = np.mean(y_true, axis=0)y_pred_mean = np.mean(y_preds, axis=0)for i in range(24):fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))cor_i = fenzi / fenmuaccskill_score += a[i] * np.log(i+1) * cor_irmse_score = rmse(y_true[:, i], y_preds[:, i])rmse_scores += rmse_scorereturn 2/3.0 * accskill_score - rmse_scores模型構(gòu)造
不同于前兩個(gè)TOP方案所構(gòu)建的多輸出神經(jīng)網(wǎng)絡(luò),該TOP方案采用的是Seq2Seq結(jié)構(gòu),以本賽題為例,輸入的序列長(zhǎng)度是12,輸出的序列長(zhǎng)度是26,方案中構(gòu)建了四個(gè)隱藏層,那么一個(gè)基礎(chǔ)的Seq2Seq結(jié)構(gòu)就如下圖所示:
要將Seq2Seq結(jié)構(gòu)應(yīng)用于不同的問(wèn)題,重點(diǎn)在于使用怎樣的Cell(神經(jīng)元)。在該TOP方案中使用的Cell是清華大學(xué)提出的SA-ConvLSTM(Self-Attention ConvLSTM),論文原文可參考https://ojs.aaai.org//index.php/AAAI/article/view/6819
SA-ConvLSTM是施行健博士提出的時(shí)空序列領(lǐng)域經(jīng)典模型ConvLSTM的改進(jìn)模型,為了捕捉空間信息的時(shí)序依賴關(guān)系,它在ConvLSTM的基礎(chǔ)上增加了SAM模塊,用來(lái)記憶空間的聚合特征。ConvLSTM的論文原文可參考https://arxiv.org/pdf/1506.04214.pdf
LSTM模型是非常經(jīng)典的時(shí)序模型,三個(gè)門的結(jié)構(gòu)使得它在挖掘長(zhǎng)期的時(shí)間依賴任務(wù)中有不俗的表現(xiàn),并且相較于RNN,LSTM能夠有效地避免梯度消失問(wèn)題。對(duì)于單個(gè)輸入樣本,在每個(gè)時(shí)間步上,LSTM的每個(gè)門實(shí)際是對(duì)輸入向量做了一個(gè)全連接,那么對(duì)應(yīng)到我們這個(gè)賽題上,輸入X的形狀是(N,T,H,W,C),則單個(gè)輸入樣本在每個(gè)時(shí)間步上輸入LSTM的就是形狀為(H,W,C)的空間信息。我們知道,全連接網(wǎng)絡(luò)對(duì)于這種空間信息的提取能力并不強(qiáng),轉(zhuǎn)換成卷積操作后能夠在大大減少參數(shù)量的同時(shí)通過(guò)堆疊多層網(wǎng)絡(luò)逐步提取出更復(fù)雜的特征,到這里就可以很自然地想到,把LSTM中的全連接操作轉(zhuǎn)換為卷積操作,就能夠適用于時(shí)空序列問(wèn)題。ConvLSTM模型就是這么做的,實(shí)踐也表明這樣的作法是非常有效的。
然而,ConvLSTM模型存在兩個(gè)問(wèn)題:
一是卷積層的感受野受限于卷積核的大小,需要通過(guò)堆疊多個(gè)卷積層來(lái)擴(kuò)大感受野,發(fā)掘全局的特征。舉例來(lái)說(shuō),假設(shè)第一個(gè)卷積層的卷積核大小是3×3,那么這一層的每個(gè)節(jié)點(diǎn)就只能感知這3×3的空間范圍內(nèi)的輸入信息,此時(shí)再增加一個(gè)3×3的卷積層,那么每個(gè)節(jié)點(diǎn)所能感知的就是3×3個(gè)第一層的節(jié)點(diǎn)內(nèi)的信息,在第一層步長(zhǎng)為1的情況下,就是4×4范圍內(nèi)的輸入信息,于是相比于第一個(gè)卷積層,第二層所能感知的輸入信息的空間范圍就增大了,而這樣做所帶來(lái)的后果就是參數(shù)量增加。對(duì)于單純的CNN模型來(lái)說(shuō)增加一層只是增加了一個(gè)卷積核大小的參數(shù)量,但是對(duì)于ConvLSTM來(lái)說(shuō)就有些不堪重負(fù),參數(shù)量的增加增大了過(guò)擬合的風(fēng)險(xiǎn),與此同時(shí)模型的收效卻并不高。
二是卷積操作只針對(duì)當(dāng)前時(shí)間步輸入的空間信息,而忽視了過(guò)去的空間信息,因此難以挖掘空間信息在時(shí)間上的依賴關(guān)系。
因此,為了同時(shí)挖掘全局和本地的空間依賴,提升模型在大空間范圍和長(zhǎng)時(shí)間的時(shí)空序列預(yù)測(cè)任務(wù)中的預(yù)測(cè)效果,SA-ConvLSTM模型在ConvLSTM模型的基礎(chǔ)上引入了SAM(self-attention memory)模塊。
SAM模塊引入了一個(gè)新的記憶單元M,用來(lái)記憶包含時(shí)序依賴關(guān)系的空間信息。SAM模塊以當(dāng)前時(shí)間步通過(guò)ConvLSTM所獲得的隱藏層狀態(tài)HtH_tHt?和上一個(gè)時(shí)間步的記憶Mt?1M_{t-1}Mt?1?作為輸入,首先將HtH_tHt?通過(guò)自注意力機(jī)制得到特征ZhZ_hZh?,自注意力機(jī)制能夠增加HtH_tHt?中與其他部分更相關(guān)的部分的權(quán)重,同時(shí)HtH_tHt?也作為Query與Mt?1M_{t-1}Mt?1?共同通過(guò)注意力機(jī)制得到特征ZmZ_mZm?,用以增強(qiáng)對(duì)Mt?1M_{t-1}Mt?1?中與HtH_tHt?有更強(qiáng)依賴關(guān)系的部分的權(quán)重,將ZhZ_hZh?和ZmZ_mZm?拼接起來(lái)就得到了二者的聚合特征ZZZ。此時(shí),聚合特征ZZZ中既包含了當(dāng)前時(shí)間步的信息,又包含了全局的時(shí)空記憶信息,接下來(lái)借鑒LSTM中的門控結(jié)構(gòu)用聚合特征ZZZ對(duì)隱藏層狀態(tài)和記憶單元進(jìn)行更新,就得到了更新后的隱藏層狀態(tài)Ht^\hat{H_t}Ht?^?和當(dāng)前時(shí)間步的記憶MtM_tMt?。SAM模塊的公式如下:
it′=σ(Wm;zi?Z+Wm;hi?Ht+bm;i)gt′=tanh(Wm;zg?Z+Wm;hg?Ht+bm;g)Mt=(1?it′)°Mt?1+it′°gt′ot′=σ(Wm;zo?Z+Wm;ho?Ht+bm;o)Ht^=ot′°Mt\begin{aligned} & i'_t = \sigma (W_{m;zi} \ast Z + W_{m;hi} \ast H_t + b_{m;i}) \\ & g'_t = tanh (W_{m;zg} \ast Z + W_{m;hg} \ast H_t + b_{m;g}) \\ & M_t = (1 - i'_t) \circ M_{t-1} + i'_t \circ g'_t \\ & o'_t = \sigma (W_{m;zo} \ast Z + W_{m;ho} \ast H_t + b_{m;o}) \\ & \hat{H_t} = o'_t \circ M_t \end{aligned} ?it′?=σ(Wm;zi??Z+Wm;hi??Ht?+bm;i?)gt′?=tanh(Wm;zg??Z+Wm;hg??Ht?+bm;g?)Mt?=(1?it′?)°Mt?1?+it′?°gt′?ot′?=σ(Wm;zo??Z+Wm;ho??Ht?+bm;o?)Ht?^?=ot′?°Mt??
關(guān)于注意力機(jī)制和自注意力機(jī)制可以參考以下鏈接:
- 深度學(xué)習(xí)中的注意力機(jī)制:https://blog.csdn.net/malefactor/article/details/78767781
- 目前主流的Attention方法:https://www.zhihu.com/question/68482809
將以上二者結(jié)合起來(lái),就得到了SA-ConvLSTM模型:
# Attention機(jī)制 def attn(query, key, value):# query、key、value的形狀都是(N, C, H*W),令S=H*W# 采用縮放點(diǎn)積模型計(jì)算得分,scores(i)=key(i)^T query/根號(hào)Cscores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)# 計(jì)算注意力得分attn = F.softmax(scores, dim=-1)output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)return output.transpose(1, 2) # (N, C, S) # SAM模塊 class SAAttnMem(nn.Module):def __init__(self, input_dim, d_model, kernel_size):super().__init__()pad = kernel_size[0] // 2, kernel_size[1] // 2self.d_model = d_modelself.input_dim = input_dim# 用1*1卷積實(shí)現(xiàn)全連接操作WhHtself.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)# 用1*1卷積實(shí)現(xiàn)全連接操作WmMt-1self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)# 用1*1卷積實(shí)現(xiàn)全連接操作Wz[Zh,Zm]self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)# 注意輸出維度和輸入維度要保持一致,都是input_dimself.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)def forward(self, h, m):# self.conv_h(h)得到WhHt,將其在dim=1上劃分成大小為self.d_model的塊,每一塊的形狀就是(N, d_model, H, W),所得到的三塊就是Qh、Kh、Vhhq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)# 同樣的方法得到Km和Vmmk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)N, C, H, W = hq.size()# 通過(guò)自注意力機(jī)制得到ZhZh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model# 通過(guò)注意力機(jī)制得到ZmZm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model# 將Zh和Zm拼接起來(lái),并進(jìn)行全連接操作得到聚合特征ZZ = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model# 計(jì)算i't、g't、o'ti, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dimi = torch.sigmoid(i)g = torch.tanh(g)# 得到更新后的記憶單元Mtm_next = i * g + (1 - i) * m# 得到更新后的隱藏狀態(tài)Hth_next = torch.sigmoid(o) * m_nextreturn h_next, m_next # SA-ConvLSTM Cell class SAConvLSTMCell(nn.Module):def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):super().__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimpad = kernel_size[0] // 2, kernel_size[1] // 2# 卷積操作Wx*Xt+Wh*Ht-1self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)def initialize(self, inputs):device = inputs.deviceN, _, H, W = inputs.size()# 初始化隱藏層狀態(tài)Htself.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)# 初始化記憶細(xì)胞狀態(tài)ctself.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)# 初始化記憶單元狀態(tài)Mtself.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)def forward(self, inputs, first_step=False):# 如果當(dāng)前是第一個(gè)時(shí)間步,初始化Ht、ct、Mtif first_step:self.initialize(inputs)# ConvLSTM部分# 拼接X(jué)t和Htcombined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim# 進(jìn)行卷積操作combined_conv = self.conv(combined) # 得到四個(gè)門控單元it、ft、ot、gtcc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)# 得到當(dāng)前時(shí)間步的記憶細(xì)胞狀態(tài)ct=ft·ct-1+it·gtself.cell_state = f * self.cell_state + i * g# 得到當(dāng)前時(shí)間步的隱藏層狀態(tài)Ht=ot·tanh(ct)self.hidden_state = o * torch.tanh(self.cell_state)# SAM部分,更新Ht和Mtself.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)return self.hidden_state在Seq2Seq模型的訓(xùn)練中,有兩種訓(xùn)練模式。一是Free running,也就是傳統(tǒng)的訓(xùn)練方式,以上一個(gè)時(shí)間步的輸出yt?1^\hat{y_{t-1}}yt?1?^?作為下一個(gè)時(shí)間步的輸入,但是這種做法存在的問(wèn)題是在訓(xùn)練的初期所得到的yt?1^\hat{y_{t-1}}yt?1?^?與實(shí)際標(biāo)簽yt?1y_{t-1}yt?1?相差甚遠(yuǎn),以此作為輸入會(huì)導(dǎo)致后續(xù)的輸出越來(lái)越偏離我們期望的預(yù)測(cè)標(biāo)簽。于是就產(chǎn)生了第二種訓(xùn)練模式——Teacher forcing。
Teacher forcing就是直接使用實(shí)際標(biāo)簽yt?1y_{t-1}yt?1?作為下一個(gè)時(shí)間步的輸入,由老師(ground truth)帶領(lǐng)著防止模型越走越偏。但是老師不能總是手把手領(lǐng)著學(xué)生走,要逐漸放手讓學(xué)生自主學(xué)習(xí),于是我們使用Scheduled Sampling來(lái)控制使用實(shí)際標(biāo)簽的概率。我們用ratio來(lái)表示Scheduled Sampling的比例,在訓(xùn)練初期,ratio=1,模型完全由老師帶領(lǐng)著,隨著訓(xùn)練論述的增加,ratio以一定的方式衰減(該方案中使用線性衰減,ratio每次減小一個(gè)衰減率decay_rate),每個(gè)時(shí)間步以ratio的概率從伯努利分布中提取二進(jìn)制隨機(jī)數(shù)0或1,為1時(shí)輸入就是實(shí)際標(biāo)簽yt?1y_{t-1}yt?1?,否則輸入為yt?1^\hat{y_{t-1}}yt?1?^?。
# 構(gòu)建SA-ConvLSTM模型 class SAConvLSTM(nn.Module):def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):super().__init__()self.input_dim = input_dimself.hidden_dim = hidden_dimself.num_layers = len(hidden_dim)layers = []for i in range(self.num_layers):cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) self.layers = nn.ModuleList(layers)self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):# 將輸入樣本X的形狀(N, T, H, W, C)轉(zhuǎn)換為(N, T, C, H, W)input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()# 僅在訓(xùn)練時(shí)使用teacher forcingif train:if teacher_forcing and scheduled_sampling_ratio > 1e-6:teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))else:teacher_forcing = Falseelse:teacher_forcing = Falsetotal_steps = input_frames + future_frames - 1outputs = [None] * total_steps# 對(duì)于每一個(gè)時(shí)間步for t in range(total_steps):# 在前12個(gè)月,使用每個(gè)月的輸入樣本Xtif t < input_frames:input_ = input_x[:, t].to(device)# 若不使用teacher forcing,則以上一個(gè)時(shí)間步的預(yù)測(cè)標(biāo)簽作為當(dāng)前時(shí)間步的輸入elif not teacher_forcing:input_ = outputs[t-1]# 若使用teacher forcing,則以ratio的概率使用上一個(gè)時(shí)間步的實(shí)際標(biāo)簽作為當(dāng)前時(shí)間步的輸入else:mask = teacher_forcing_mask[:, t-input_frames].float().to(device)input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)first_step = (t==0)input_ = input_.float()# 將當(dāng)前時(shí)間步的輸入通過(guò)隱藏層for layer_idx in range(self.num_layers):input_ = self.layers[layer_idx](input_, first_step=first_step)# 記錄每個(gè)時(shí)間步的輸出if train or (t >= (input_frames - 1)):outputs[t] = self.conv_output(input_)outputs = [x for x in outputs if x is not None]# 確認(rèn)輸出序列的長(zhǎng)度if train:assert len(outputs) == output_frameselse:assert len(outputs) == future_frames# 得到sst的預(yù)測(cè)序列outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)# 對(duì)sst的預(yù)測(cè)序列在nino3.4區(qū)域取三個(gè)月的平均值就得到nino3.4指數(shù)的預(yù)測(cè)序列nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)return nino_pred # 輸入特征數(shù) input_dim = 1 # 隱藏層節(jié)點(diǎn)數(shù) hidden_dim = (64, 64, 64, 64) # 注意力機(jī)制節(jié)點(diǎn)數(shù) d_attn = 32 # 卷積核大小 kernel_size = (3, 3)model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size) print(model) SAConvLSTM((layers): ModuleList((0): SAConvLSTMCell((conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(sa): SAAttnMem((conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))(conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))(conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(1): SAConvLSTMCell((conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(sa): SAAttnMem((conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))(conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))(conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(2): SAConvLSTMCell((conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(sa): SAAttnMem((conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))(conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))(conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))(3): SAConvLSTMCell((conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(sa): SAAttnMem((conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))(conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))(conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))))(conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1)) )模型訓(xùn)練
# 采用RMSE作為損失函數(shù) def RMSELoss(y_pred,y_true):loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()return loss model_weights = './task05_model_weights.pth' device = 'cuda' if torch.cuda.is_available() else 'cpu' model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device) criterion = RMSELoss optimizer = optim.Adam(model.parameters(), lr=1e-3) lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001) epochs = 5 ratio, decay_rate = 1, 8e-5 train_losses, valid_losses = [], [] scores = [] best_score = float('-inf') preds = np.zeros((len(y_valid),24))for epoch in range(epochs):print('Epoch: {}/{}'.format(epoch+1, epochs))# 模型訓(xùn)練model.train()losses = 0for data, labels in tqdm(trainloader):data = data.to(device)labels = labels.to(device)optimizer.zero_grad()# ratio線性衰減ratio = max(ratio-decay_rate, 0)pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)loss = criterion(pred, labels)losses += loss.cpu().detach().numpy()loss.backward()optimizer.step()train_loss = losses / len(trainloader)train_losses.append(train_loss)print('Training Loss: {:.3f}'.format(train_loss))# 模型驗(yàn)證model.eval()losses = 0with torch.no_grad():for i, data in tqdm(enumerate(validloader)):data, labels = datadata = data.to(device)labels = labels.to(device)pred = model(data, train=False)loss = criterion(pred, labels)losses += loss.cpu().detach().numpy()preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()valid_loss = losses / len(validloader)valid_losses.append(valid_loss)print('Validation Loss: {:.3f}'.format(valid_loss))s = score(y_valid, preds)scores.append(s)print('Score: {:.3f}'.format(s))# 保存最佳模型權(quán)重if s > best_score:best_score = scheckpoint = {'best_score': s,'state_dict': model.state_dict()}torch.save(checkpoint, model_weights) Epoch: 1/5100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]Training Loss: 3.28950it [00:11, 4.47it/s]Validation Loss: 44.009 Score: -43.458 Epoch: 2/5100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]Training Loss: 3.08450it [00:11, 4.33it/s]Validation Loss: 25.011 Score: -19.966 Epoch: 3/5100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]Training Loss: 13.46150it [00:12, 4.16it/s]Validation Loss: 15.438 Score: -14.139 Epoch: 4/5100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]Training Loss: 17.62750it [00:12, 3.99it/s]Validation Loss: 15.389 Score: -22.500 Epoch: 5/5100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]Training Loss: 17.59250it [00:11, 4.48it/s]Validation Loss: 15.252 Score: -14.459 # 繪制訓(xùn)練/驗(yàn)證曲線 def training_vis(train_losses, valid_losses):# 繪制損失函數(shù)曲線fig = plt.figure(figsize=(8,4))# subplot lossax1 = fig.add_subplot(121)ax1.plot(train_losses, label='train_loss')ax1.plot(valid_losses,label='val_loss')ax1.set_xlabel('Epochs')ax1.set_ylabel('Loss')ax1.set_title('Loss on Training and Validation Data')ax1.legend()plt.tight_layout() training_vis(train_losses, valid_losses)模型評(píng)估
在測(cè)試集上評(píng)估模型效果。
# 加載得分最高的模型 checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth') model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size) model.load_state_dict(checkpoint['state_dict']) <All keys matched successfully> # 測(cè)試集路徑 test_path = '../input/ai-earth-tests/' # 測(cè)試集標(biāo)簽路徑 test_label_path = '../input/ai-earth-tests-labels/' import os# 讀取測(cè)試數(shù)據(jù)和測(cè)試數(shù)據(jù)的標(biāo)簽 files = os.listdir(test_path) X_test = [] y_test = [] for file in files:X_test.append(np.load(test_path + file))y_test.append(np.load(test_label_path + file)) X_test = np.array(X_test)[:, :, :, 19: 67, :1] y_test = np.array(y_test) X_test.shape, y_test.shape ((103, 12, 24, 48, 1), (103, 24)) testset = AIEarthDataset(X_test, y_test) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False) # 在測(cè)試集上評(píng)估模型效果 model.eval() model.to(device) preds = np.zeros((len(y_test),24)) for i, data in tqdm(enumerate(testloader)):data, labels = datadata = data.to(device)labels = labels.to(device)pred = model(data, train=False)preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy() s = score(y_test, preds) print('Score: {:.3f}'.format(s))總結(jié)
這一次的TOP方案沒(méi)有自己設(shè)計(jì)模型,而是使用了目前時(shí)空序列預(yù)測(cè)領(lǐng)域現(xiàn)有的模型,另一組TOP選手“ailab”也使用了現(xiàn)有的模型PredRNN++,關(guān)于時(shí)空序列預(yù)測(cè)領(lǐng)域的一些比較經(jīng)典的模型可以參考https://www.zhihu.com/column/c_1208033701705162752
作業(yè)
該TOP方案中以sst作為預(yù)測(cè)目標(biāo),間接計(jì)算nino3.4指數(shù),學(xué)有余力的同學(xué)可以嘗試用SA-ConvLSTM模型直接預(yù)測(cè)nino3.4指數(shù)。
參考文獻(xiàn)
總結(jié)
以上是生活随笔為你收集整理的【算法竞赛学习】气象海洋预测-Task5 模型建立之 SA-ConvLSTM的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: WAMP的详细安装过程分享
- 下一篇: 公积金账户密码怎么查