WeightedRandomSampler 理解了吧
WeightedRandomSampler
?
sampler = WeightedRandomSampler(samples_weight, samples_num)
train_loader = DataLoader( train_dataset, batch_size=bs, num_workers=1, sampler=sampler)
?
我的數據不平衡,使用pytorch,發現WeightedRandomSampler這個東西,網上找了一圈,有點會用了,就是上面這個用法,但是理解了很久才知道為什么這么用。
最大的問題就是不能理解WeightedRandomSampler是怎么運作的。除了官方解釋,其他也沒有找到更有用的信息了。
現在我覺得是有點理解了。
?
官方解釋是:
?
還給了例子:
?
然后不是很懂,還是不知道怎么用。感覺這個例子卻是不是很好說明問題,也是我理解能力太差,多試幾次才懂了。
?
我換一個例子如下:
list(WeightedRandomSampler([1, 9], 5, replacement=True))上面這句話反復運行,你猜怎么著?
我每次運行的結果如下:(你的結果肯定不一樣)
[1, 0, 1, 1, 1][1, 1, 1, 1, 1][1, 1, 1, 0, 1][1, 1, 1, 1, 1][1, 1, 1, 1, 1][1, 1, 0, 1, 1]有點理解了吧?
這個5代表要生成5個數,這5個數是誰呢? 取決于前面【】內的數的數量,上面【】內有2個數,根據上面[0,..,len(weights)-1],即生成的數是0-1之間的任意數。
那這5個數到底是幾,有10%的概率是0,有90%的概率是1。
理解了吧?其他參數不解釋了。
?
使用
有一種通常的用法是:(不限于此)
假設分類問題,分為3類。
sampler = WeightedRandomSampler(samples_weight,samples_num)
samples_weight的數量等于我們訓練集總樣本的數量,假設為1000。
samples_weight的每一項代表該樣本種類占總樣本的比例的倒數。
samples_num 為我們想采集多少個樣本,可以重復采集。假設為2000。
?
假設3類樣本分布比例為 貓,狗,豬 = ?0.1,0.2,0.7
Count = [0.1,0.2,0.7]
Weight = 1/Count = [10,5,1.43] 約等于[0.7,0.2,0.1]
?
samples_weight內全是 10或5或1.43,是10代表該樣本是貓...
假設samples_weight內樣子是:
[10,5,5,1.43,1.43,1.43,1.43.......,10]
10的數量最少,但是權重最大,所以達到了樣本平衡的效果。
?
所以結合上面的WeightedRandomSampler的使用:
會生成樣本總數個數即2000個數,
每個數可能是0-999之間的某個數,
每個數:(和samples_weight內數值對應)
是0的概率為 10/sum(samples_weight)
是1的概率為5/sum(samples_weight)
是2的概率為1.43/sum(samples_weight)
是3的概率為1.43/sum(samples_weight)
是4的概率為1.43/sum(samples_weight)
......
是999的概率為 10/sum(samples_weight)
?
把取出來的數字作為index,DataLoader就取用了。
?
?
end
目前的理解,難免有疏漏錯誤,還望大佬們多多指正。
?
?
?
?
?
?
?
?
?
?
?
?
?
?
?
?
?
?
《新程序員》:云原生和全面數字化實踐50位技術專家共同創作,文字、視頻、音頻交互閱讀總結
以上是生活随笔為你收集整理的WeightedRandomSampler 理解了吧的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: matlab常用代码总结
- 下一篇: SppNet 多尺度训练