深度学习之早停策略EarlyStopping以及保存测试集准确率最高的模型ModelCheckpoint
在訓練神經網絡時,如果epochs設置的過多,導致最終結束時測試集上模型的準確率比較低,而我們卻想保存準確率最高時候的模型參數,這就需要用到Early Stopping以及ModelCheckpoint。
一.早停策略之EarlyStopping
EarlyStopping是用于提前停止訓練的callbacks,callbacks用于指定在每個epoch開始和結束的時候進行哪種特定操作。簡而言之,就是可以達到當測試集上的loss不再減小(即減小的程度小于某個閾值)的時候停止繼續訓練。
1.EarlyStopping的原理
1.將數據分為訓練集和測試集
2.每個epoch結束后(或每N個epoch后): 在測試集上獲取測試結果,隨著epoch的增加,如果在測試集上發現測試誤差上升,則停止訓練;
3.將停止之后的權重作為網絡的最終參數。
這兒就有一個疑惑,在平常模型訓練時,會發現模型的loss值有時會出現降低再上升再下降的情況,難道只要再上升的時候就要停止嘛?上升之后再下降有可能會得到更低的loss值,那么如果只要上升就停止的話,就會得不償失。現實肯定不是這樣的!不能根據一兩次的連續降低就判斷不再提高。一般的做法是,在訓練的過程中,記錄到目前為止最好的測試集精度,當連續10次epoch(或者更多次)沒達到最佳精度時,則可以認為精度不再提高了。
看圖直觀感受一下:
2.EarlyStopping的優缺點
優點:只運行一次梯度下降,我們就可以找出w的較小值,中間值和較大值。而無需嘗試L2正則化超級參數lambda的很多值。
缺點:不能獨立地處理以上兩個問題,使得要考慮的東西變得復雜
3.參數解釋
tf.keras.callbacks.EarlyStopping(monitor="acc",min_delta=0,patience=0,verbose=0,mode="max",baseline=None,restore_best_weights=False, )1.monitor: 監控的數據接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情況下如果有驗證集,就用’val_acc’或者’val_loss’。
2.mode: 就’auto’, ‘min’, ‘,max’三個可能。如果知道是要上升還是下降,建議設置一下。例如監控的是’acc’,那么就設置為’max’。
3.min_delta:增大或減小的閾值,只有大于這個部分才算作改善(監控的數據不同,變大變小就不確定)。這個值的大小取決于monitor,也反映了你的容忍程度。
4.patience:能夠容忍多少個epoch內都沒有改善。patience的大小和learning rate直接相關。在learning rate設定的情況下,前期先訓練幾次觀察抖動的epoch number,patience設置的值應當稍大于epoch number。在learning rate變化的情況下,建議要略小于最大的抖動epoch number。
5.baseline:監控數據的基線值,如果在訓練過程中,模型訓練結果相比于基線值沒有什么改善的話,就停止訓練。
二.ModelCheckpoint
函數原型:
tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)參數解釋
1.filename:字符串,保存模型的路徑,filepath可以是格式化的字符串,里面的占位符將會被epoch值和傳入on_epoch_end的logs關鍵字所填入。
例如:filepath = “weights_{epoch:03d}-{val_loss:.4f}.h5”,則會生成對應epoch和測試集loss的多個文件。
2.monitor:需要監視的值,通常為:val_acc 、 val_loss 、 acc 、 loss四種。
3.verbose:信息展示模式,0或1。為1表示輸出epoch模型保存信息,默認為0表示不輸出該信息。
4.save_best_only:當設置為True時,將只保存在測試集上性能最好的模型。
5.mode:‘auto’,‘min’,‘max’之一,在save_best_only=True時決定性能最佳模型的評判準則,例如,當監測值為val_acc時,模式應為max,當檢測值為val_loss時,模式應為min。在auto模式下,評價準則由被監測值的名字自動推斷。
6.save_weights_only:若設置為True,則只保存模型權重,否則將保存整個模型(包括模型結構,配置信息等)。
7.period:CheckPoint之間的間隔的epoch數。
三.樣例示范
from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStoppingearlystopper = EarlyStopping(monitor='loss', patience=1, verbose=1,mode = 'min')checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=0,save_best_only=True,save_weights_only=True,mode = 'max') train_model = model.fit(train_ds,epochs=epochs,validation_data=test_ds,callbacks=[earlystopper, checkpointer]#<-看這兒)努力加油a啊
參考鏈接:
https://blog.csdn.net/zwqjoy/article/details/86677030 https://blog.csdn.net/zengNLP/article/details/94589469 創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的深度学习之早停策略EarlyStopping以及保存测试集准确率最高的模型ModelCheckpoint的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习之基于InceptionV3实现
- 下一篇: mac全选文字的快捷键_MACBOOK最