深度学习的模型是怎么训练/优化出来的
以典型的分類問題為例,來梳理模型的訓(xùn)練過程。訓(xùn)練的過程就是問題發(fā)現(xiàn)的過程,一次訓(xùn)練是為下一步迭代做好指引。
1.數(shù)據(jù)準(zhǔn)備
準(zhǔn)備:
數(shù)據(jù)標(biāo)注前的標(biāo)簽體系設(shè)定要合理
用于標(biāo)注的數(shù)據(jù)集需要無偏、全面、盡可能均衡
標(biāo)注過程要審核
整理數(shù)據(jù)集
將各個標(biāo)簽的數(shù)據(jù)放于不同的文件夾中,并統(tǒng)計各個標(biāo)簽的數(shù)目
如:第一列是路徑,最后一列是圖片數(shù)目。
PS:可能會存在某些標(biāo)簽樣本很少/多,記下來模型效果不好就怨它。
樣本均衡,樣本不會絕對均衡,差不多就行了
如:控制最大類/最小類<(delta),(delta=5),最后一列為均衡的目標(biāo)值。
切分樣本集
如:90%用于訓(xùn)練,10%留著測試,比例自己定。訓(xùn)練集合,對于弱勢類要重采樣,最后的圖片列表要shuffle;測試集合就不用重采樣了。
訓(xùn)練中要保證樣本均衡,學(xué)習(xí)到弱勢類的特征,測試過程要反應(yīng)真實的數(shù)據(jù)集分布。
第一列是圖片路徑,后面幾列是標(biāo)簽(多任務(wù))。
按需要的格式生成tfrecord
按照train.list和validation.list生成需要的格式。生成和解析tfrecord的代碼要根據(jù)具體情況編寫。
2.訓(xùn)練
預(yù)處理,根據(jù)自己的喜好,編寫預(yù)處理策略。
preprocessing的方法,變換方案諸如:隨機(jī)裁剪、隨機(jī)變換框、添加光照飽和度、修改壓縮系數(shù)、各種縮放方案、多尺度等。進(jìn)而,減均值除方差或歸一化到[-1,1],將float類型的Tensor送入網(wǎng)絡(luò)。
這一步的目的是:讓網(wǎng)絡(luò)接受的訓(xùn)練樣本盡可能多樣,不要最后出現(xiàn)原圖沒問題,改改分辨率或?qū)捀弑染凸蛄说那闆r。
網(wǎng)絡(luò)設(shè)計,基礎(chǔ)網(wǎng)絡(luò)的選擇和Loss的設(shè)計。
基礎(chǔ)網(wǎng)絡(luò)的選擇和問題的復(fù)雜程度息息相關(guān),用ResNet18可以解決的沒必要用101;還有一些SE、GN等模塊加上去有沒有提升也可以去嘗試。
Loss的設(shè)計,一般問題的抽象就是設(shè)計Loss數(shù)據(jù)公式的過程。比如多任務(wù)中的各個任務(wù)權(quán)重配比,centorLoss可以讓特征分布更緊湊,SmoothL1Loss更平滑避免梯度爆炸等。
優(yōu)化算法
一般來說,只要時間足夠,Adam和SGD+Momentum可以達(dá)到的效果差異不大。用框架提供的理論上最好的優(yōu)化策略就是了。
訓(xùn)練過程
finetune網(wǎng)絡(luò),我習(xí)慣分兩步:首先訓(xùn)練fc層,迭代幾個epoch后保存模型;然后基于得到的模型,訓(xùn)練整個網(wǎng)絡(luò),一般迭代40-60個epoch可以得到穩(wěn)定的結(jié)果。
total_loss會一直下降的,過程中可以評測下模型在測試集上的表現(xiàn)。真正的loss往往包括兩部分。后面total_loss的下降主要是正則項的功勞了。
3.評估模型
1.混淆矩陣必不可少
混淆矩陣可以發(fā)現(xiàn)哪些類是難區(qū)分的。基于混淆矩陣可以得到各類的準(zhǔn)召,進(jìn)而可以得到哪些類比較差。
如:列為真值,行為檢測的值。
| gt/pl | 靴子 | 單鞋 | 運(yùn)動 | 休閑 | 棉鞋 | 雪地靴 | 帆布 | 拖鞋 | 涼鞋 | 雨鞋 |
|---|---|---|---|---|---|---|---|---|---|---|
| 靴子 | 4524 | 45 | 39 | 79 | 12 | 59 | 5 | 6 | 0 | 20 |
| 單鞋 | 51 | 4088 | 15 | 44 | 115 | 9 | 18 | 80 | 43 | 6 |
| 運(yùn)動 | 38 | 6 | 817 | 247 | 0 | 2 | 18 | 8 | 1 | 0 |
| 休閑 | 53 | 47 | 171 | 806 | 17 | 8 | 118 | 15 | 1 | 2 |
| 棉鞋 | 12 | 110 | 5 | 15 | 424 | 55 | 2 | 32 | 1 | 1 |
| 雪地靴 | 53 | 6 | 5 | 10 | 73 | 628 | 0 | 13 | 2 | 1 |
| 帆布鞋 | 5 | 28 | 16 | 158 | 1 | 1 | 515 | 17 | 3 | 4 |
| 拖鞋 | 6 | 139 | 1 | 12 | 33 | 3 | 18 | 2316 | 60 | 6 |
| 涼鞋 | 7 | 69 | 3 | 6 | 0 | 0 | 2 | 55 | 633 | 1 |
| 雨鞋 | 26 | 6 | 1 | 3 | 0 | 1 | 2 | 5 | 1 | 499 |
進(jìn)而可得:
| label | 召回 | 精度 |
|---|---|---|
| 靴子 | 0.9446648569638756 | 0.947434554973822 |
| 單鞋 | 0.9147460281942269 | 0.8996478873239436 |
| 運(yùn)動 | 0.7185576077396658 | 0.7614165890027959 |
| 休閑 | 0.6510500807754442 | 0.5840579710144927 |
| ... | ... | ... |
PS:運(yùn)動-休閑容易混淆。
2.抽樣看測試數(shù)據(jù)
從測試數(shù)據(jù)中每類抽1000張,把它們的模型結(jié)果放在不同的文件夾下。對于分析問題還是很有效的,為什么它會分錯,要拿出來看看!
有些確實是人工標(biāo)錯了。
3.CAM
通過CAM可以查看網(wǎng)絡(luò)究竟學(xué)到了什么(是不是學(xué)錯了)。對于細(xì)粒度問題就不用分析CAM了,一般7x7的特征圖本來就很小了,根本就看不出細(xì)節(jié)學(xué)到了什么,只能粗略看看部位定位是否準(zhǔn)確。
也可以一定程度上幫助理解為什么網(wǎng)絡(luò)會搞錯,比如下面的單鞋被誤判為了拖鞋。
總結(jié)
以上是生活随笔為你收集整理的深度学习的模型是怎么训练/优化出来的的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《地下城与勇士:起源》邪龙的精髓肩甲介绍
- 下一篇: 《活侠传》战术等级提升攻略-活侠传战术等