知识蒸馏小总结
定義
知識(shí)蒸餾是一種模型壓縮方法,是一種基于“教師-學(xué)生網(wǎng)絡(luò)思想”的訓(xùn)練方法,由于其簡(jiǎn)單,有效,在工業(yè)界被廣泛應(yīng)用。
更簡(jiǎn)單的理解:用一個(gè)已經(jīng)訓(xùn)練好的模型去“教”另一個(gè)模型去學(xué)習(xí),這兩個(gè)模型通常稱為老師-學(xué)生模型。
用一個(gè)小例子來(lái)加深理解:
相關(guān)知識(shí)
pytorch中的損失函數(shù):
Softmax:將一個(gè)數(shù)值序列映射到概率空間
# Softmax import torch import torch.nn.functional as F# torch.nn是pytorch中自帶的一個(gè)函數(shù)庫(kù),里面包含了神經(jīng)網(wǎng)絡(luò)中使用的一些常用函數(shù), # 如具有可學(xué)習(xí)參數(shù)的nn.Conv2d(),nn.Linear()和不具有可學(xué)習(xí)的參數(shù)(如ReLU,pool,DropOut等)(后面這幾個(gè)是在nn.functional中) # 在圖片分類(lèi)問(wèn)題中,輸入m張圖片,輸出一個(gè)m*N的Tensor,其中N是分類(lèi)類(lèi)別總數(shù)。 # 比如輸入2張圖片,分三類(lèi),最后的輸出是一個(gè)2*3的Tensor,舉個(gè)例子: # torch.randn:用來(lái)生成隨機(jī)數(shù)字的tensor,這些隨機(jī)數(shù)字滿足標(biāo)準(zhǔn)正態(tài)分布(0~1) output = torch.randn(2, 3) print(output) # tensor([[-1.1639, 0.2698, 1.5513], # [-1.0839, 0.3102, -0.8798]]) # 第1,2行分別是第1,2張圖片的結(jié)果,假設(shè)第123列分別是貓、狗和豬的分類(lèi)得分。 # 可以看出模型認(rèn)為第一張為豬,第二張為狗。 然后對(duì)每一行使用Softmax,這樣可以得到每張圖片的概率分布。 print(F.softmax(output,dim=1)) # tensor([[0.1167, 0.1955, 0.6878], # [0.8077, 0.0990, 0.0933]])log_Softmax:在Softmax的基礎(chǔ)上進(jìn)行取對(duì)數(shù)運(yùn)算
# log_softmax print(F.log_softmax(output,dim=1)) print(torch.log(F.softmax(output,dim=1))) tensor([[-1.8601, -0.7688, -0.9655],[-0.9205, -1.1949, -1.2075]]) tensor([[-1.8601, -0.7688, -0.9655],[-0.9205, -1.1949, -1.2075]]) # 結(jié)果是一致的NLLLoss:對(duì)log_softmax和one-hot編碼進(jìn)行運(yùn)算
# NLLLoss print(F.nll_loss(torch.tensor([[-1.2, -0.03, -0.5]]), torch.tensor([0])))注:Tensor是張量,所以至少為[[]]!!!
# 通常我們結(jié)合 log_softmax 和 nll_loss一起用 output = torch.tensor([[1.2,3,2.6]]) target = torch.tensor([0]) print("output為[[1.2,3,2.6]],若target為第一個(gè),nll_loss為:",F.nll_loss(output,target)) target = torch.tensor([1]) print("output為[[1.2,3,2.6]],若target為第二個(gè),nll_loss為:",F.nll_loss(output,target)) target = torch.tensor([2]) print("output為[[1.2,3,2.6]],若target為第二個(gè),nll_loss為:",F.nll_loss(output,target))輸出結(jié)果: output為[[1.2,3,2.6]],若target為第一個(gè),nll_loss為: tensor(-1.2000) output為[[1.2,3,2.6]],若target為第二個(gè),nll_loss為: tensor(-3.) output為[[1.2,3,2.6]],若target為第二個(gè),nll_loss為: tensor(-2.6000)CrossEntropy:衡量?jī)蓚€(gè)概率分布的差別
output = torch.tensor([[1.2,3,2.6]]) log_softmax_output = F.log_softmax(output,dim=1) target = torch.tensor([0]) print(F.nll_loss(log_softmax_output,target))print(F.cross_entropy(output,target)) # 交叉熵自帶softmax輸出結(jié)果: tensor(2.4074) tensor(2.4074)圖解KD
圖中貓的圖片的one-hot編碼先輸入到Teacher網(wǎng)絡(luò)中進(jìn)行訓(xùn)練得到q’,在通過(guò)蒸餾得到q’’,最后得到soft targets,然后再把貓的圖片輸入到Student網(wǎng)絡(luò)中,得到hard targets并計(jì)算損失函數(shù),最后和來(lái)自Teacher網(wǎng)絡(luò)預(yù)測(cè)結(jié)果的損失函數(shù)相加得到最后的損失函數(shù)。
知識(shí)蒸餾過(guò)程
知識(shí)蒸餾應(yīng)用場(chǎng)景
知識(shí)蒸餾和遷移學(xué)習(xí)的基本區(qū)別
遷移學(xué)習(xí):是從一個(gè)領(lǐng)域獲取得模型應(yīng)用到別的領(lǐng)域的學(xué)習(xí)
知識(shí)蒸餾:是在同一個(gè)領(lǐng)域中,從大模型遷移到小模型上的學(xué)習(xí)
總結(jié)
- 上一篇: QQ浏览器调试解决方案
- 下一篇: Python 之圆周率 π 的计算