Distilling the Knowledge in a Neural Network 论文笔记蒸馏
</div><!--一個博主專欄付費入口--><!--一個博主專欄付費入口結束--><link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-833878f763.css"><div id="content_views" class="markdown_views prism-github-gist"><!-- flowchart 箭頭圖標 勿刪 --><svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg><p><img src="https://img-blog.csdnimg.cn/20181119143732887.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2JyeWFudF9tZW5n,size_16,color_FFFFFF,t_70#pic_center" alt="在這里插入圖片描述"></p>
arXiv-2015
In NIPS Deep Learning Workshop, 2014
文章目錄
- 1 Background and Motivation
- 2 Conceptual block
- 3 Knowledge Distilling
- 3.1 hard target
- 3.2 soft target
- 3.3 學 logits 和學 softmax+T 的區別
- 3.4 softmax+T 相比 logits 的優勢
- 3.5 Cost function
- 4 Dataset
- 5 Experiments
- 6 References
- 7 Appendix
- A. softmax 加 temperature 后的變化
- B. knowledge distilling (MNIST)代碼
- B.1 teacher network
- B.2 student network
本文只涉及《Distilling the Knowledge in a Neural Network》有關分類的部分,更多相關paper可以參考《Paper》
1 Background and Motivation
提高模型的 performance 一個很簡單的思路是
train many different models on the same data and then to average their predictions
缺點
- 用 ensemble 來預測結果太 cumbersome
- 可能由于計算成本太高而無法部署到大量用戶中,特別是如果單個模型是大型神經網絡的話
Caruana 證實了 ensemble model to single model 的可行性
(demonstrate convincingly that the knowledge acquired by a large ensemble of models can be transferred to a single small model)
作者采用 knowledge distilling 的方法(全新的壓縮方法)來實現這個過程(ensemble model to single model)
2 Conceptual block
1)對于模型學到的知識有個思想誤區,這些知識常被認為是模型中已經訓練好的參數。這種狹隘的思想曾一度阻礙了灌輸學習的發展,因為一旦網絡模型的結構發生變化,其所謂的知識/參數便無法得到有效利用。文中作者提出了對知識的更加宏觀、抽象的理解,知識即為網絡學習到的從輸入vectors 到輸出 vectors 之間的一種映射關系。
這樣理解的話就不局限于模型的具體結構,使得小網絡學習大網絡成為可能!
2)另外一個思想誤區是訓練的目標函數應該盡量貼近真實值。盡管如此,盡管如此,模型訓練的目的是讓模型在訓練數據集上表現盡可能好,而實際的目的卻是模型在新數據上的泛化能力。顯然,如果我們能夠訓練模型,從而使之具有優越的泛化性能,那真真是極好的!可是這幾乎是不可能的因為關于泛化的信息難以獲取。然而,在進行知識灌輸時,大模學到的泛化能力可以很自然地傳輸給小模,由于大模體型龐大泛化能力出色,由他帶出來的小模的泛化能力肯定比從頭訓練小模效果要好很多。
那么大模型的泛化性能是怎么傳給小模型的呢? 通過 soft target,大網絡 softmax 輸出(傳統 softmax 加上 temperature) 作為 label,這就是 soft target ,用小網絡的 softmax 輸出去逼近大網絡的 softmax 的輸出。對應 hard target 就是原數據集的標簽。soft target 比 hard target 好的地方如上面的 PPT。
為什么說 soft target 就包含了模型泛化性能的信息呢? 個人理解是,soft target 相對 hard target 有更多的類類關系
3 Knowledge Distilling
3.1 hard target
我們先看一下 hard target (softmax)的計算
更形象一點(來自知乎)
3.2 soft target
再看下 soft target (softmax + T)的效果
橫坐標是溫度 T,縱坐標是 soft target 的輸出 qiqiqiqiqi q_iqiqiqi??zi??21??(zi??vi?)2?=zi??vi?
3.4 softmax+T 相比 logits 的優勢
既然學 logits 和學 softmax+T 的一種特例,那么 學 softmax + T 相比之下,有哪些優勢呢?
作者做出了如下總結
- logits are almost completely unconstrained by the cost function used for training the cumbersome model so they could be very noisy
- very negative logits may convey useful information about the knowledge acquired by the cumbersome model
3.5 Cost function
小網絡的損失函數如下
從大網絡學泛化性能的時候,用比較大的T(T越大,越不自信,如果在這種不自信的情況下還能辨認類別,當測試的時候T=1,就會表現的更好,類比負重訓練)訓練,學真實數據的時候,用T = 1
將真實標簽與soft target結合起來,采用二者的加權和作為目標標簽可以獲得更好的效果。從而目標函數轉化為下式,其中,λ取小于1的數值時效果較好。
4 Dataset
MNIST
5 Experiments
網絡結構:
- 大網絡:2個隱含層,每層1200個單元,55000訓練樣本。用dropout訓練。
- 小網絡1(常規):2個隱含層,每層800個單元,無正則化。采用常規方式直接訓練。
- 小網絡1(soft):2個隱含層,每層800個單元,無正則化。采用知識灌輸法,師從大模進行訓練。T=20。
錯誤個數對比:
- 大網絡:67
- 小網絡1(常規):146
- 小網絡1(soft):74
泛化性能的實驗
為了研究小網絡的泛化能力,作者將所有數字3的圖片從transfer set 數據(訓練小網絡的數據集,可以比訓練大網絡的數據集小,也可以為空)集中刪除,也就是說小網絡在訓練過程中從未見過3這個數字。盡管如何,在測試中發現,小網絡對于數字3取得了高達98.6%的準確率。另外,即使transfer set數據集僅包含數字7和數字8的圖片,小模的錯誤率僅有13.2%。說明,小網絡從大網絡那里繼承了泛化性能!
Q1:論文中第三節,調整實驗的時候改變 bias 怎么理解?
6 References
【1】【論文導讀】Hinton - Distilling the Knowledge in a Neural Network
【2】手打例子一步一步帶你看懂softmax函數以及相關求導過程
【3】知識蒸餾(Distillation)相關論文閱讀(1)——Distilling the Knowledge in a Neural Network(以及代碼復現)
7 Appendix
A. softmax 加 temperature 后的變化
import math
import numpy as np
import matplotlib.pyplot as plt
T = np.arange(1,20,1)
y1 = (math.e**(0.9/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y2 = (math.e**(0.07/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
y3 = (math.e**(0.03/T)/(math.e**(0.07/T)+math.e**(0.9/T)+math.e**(0.03/T)))
plt.plot(T,y1)
plt.plot(T,y2)
plt.plot(T,y3)
plt.legend(["0.9", "0.07","0.03"])# 圖例
plt.grid()#網格
#plt.savefig('1.png')
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
B. knowledge distilling (MNIST)代碼
-
代碼來源(TensorFlow版本):
akimach/tensorflow-distillation-examples
也可以通過如下方式下載:鏈接:https://pan.baidu.com/s/1vDud4Iws_xnDxRqRnpyR-g 提取碼:cemy -
知識補充:
《Tensorflow | 莫煩 》learning notes
【Keras-MLP】MNIST
【TensorFlow-MLP】MNIST
MNIST training data is 60000,為什么這里是 55000,還有 5000 是 validation data
B.1 teacher network
2個隱含層,每層1200個單元,55000訓練樣本。用dropout = 0.5 訓練
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline
random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)
# 載入數據集
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
網絡層的定義
# hyper parameters
n_epochs = 50
batch_size = 50
num_nodes_h1 = 1200
num_nodes_h2 = 1200
learning_rate = 0.001
# number of batches
n_batches = len(mnist.train.images) // batch_size # 55000
# 定義 W
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
# 定義 b
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 定義 soft max with T
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp), axis=axis, keep_dims=True)
return _softmax
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
網絡結構的設計
# data
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)# drop out
# input to hidden layer 1
W_h1 = weight_variable([784, num_nodes_h1])# 784,1200
b_h1 = bias_variable([num_nodes_h1])# 1200
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1) # relu(wx+b)
h1_drop = tf.nn.dropout(h1, keep_prob) # drop out
# hidden layer 1 to hidden layer 2
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])# 1200,1200
b_h2 = bias_variable([num_nodes_h2])# 1200
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)# relu(wx+b)
h2_drop = tf.nn.dropout(h2, keep_prob) # drop out
# hidden layer 2 to output layer
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output
y = tf.nn.softmax(logits) # hard target
y_soft_target = softmax_with_temperature(logits, temp=2.0) # soft target
loss = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
用 mini-batch 開始訓練,并把訓練的模型保留下來,訓練的 loss,訓練測試的 accuracy 記錄下來
saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):# epoch
x_shuffle, y_shuffle = shuffle(mnist.train.images, mnist.train.labels)
for i in range(n_batches):# batches
start = i * batch_size
end = start + batch_size
batch_x, batch_y = x_shuffle[start:end], y_shuffle[start:end]
sess.run(train_step, feed_dict={
x: batch_x, y_: batch_y, keep_prob:0.5})
train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, keep_prob:0.5})
train_accuracy = sess.run(accuracy, feed_dict={
x: batch_x, y_: batch_y, keep_prob:1.0})
test_accuracy = sess.run(accuracy, feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0})
print(“Epoch : %i, train loss : %f, Accuracy: %f, Test accuracy: %f” % (
epoch+1, train_loss, train_accuracy, test_accuracy))
saver.save(sess, “/root/userfolder/Experiment/tensorflow-distillation-examples/model_teacher/”,
global_step=epoch+1)# 只保留最新的幾個 epoch
losses.append(train_loss)
accs.append(train_accuracy)
test_accs.append(test_accuracy)
print("… completed!")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
output
Epoch : 1, train loss : 0.737658, Accuracy: 0.880000, Test accuracy: 0.870400
Epoch : 2, train loss : 0.761208, Accuracy: 0.900000, Test accuracy: 0.877700
Epoch : 3, train loss : 0.589437, Accuracy: 0.920000, Test accuracy: 0.890600
Epoch : 4, train loss : 0.643363, Accuracy: 0.900000, Test accuracy: 0.899900
Epoch : 5, train loss : 0.616038, Accuracy: 0.900000, Test accuracy: 0.900900
Epoch : 6, train loss : 0.611822, Accuracy: 0.860000, Test accuracy: 0.907100
Epoch : 7, train loss : 0.644078, Accuracy: 0.860000, Test accuracy: 0.909100
Epoch : 8, train loss : 0.402896, Accuracy: 0.960000, Test accuracy: 0.911100
Epoch : 9, train loss : 0.572901, Accuracy: 0.960000, Test accuracy: 0.907900
Epoch : 10, train loss : 0.517088, Accuracy: 0.900000, Test accuracy: 0.914600
Epoch : 11, train loss : 0.410240, Accuracy: 0.960000, Test accuracy: 0.914300
Epoch : 12, train loss : 0.945823, Accuracy: 0.800000, Test accuracy: 0.916200
Epoch : 13, train loss : 0.579927, Accuracy: 0.900000, Test accuracy: 0.917000
Epoch : 14, train loss : 0.503660, Accuracy: 0.860000, Test accuracy: 0.918300
Epoch : 15, train loss : 0.532867, Accuracy: 0.940000, Test accuracy: 0.918600
Epoch : 16, train loss : 0.430909, Accuracy: 0.940000, Test accuracy: 0.920300
Epoch : 17, train loss : 0.507866, Accuracy: 0.920000, Test accuracy: 0.920600
Epoch : 18, train loss : 0.453426, Accuracy: 0.920000, Test accuracy: 0.925200
Epoch : 19, train loss : 0.689311, Accuracy: 0.920000, Test accuracy: 0.926600
Epoch : 20, train loss : 0.379545, Accuracy: 0.940000, Test accuracy: 0.926100
Epoch : 21, train loss : 0.431786, Accuracy: 0.920000, Test accuracy: 0.926800
Epoch : 22, train loss : 0.401257, Accuracy: 0.960000, Test accuracy: 0.927300
Epoch : 23, train loss : 0.587902, Accuracy: 0.960000, Test accuracy: 0.928600
Epoch : 24, train loss : 0.620417, Accuracy: 0.880000, Test accuracy: 0.927400
Epoch : 25, train loss : 0.365211, Accuracy: 0.940000, Test accuracy: 0.929500
Epoch : 26, train loss : 0.427130, Accuracy: 0.960000, Test accuracy: 0.930300
Epoch : 27, train loss : 0.253452, Accuracy: 0.900000, Test accuracy: 0.930800
Epoch : 28, train loss : 0.427312, Accuracy: 0.920000, Test accuracy: 0.930900
Epoch : 29, train loss : 0.419188, Accuracy: 0.900000, Test accuracy: 0.933100
Epoch : 30, train loss : 0.268312, Accuracy: 0.940000, Test accuracy: 0.933800
Epoch : 31, train loss : 0.346375, Accuracy: 0.920000, Test accuracy: 0.933500
Epoch : 32, train loss : 0.292108, Accuracy: 0.960000, Test accuracy: 0.933000
Epoch : 33, train loss : 0.436444, Accuracy: 0.960000, Test accuracy: 0.935100
Epoch : 34, train loss : 0.278850, Accuracy: 0.940000, Test accuracy: 0.934900
Epoch : 35, train loss : 0.277737, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 36, train loss : 0.425431, Accuracy: 0.940000, Test accuracy: 0.937300
Epoch : 37, train loss : 0.359413, Accuracy: 0.940000, Test accuracy: 0.937800
Epoch : 38, train loss : 0.338502, Accuracy: 0.960000, Test accuracy: 0.937600
Epoch : 39, train loss : 0.433313, Accuracy: 0.880000, Test accuracy: 0.937100
Epoch : 40, train loss : 0.529199, Accuracy: 0.860000, Test accuracy: 0.938700
Epoch : 41, train loss : 0.657401, Accuracy: 0.920000, Test accuracy: 0.938500
Epoch : 42, train loss : 0.491150, Accuracy: 0.920000, Test accuracy: 0.938600
Epoch : 43, train loss : 0.334091, Accuracy: 0.940000, Test accuracy: 0.940200
Epoch : 44, train loss : 0.298908, Accuracy: 0.940000, Test accuracy: 0.941000
Epoch : 45, train loss : 0.303939, Accuracy: 0.940000, Test accuracy: 0.939800
Epoch : 46, train loss : 0.378838, Accuracy: 0.940000, Test accuracy: 0.939500
Epoch : 47, train loss : 0.323622, Accuracy: 0.920000, Test accuracy: 0.941700
Epoch : 48, train loss : 0.280403, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 49, train loss : 0.390651, Accuracy: 0.920000, Test accuracy: 0.942800
Epoch : 50, train loss : 0.614632, Accuracy: 0.900000, Test accuracy: 0.941700
... completed!
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
可視化訓練的loss
#查看訓練的損失變化
plt.title("Loss of teacher")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
可視化訓練和測試的 accuracy
# 查看訓練精度和測試精度的變化
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
把 訓練的 loss,訓練測試的 accuracy 保存下來
# 保存訓練loss 和 accuracy 以及測試的 accuracy
np.save("loss_teacher.npy", np.array(losses))
np.save("acc_train_teacher.npy", np.array(accs))
np.save("acc_test_teacher.npy", np.array(test_accs))
- 1
- 2
- 3
- 4
保存 teacher network 的soft target,我們選擇表現好一點 epoch 訓練結果,下面的保存的 第48個 epoch
# 保存 第48個 epoch 的soft target
_soft_targets = []
with tf.Session() as sess:saver.restore(sess, "./model_teacher/-48")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x = mnist.train.images[start:end]soft_target = sess.run(y_soft_target, feed_dict={x: batch_x, keep_prob:1.0})_soft_targets.append(soft_target)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
看下 _sotf_targets 的形式,reshape一下
np.shape(_soft_targets)# (1100, 50, 10) = (batch,batch_size,classes)
soft_targets = np.c_[_soft_targets].reshape(55000, 10)# reshape (5500,10)
- 1
- 2
對比下 soft target 和 hard target
print(soft_targets[:2])
print(mnist.train.labels[:2]) # label 可以和上面的softmax 預測結果對比一下
- 1
- 2
output
[[5.2621812e-03 6.1693429e-03 1.5207376e-01 6.1155759e-02 1.4845385e-024.8464271e-03 3.6828788e-03 6.0641229e-01 2.9818511e-02 1.1573344e-01][2.4089564e-03 2.6752956e-03 1.8253580e-02 8.5861373e-01 3.0618338e-041.7423177e-02 9.3506598e-05 3.6187540e-03 8.3464541e-02 1.3142269e-02]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.][0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]
- 1
- 2
- 3
- 4
- 5
- 6
保存 teacher network 的 soft target,方便 student network learning
np.save('soft-targets.npy', soft_targets)
- 1
查看其 shape
np.load(file="soft-targets.npy").shape
- 1
output
(55000, 10)
- 1
B.2 student network
和 teacher network 的區別是 hidden layer 的大小(1200 to 600,論文中是800),以及loss的變化,其它一樣
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
%matplotlib inline
random.seed(123)
np.random.seed(123)
tf.set_random_seed(123)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
mnist = input_data.read_data_sets("/root/userfolder/Experiment/MNIST_data/", one_hot=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
載入 teacher network 的 soft target
soft_targets = np.load(file="soft-targets.npy")
print(np.shape(soft_targets))
- 1
- 2
output
(55000, 10)
- 1
hyper parameters 設置,W,b ,soft target 的定義
n_epochs = 50
batch_size = 50
num_nodes_h1 = 600 # Before 800
num_nodes_h2 = 600 # Before 800
learning_rate = 0.001
n_batches = len(mnist.train.images) // batch_size
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def softmax_with_temperature(logits, temp=1.0, axis=1, name=None):
logits_with_temp = logits / temp
_softmax = tf.exp(logits_with_temp) / tf.reduce_sum(tf.exp(logits_with_temp),
axis=axis, keep_dims=True)
return _softmax
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
網絡的設計
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
soft_target_ = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
T = tf.placeholder(tf.float32)
W_h1 = weight_variable([784, num_nodes_h1])
b_h1 = bias_variable([num_nodes_h1])
h1 = tf.nn.relu(tf.matmul(x, W_h1) + b_h1)
h1_drop = tf.nn.dropout(h1, keep_prob)
W_h2 = weight_variable([num_nodes_h1, num_nodes_h2])
b_h2 = bias_variable([num_nodes_h2])
h2 = tf.nn.relu(tf.matmul(h1_drop, W_h2) + b_h2)
h2_drop = tf.nn.dropout(h2, keep_prob)# 還是用了drop out
W_output = tf.Variable(tf.zeros([num_nodes_h2, 10]))
b_output = tf.Variable(tf.zeros([10]))
logits = tf.matmul(h2_drop, W_output) + b_output
y = tf.nn.softmax(logits)
y_soft_target = softmax_with_temperature(logits, temp=T)
loss_hard_target = -tf.reduce_sum(y_ tf.log(y), reduction_indices=[1])
loss_soft_target = -tf.reduce_sum(soft_target_ tf.log(y_soft_target),
reduction_indices=[1])
loss = tf.reduce_mean(tf.square(T) loss_hard_target + tf.square(T) loss_soft_target)
train_step = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
開始訓練,高溫訓練,低溫測試
saver = tf.train.Saver()
losses = []
accs = []
test_accs = []
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(n_epochs):x_shuffle, y_shuffle, soft_targets_shuffle \= shuffle(mnist.train.images, mnist.train.labels, soft_targets)for i in range(n_batches):start = i * batch_sizeend = start + batch_sizebatch_x, batch_y, batch_soft_targets \= x_shuffle[start:end], y_shuffle[start:end], soft_targets_shuffle[start:end]sess.run(train_step, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})train_loss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, soft_target_:batch_soft_targets, keep_prob:0.5, T:2.0})# 高溫訓練train_accuracy = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y, keep_prob:1.0, T:1.0})test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0, T:1.0})# 低溫測試print("Epoch : %i, Loss : %f, Accuracy: %f, Test accuracy: %f" % (epoch+1, train_loss, train_accuracy, test_accuracy))saver.save(sess, "/root/userfolder/Experiment/tensorflow-distillation-examples/model_student/", global_step=epoch+1)losses.append(train_loss)accs.append(train_accuracy)test_accs.append(test_accuracy)print("... completed!")
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
output,可以看出,結果青出于藍
Epoch : 1, Loss : 7.137307, Accuracy: 0.860000, Test accuracy: 0.868200
Epoch : 2, Loss : 5.926404, Accuracy: 0.940000, Test accuracy: 0.892200
Epoch : 3, Loss : 5.597841, Accuracy: 0.920000, Test accuracy: 0.901400
Epoch : 4, Loss : 5.938632, Accuracy: 0.920000, Test accuracy: 0.913000
Epoch : 5, Loss : 5.872798, Accuracy: 0.920000, Test accuracy: 0.915800
Epoch : 6, Loss : 5.436497, Accuracy: 0.920000, Test accuracy: 0.919300
Epoch : 7, Loss : 5.455486, Accuracy: 0.880000, Test accuracy: 0.924100
Epoch : 8, Loss : 4.402141, Accuracy: 0.980000, Test accuracy: 0.927100
Epoch : 9, Loss : 5.413333, Accuracy: 0.960000, Test accuracy: 0.929700
Epoch : 10, Loss : 4.503023, Accuracy: 0.960000, Test accuracy: 0.931900
Epoch : 11, Loss : 4.971416, Accuracy: 0.960000, Test accuracy: 0.934800
Epoch : 12, Loss : 6.448879, Accuracy: 0.880000, Test accuracy: 0.937300
Epoch : 13, Loss : 6.164934, Accuracy: 0.920000, Test accuracy: 0.939000
Epoch : 14, Loss : 5.904130, Accuracy: 0.880000, Test accuracy: 0.940200
Epoch : 15, Loss : 5.206109, Accuracy: 0.940000, Test accuracy: 0.941200
Epoch : 16, Loss : 4.704682, Accuracy: 0.960000, Test accuracy: 0.942000
Epoch : 17, Loss : 4.707399, Accuracy: 0.940000, Test accuracy: 0.943000
Epoch : 18, Loss : 4.608377, Accuracy: 0.940000, Test accuracy: 0.944000
Epoch : 19, Loss : 6.394137, Accuracy: 0.900000, Test accuracy: 0.944600
Epoch : 20, Loss : 4.419221, Accuracy: 0.980000, Test accuracy: 0.944900
Epoch : 21, Loss : 4.322970, Accuracy: 0.960000, Test accuracy: 0.946800
Epoch : 22, Loss : 3.958002, Accuracy: 0.960000, Test accuracy: 0.946400
Epoch : 23, Loss : 4.949951, Accuracy: 0.960000, Test accuracy: 0.947600
Epoch : 24, Loss : 5.640293, Accuracy: 0.900000, Test accuracy: 0.947100
Epoch : 25, Loss : 4.615621, Accuracy: 0.940000, Test accuracy: 0.948300
Epoch : 26, Loss : 4.853579, Accuracy: 0.940000, Test accuracy: 0.948600
Epoch : 27, Loss : 4.839081, Accuracy: 0.960000, Test accuracy: 0.949700
Epoch : 28, Loss : 4.525964, Accuracy: 0.940000, Test accuracy: 0.950600
Epoch : 29, Loss : 5.636992, Accuracy: 0.940000, Test accuracy: 0.950700
Epoch : 30, Loss : 4.566214, Accuracy: 0.980000, Test accuracy: 0.951200
Epoch : 31, Loss : 4.846083, Accuracy: 0.960000, Test accuracy: 0.951300
Epoch : 32, Loss : 4.274162, Accuracy: 0.980000, Test accuracy: 0.951700
Epoch : 33, Loss : 4.423202, Accuracy: 0.960000, Test accuracy: 0.951800
Epoch : 34, Loss : 4.516046, Accuracy: 0.940000, Test accuracy: 0.952200
Epoch : 35, Loss : 3.987510, Accuracy: 0.940000, Test accuracy: 0.952900
Epoch : 36, Loss : 4.587525, Accuracy: 0.940000, Test accuracy: 0.953200
Epoch : 37, Loss : 4.149089, Accuracy: 0.960000, Test accuracy: 0.953300
Epoch : 38, Loss : 4.955534, Accuracy: 0.940000, Test accuracy: 0.953900
Epoch : 39, Loss : 5.080862, Accuracy: 0.960000, Test accuracy: 0.954700
Epoch : 40, Loss : 5.033619, Accuracy: 0.900000, Test accuracy: 0.954500
Epoch : 41, Loss : 5.110637, Accuracy: 0.940000, Test accuracy: 0.954100
Epoch : 42, Loss : 5.486012, Accuracy: 0.940000, Test accuracy: 0.954300
Epoch : 43, Loss : 4.117889, Accuracy: 0.980000, Test accuracy: 0.955800
Epoch : 44, Loss : 3.833005, Accuracy: 0.940000, Test accuracy: 0.955900
Epoch : 45, Loss : 4.636988, Accuracy: 0.960000, Test accuracy: 0.954500
Epoch : 46, Loss : 5.074997, Accuracy: 0.940000, Test accuracy: 0.955700
Epoch : 47, Loss : 4.291631, Accuracy: 0.960000, Test accuracy: 0.954800
Epoch : 48, Loss : 4.045475, Accuracy: 0.960000, Test accuracy: 0.956500
Epoch : 49, Loss : 4.960283, Accuracy: 0.920000, Test accuracy: 0.957400
Epoch : 50, Loss : 5.411842, Accuracy: 0.940000, Test accuracy: 0.956300
... completed!
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
可視化 training loss
plt.title("Loss of student")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.plot(range(1, len(losses)+1), losses, label='train_loss')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
可視化一下訓練和測試的 accuracy
plt.title("Accuracy of teacher")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.plot(range(1, len(accs)+1), accs, label='Training')
plt.plot(range(1, len(test_accs)+1), test_accs, label='Test')
plt.legend()
plt.show()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
看一下某個模型的精度
with tf.Session() as sess:saver.restore(sess, "./model_student/-49")print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))
- 1
- 2
- 3
output
INFO:tensorflow:Restoring parameters from ./model_student/-49
0.9574
- 1
- 2
保存一下 精度和損失
np.save("loss_student.npy", np.array(losses))
np.save("acc_student.npy", np.array(accs))
np.save("acc_test_student.npy", np.array(test_accs))
- 1
- 2
- 3
</div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-b6c3c6d139.css" rel="stylesheet"></div>
總結
以上是生活随笔為你收集整理的Distilling the Knowledge in a Neural Network 论文笔记蒸馏的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: distill bert 相关问题
- 下一篇: pytorch中调整学习率的lr_sch