TensorFlow从1到2(四)时尚单品识别和保存、恢复训练数据
Fashion Mnist --- 一個(gè)圖片識(shí)別的延伸案例
在TensorFlow官方新的教程中,第一個(gè)例子使用了由MNIST延伸而來(lái)的新程序。
這個(gè)程序使用一組時(shí)尚單品的圖片對(duì)模型進(jìn)行訓(xùn)練,比如T恤(T-shirt)、長(zhǎng)褲(Trouser),訓(xùn)練完成后,對(duì)于給定圖片,可以識(shí)別出單品的名稱。
程序同樣將所有圖片規(guī)范為28x28點(diǎn)陣,使用灰度圖,每個(gè)字節(jié)取值范圍0-255。時(shí)尚單品的類型,同樣也是分為10類,跟手寫(xiě)數(shù)字識(shí)別的分類維度相同。因此實(shí)際上,這個(gè)例子看起來(lái)美觀也有趣很多,但是在技術(shù)層面上,跟傳統(tǒng)的MNIST沒(méi)有區(qū)別。
不同的地方也有,首先是識(shí)別之后需要顯示的是單品名稱,而不是0-9的數(shù)字,所以程序中需要定義一個(gè)標(biāo)簽數(shù)組,并在顯示時(shí)做一個(gè)轉(zhuǎn)換:
其次,從樣本圖片中你應(yīng)當(dāng)能看出來(lái),圖片的復(fù)雜度,比手寫(xiě)數(shù)字還是高多了。從而造成的混淆和誤判,顯然也高的多。這種情況下,只使用tf.argmax()獲取確定的一個(gè)標(biāo)簽就有點(diǎn)不足了。所以在這個(gè)例子中,增加了使用直方圖,顯示所有10個(gè)預(yù)測(cè)分類中,每個(gè)分類的相似度功能。同時(shí),預(yù)測(cè)正確的,用藍(lán)色字體表示。預(yù)測(cè)結(jié)果同樣本標(biāo)注不同的,使用紅色字體表示。
完整的代碼如下:
#!/usr/bin/env python3from __future__ import absolute_import, division, print_function# TensorFlow and tf.keras import tensorflow as tf from tensorflow import keras# Helper libraries import numpy as np import matplotlib.pyplot as plt# 顯示樣本集中,指定圖片、預(yù)測(cè)信息、標(biāo)注信息 def plot_image(i, predictions_array, true_label, img):predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img, cmap=plt.cm.binary)predicted_label = tf.argmax(predictions_array)if predicted_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],100*np.max(predictions_array),class_names[true_label]),color=color)# 使用柱狀圖顯示預(yù)測(cè)結(jié)果數(shù)組,每一個(gè)柱狀圖,代表圖片屬于該類的可能性 def plot_value_array(i, predictions_array, true_label):predictions_array, true_label = predictions_array[i], true_label[i]plt.grid(False)plt.xticks([])plt.yticks([])thisplot = plt.bar(range(10), predictions_array, color="#777777")plt.ylim([0, 1])predicted_label = tf.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')# 加載Fashion Mnist數(shù)據(jù)集,第一次執(zhí)行的時(shí)候會(huì)自動(dòng)從網(wǎng)上下載,這個(gè)速度會(huì)比較慢 fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()# 如同數(shù)字識(shí)別的0-9十類,這里也將時(shí)尚潮品分了以下十類 # 所以本質(zhì)上,這跟手寫(xiě)數(shù)字的識(shí)別是完全一致的 class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 數(shù)據(jù)規(guī)范化,將圖片數(shù)據(jù)轉(zhuǎn)化為0-1之間的浮點(diǎn)數(shù)字 train_images = train_images / 255.0 test_images = test_images / 255.0# 為了有一個(gè)直觀印象,我們把訓(xùn)練集前24個(gè)樣本圖片顯示在屏幕上,同時(shí)顯示圖片的標(biāo)注信息 # 你可能注意到了,我們?cè)陲@示圖片的時(shí)候,并沒(méi)有跟前面顯示手寫(xiě)字體圖片一樣,把圖片的規(guī)范化數(shù)據(jù)還原為0-255, # 這是因?yàn)閷?shí)際上mathplotlib庫(kù)可以直接接受浮點(diǎn)型的圖像數(shù)據(jù), # 我們前面首先還原規(guī)范化數(shù)據(jù),是為了讓你清楚理解原始數(shù)據(jù)的格式。 plt.figure(figsize=(8, 6)) for i in range(24):plt.subplot(4, 6, i+1)plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]]) plt.show()# 定義神經(jīng)網(wǎng)絡(luò)模型,用了一個(gè)比較簡(jiǎn)單的模型 model = keras.Sequential([keras.layers.Flatten(input_shape=(28, 28)),keras.layers.Dense(128, activation='relu'),keras.layers.Dense(10, activation='softmax') ])# 采用指定的優(yōu)化器和損失函數(shù)編譯模型 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 訓(xùn)練模型 model.fit(train_images, train_labels, epochs=15)# 使用測(cè)試集數(shù)據(jù)評(píng)估訓(xùn)練后的模型,并顯示評(píng)估結(jié)果 test_loss, test_acc = model.evaluate(test_images, test_labels) print('\nTest accuracy:', test_acc)######### # 預(yù)測(cè)所有測(cè)試集數(shù)據(jù),用于圖形顯示結(jié)果 predictions = model.predict(test_images)# 以5行x3列顯示測(cè)試集前15個(gè)樣本的圖片和預(yù)測(cè)結(jié)果 # 正確的預(yù)測(cè)結(jié)果藍(lán)色顯示,錯(cuò)誤的預(yù)測(cè)信息會(huì)紅色顯示 # 每一張圖片的右側(cè),會(huì)顯示圖片預(yù)測(cè)的結(jié)果數(shù)組,這個(gè)數(shù)組中,數(shù)值最大的,代表最可能的分類 # 或者說(shuō),每一個(gè)數(shù)組元素,都代表圖片屬于對(duì)應(yīng)分類的可能性 num_rows = 5 num_cols = 3 num_images = num_rows*num_cols plt.figure(figsize=(2*2*num_cols, 2*num_rows)) for i in range(num_images):plt.subplot(num_rows, 2*num_cols, 2*i+1)plot_image(i, predictions, test_labels, test_images)plt.subplot(num_rows, 2*num_cols, 2*i+2)plot_value_array(i, predictions, test_labels) plt.show()############# # 演示預(yù)測(cè)單獨(dú)一幅圖片 # 從測(cè)試集獲取一幅圖 img = test_images[0] # 我們的模型是批處理進(jìn)行預(yù)測(cè)的,要求的是一個(gè)圖片的數(shù)組,所以這里擴(kuò)展一維 # 成為(1, 28, 28)這樣的形式 img = (np.expand_dims(img, 0)) # 使用模型進(jìn)行預(yù)測(cè) predictions_single = model.predict(img) # 顯示預(yù)測(cè)結(jié)果數(shù)組 print("test_images[0] prediction array:", predictions_single) # 顯示轉(zhuǎn)換為可識(shí)別類型的預(yù)測(cè)結(jié)果 print("test_images[0] prediction text:", class_names[tf.argmax(predictions_single[0])]) # 顯示原標(biāo)注 print("test_labels[0]:", class_names[test_labels[0]]) # 原圖的顯示請(qǐng)參考上面大圖的左上角第一幅,此處略程序最后還演示了使用1幅圖片數(shù)據(jù)調(diào)用模型進(jìn)行預(yù)測(cè)的方式。特別不要忘記把這一幅圖片擴(kuò)展一維再進(jìn)入模型,因?yàn)槲覀兊哪P褪鞘褂门幚矸绞竭M(jìn)行預(yù)測(cè)的,原本接受的是一個(gè)圖片的數(shù)組。
程序在第一次執(zhí)行的時(shí)候,會(huì)自動(dòng)由網(wǎng)上下載數(shù)據(jù)集,下載的網(wǎng)址在下面的顯示信息中能看到。下載完成后,數(shù)據(jù)會(huì)存放在~/.keras/datasets/fashion-mnist/文件夾。
以后再運(yùn)行程序的時(shí)候,程序就直接使用本地?cái)?shù)據(jù)運(yùn)行。執(zhí)行過(guò)程所顯示的信息類似下面:
$ ./fashion_mnist.py Epoch 1/15 60000/60000 [==============================] - 4s 68us/sample - loss: 0.4999 - accuracy: 0.8247 Epoch 2/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.3753 - accuracy: 0.8652 Epoch 3/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.3361 - accuracy: 0.8783 Epoch 4/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.3120 - accuracy: 0.8848 Epoch 5/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.2950 - accuracy: 0.8916 Epoch 6/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.2825 - accuracy: 0.8950 Epoch 7/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.2681 - accuracy: 0.9004 Epoch 8/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.2564 - accuracy: 0.9052 Epoch 9/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.2463 - accuracy: 0.9088 Epoch 10/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.2385 - accuracy: 0.9118 Epoch 11/15 60000/60000 [==============================] - 5s 79us/sample - loss: 0.2299 - accuracy: 0.9145 Epoch 12/15 60000/60000 [==============================] - 4s 72us/sample - loss: 0.2224 - accuracy: 0.9165 Epoch 13/15 60000/60000 [==============================] - 4s 65us/sample - loss: 0.2152 - accuracy: 0.9192 Epoch 14/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.2093 - accuracy: 0.9214 Epoch 15/15 60000/60000 [==============================] - 4s 64us/sample - loss: 0.2031 - accuracy: 0.9227 10000/10000 [==============================] - 0s 38us/sample - loss: 0.3361 - accuracy: 0.8889Test accuracy: 0.8889 test_images[0] prediction array: [[2.8952907e-09 4.0831842e-06 9.7278274e-08 1.6851689e-09 5.8218838e-083.0680697e-03 1.2691763e-07 1.8435927e-02 3.7783199e-08 9.7849166e-01]] test_images[0] prediction text: Ankle boot test_labels[0]: Ankle boot程序執(zhí)行中,測(cè)試集前15幅圖片的驗(yàn)證結(jié)果顯示如下:
左下角的圖片出現(xiàn)了明顯的識(shí)別錯(cuò)誤。不過(guò)話說(shuō)回來(lái),以我這種時(shí)尚盲人來(lái)說(shuō),也完全區(qū)分不出來(lái)這種樣子的涼鞋跟運(yùn)動(dòng)鞋有啥區(qū)別(手動(dòng)捂臉),當(dāng)然圖片的分辨率也是問(wèn)題之一啦。
保存和恢復(fù)訓(xùn)練數(shù)據(jù)
TensorFlow 2.0提供了兩種數(shù)據(jù)保存和恢復(fù)的方式。第一種方式是我們?cè)赥ensorFlow 1.x中經(jīng)常用的保存模型權(quán)重參數(shù)的方式。
因?yàn)樵赥ensorFlow 2.0中,我們使用了model.fit方法來(lái)代替之前使用的訓(xùn)練循環(huán),所以保存訓(xùn)練權(quán)重?cái)?shù)據(jù)是使用回調(diào)函數(shù)的方式完成的。下面舉一個(gè)例子:
這樣在每一個(gè)訓(xùn)練周期,都會(huì)將訓(xùn)練數(shù)據(jù)寫(xiě)入到文件,屏幕顯示會(huì)類似這樣:
Epoch 1/15 60000/60000 [==============================] - 4s 68us/sample - loss: 0.4999 - accuracy: 0.8247 Epoch 00001: saving model to training_data/cp.ckpt Epoch 2/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.3753 - accuracy: 0.8652 Epoch 00002: saving model to training_data/cp.ckpt Epoch 3/15 60000/60000 [==============================] - 4s 63us/sample - loss: 0.3361 - accuracy: 0.8783 Epoch 00003: saving model to training_data/cp.ckpt Epoch 4/15......對(duì)于稍大的數(shù)據(jù)集和稍微復(fù)雜的模型,訓(xùn)練的時(shí)間會(huì)非常之長(zhǎng)。通常我們都會(huì)把這種工作部署到有強(qiáng)大算力的服務(wù)器上執(zhí)行。訓(xùn)練完成,將訓(xùn)練數(shù)據(jù)保存下來(lái)。預(yù)測(cè)的時(shí)候,則并不需要很大的運(yùn)算量,就可以在普通的設(shè)備上執(zhí)行了。
還原保存的數(shù)據(jù),其實(shí)就是把fit方法這一句,替換為加載保存的數(shù)據(jù)就可以:
這種方法是比較多用的,因?yàn)楹芏嗲闆r下,我們訓(xùn)練所使用的模型,跟預(yù)測(cè)所使用的模型,會(huì)有細(xì)微的調(diào)整。這時(shí)候只載入模型的權(quán)重值,并不影響模型的微調(diào)。
此外,上面的代碼僅為示例。在實(shí)際應(yīng)用中,這種不改變文件名、只保存一組文件的形式,實(shí)際并不需要回調(diào)函數(shù),在訓(xùn)練完成后一次寫(xiě)入到文件是更好的選擇。使用回調(diào)函數(shù)通常都是為了保存每一步的訓(xùn)練結(jié)果。
保存完整模型
如果模型是比較成熟穩(wěn)定的,我們很可能喜歡完整的保存整個(gè)模型,這樣不僅操作容易,而且也省去了重新建模的工作。Keras內(nèi)置的vgg-19/resnet50等模型,實(shí)際就使用了這種方式,我們會(huì)在下一篇詳細(xì)介紹。
保存完整的模型非常簡(jiǎn)單,只要在model.fit執(zhí)行完成后,一行代碼就可以保存完整、包含權(quán)重參數(shù)的模型:
還原完整模型的話,則可以從使用keras.Sequential開(kāi)始定義模型、模型編譯都不需要,直接使用:
new_model = keras.models.load_model('fashion_mnist.h5')接著就可以使用new_model這個(gè)模型進(jìn)行預(yù)測(cè)了。
(待續(xù)...)
轉(zhuǎn)載于:https://www.cnblogs.com/andrewwang/p/10709914.html
總結(jié)
以上是生活随笔為你收集整理的TensorFlow从1到2(四)时尚单品识别和保存、恢复训练数据的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: CSS 布局与“仓库管理”的关系
- 下一篇: Android Jetpack组件之数据