dropout和bagging_Dropout的Bagging思想以及使用要点
一:Dropout的原理及bagging思想
1:Dropout原理
Dropout是深度學習中應對過擬合問題簡單且有效的一種正則化方法。原理很簡單:在訓練階段,在每一個Epoch中都以一定比例隨機的丟棄網絡中的一些神經元,如圖一所示,使得每次訓練的模型包含的神經元都不同。這種方式使得網絡權重在更新過程中不依賴隱藏節點之間的固定關系(隱藏層節點的固定關系可能會影響參數更新過程),同時使得網絡不會對某一個特定的神經元過分敏感,從而提高了網絡的泛化能力。圖一:Dropout原理圖
2:Dropout的Bagging思想
從《百面深度學習》這本書中的相關內容介紹領悟到,Dropout這種以一定比例隨機丟棄神經元的方式是一種Bagging的思想:神經網絡通過Dropout層以一定比例隨即的丟棄神經元,使得每次訓練的網絡模型都不相同,多個Epoch下來相當于訓練了多個模型,同時每一個模型都參與了對最終結果的投票,從而提高了模型的泛化能力。在此注意,Dropout與Bagging有一點不同是:Bagging的各個模型之間是相互獨立的,而Dropout各個模型之間是共享權重的。bagging是利用相同數據訓練多個模型,然后將各個模型的結果投票或者加權取平均等。
二:Dropout使用要點
1:Dropout參數設置介紹
首先說一下Dropout的參數:官方文檔中是這樣介紹的:p: probability of an element to be zeroed. Default: 0.5
inplace: If set to ``True``, will do this operation in-place. Default: ``False``第一個參數p代表Dropout率:即一個神經元被丟棄的概率,相反一個神經元被保留下來的概率即1-p;當p設置為1時,表示所有的神經元都被丟棄,輸出全為0。
第二個參數inplace,是布爾量,默認為'False',當inplace = True時,這個Dropout操作會作用在tensor自身上;當inplace = False時,tensor自身則不會改變。設置為True的好處是:上層網絡傳遞下來的tensor直接進行修改,可以節省運算內存,不用多儲存變量,所以顯存不夠的可以設置成'True'。
為了便于大家理解,請看下面代碼:
import torch
import torch.nn as nn
d = nn.Dropout(p = 0.5, inplace = True)
input = torch.randn(4, 3)
d(input)
input
#input輸出結果
tensor([[-0.0000, 0.0000, 0.0000],
[ 0.0000, -0.1043, 1.8617],
[-0.0000, -0.0000, 0.0000],
[-0.0000, -0.0000, -0.4744]])
import torch
import torch.nn as nn
d = nn.Dropout(p = 0.5, inplace = False)
input = torch.randn(4, 3)
d(input)
input
#input輸出結果
tensor([[ 1.1641, 1.6885, -0.5561],
[ 0.4439, -0.3091, -0.7204],
[-0.8396, 0.2921, -2.7595],
[ 1.7675, -0.6382, 1.6372]])
2:Dropout層在訓練和測試中的區別
在本文的第二段提到了,Dropout層利用了Bagging思想,在測試階段,每一個模型都參與了測試結果的投票,這也就是測試階段并沒有設置Dropout率,而是將此功能關閉了,那么在網絡中是如何實現的呢?
以pytorch框架為例說明,訓練過程中,網絡的模式會設置成train,即model.train(),在測試過程中網絡的模式會設置成eval,即model.eval()。當設時成model.eval()時,表示這種模式關閉了Dropout功能,除此之外也改變了BatchNorm層。其中,測試時關閉Dropout功能我們上面解釋了,對BatchNorm的改變是因為:BatchNorm訓練時是將以minibatch的形式送入網絡的,而測試是單個圖片進行的,沒有minibatch這個概念,所以測試過程中利用的是全部數據的統計結果,即均值和標準差,具體可參見。
3:Dropout層的復現問題
由于Dropout以一定比例隨機丟棄神經元這種模式的存在,使得網絡每次的訓練結果都不同,那么結果即無法復現。解決辦法:可以通過設置隨機種子來保證網絡每次丟棄固定的神經元,使結果得以復現。代碼如下:
seed = 3
torch.manual_seed(seed)
以上內容僅代表個人觀點,有不足之處還請指正。
總結
以上是生活随笔為你收集整理的dropout和bagging_Dropout的Bagging思想以及使用要点的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 讯飞AIUI+唤醒,导致唤醒监听报错10
- 下一篇: php querylist 404,Qu