深度学习(19)神经网络与全连接层二: 测试(张量)实战
生活随笔
收集整理的這篇文章主要介紹了
深度学习(19)神经网络与全连接层二: 测试(张量)实战
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
深度學習(19)神經網絡與全連接層二: 測試(張量)實戰
- 1. 傳入測試集數據
- 2. 數據類型轉換
- 3. 創建test_db
- 4. test/evluation
- 5. 創建神經網絡
- 6. 輸出
- 7. 運行結果
- 8. 提高測試集正確率
在前向傳播的基礎上修改代碼
1. 傳入測試集數據
# x: [60k, 28, 28], [10k, 28, 28] # y: [60k], [10k] (x, y), (x_test, y_test) = datasets.mnist.load_data()2. 數據類型轉換
# 轉換數據類型 # x: [0~255] => [0~1.] x = tf.convert_to_tensor(x, dtype=tf.float32) / 255. y = tf.convert_to_tensor(y, dtype=tf.int32) x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) / 255. y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)3. 創建test_db
# 創建數據集 train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(128) test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)4. test/evluation
# test/evluation # [w1, b1, w2, b2, w3, b3], 測試集就是要測試訓練集得到的參數 total_correct, total_num = 0, 0 for step, (x, y) in enumerate(test_db):# [b, 28, 28] = > [b, 28 * 28]x = tf.reshape(x, [-1, 28*28])5. 創建神經網絡
測試集的神經網絡結構要與訓練集的神經網絡結構一樣。
# 測試集的網絡要和訓練集的網絡一樣 # [b, 784] => [b, 256] => [b, 128] => [b, 10] h1 = tf.nn.relu(x@w1 + b1) h2 = tf.nn.relu(h1@w2 + b2) out = h2@w3 + b36. 輸出
# out: [b, 10] ~ R, 輸出out是數字識別的結果, 所以范圍在[0~9]之間, R為實數 # prob: [b, 10] ~ (0~1), prob為數字識別的概率, 所以范圍在(0~1)之間 # 將out轉換為prob prob = tf.nn.softmax(out, axis=1) # pred: 預測的結果, argmax()為最大值所在的索引,也就對應0~9 # [b, 10] => [b] # int64!!! pred = tf.argmax(prob, axis=1) pred = tf.cast(pred, dtype=tf.int32) # y: [b], 做測試時, y不需要轉換成one-hot向量, 只有做訓練的時候才需要轉換 # [b], int32 # print(pred.dtype, y.dtype) correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) correct = tf.reduce_sum(correct)# total_correct: 總正確數, 因為correct為Tensor類型,所以需要轉換成int類型 # total_num: 總測試數 total_correct += int(correct) total_num += x.shape[0]7. 運行結果
可以看到,第一輪訓練完測試集的正確率僅為12.97%; 而第10輪訓練完后測試集的正確率為49.48%。
8. 提高測試集正確率
可以選擇將epoch增大,也就是將訓練的次數增多,這里由10增加到30:
for epoch in range(30): ...運行結果如下:
可以看到,30輪訓練后,正確率已經能夠達到70.2% 了,3層神經網絡的正確率在經過足夠多輪的訓練后,一般能夠達到90% 以上。
參考文獻:
[1] 龍良曲:《深度學習與TensorFlow2入門實戰》
總結
以上是生活随笔為你收集整理的深度学习(19)神经网络与全连接层二: 测试(张量)实战的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 车辆必买的4个险,车险购买方法
- 下一篇: 一用户微信转账2万转错人 对方不退还!法