深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist)
Mnist數據集是深度學習入門的數據集,昨天發現了Chinese-Mnist數據集,與Mnist數據集類似,只不過是漢字數字,例如‘一’、‘二’、‘三’等,本次實驗利用自己搭建的CNN網絡實現Chinese版的手寫數字識別。
1.導入庫
import tensorflow as tf import matplotlib.pyplot as plt import os,PIL,pathlib import numpy as np import pandas as pd import warnings from tensorflow import keraswarnings.filterwarnings("ignore")#忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標簽 plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號 os.environ['TF_CPP_MIN_LOG_LEVEL']='2'2.數據加載
原數據中包括15000張圖片,如下所示:
原數據并沒有將各類數據分開,而是給出了一個csv文件:
在進行訓練之前將圖片分類,首先對數據的標簽進行切片
統計每張圖片的具體路徑:
#訓練數據的具體路徑 img_dir = "E:/tmp/.keras/datasets/chinese_mnist/data/data/input" train_image_paths = [] for row in train.itertuples():suite_id = row[1]sample_id = row[2]code = row[3]train_image_paths.append(img_dir+"_"+str(suite_id)+"_"+str(sample_id)+"_"+str(code)+".jpg") #對圖片路徑進行切片 train_path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)train_image_paths結果如下:
E:/tmp/.keras/datasets/chinese_mnist/data/data/input_1_1_10.jpg讀取圖片并進行預處理,然后切片
#圖片預處理 def preprocess_image(image):image = tf.image.decode_jpeg(image,channels = 3)image = tf.image.resize(image,[height,width])return image / 255.0 def load_and_preprocess_image(path):image = tf.io.read_file(path)return preprocess_image(image) #根據路徑讀取圖片并進行預處理 train_image_ds = train_path_ds.map(load_and_preprocess_image,num_parallel_calls=tf.data.experimental.AUTOTUNE)將train_image_ds與train_label_ds組合在一起
image_label_ds = tf.data.Dataset.zip((train_image_ds,train_label_ds))顯示圖片:
for i in range(20):plt.subplot(4, 5, i + 1)num +=1plt.xticks([])plt.yticks([])plt.grid(False)# 顯示圖片images = plt.imread(train_image_paths[i])plt.imshow(images)# 顯示標簽plt.xlabel(train_image_label[i])plt.show()在并未對數據進行shuffle之前,如下所示:
原數據中一共15000張圖片,分為15類,每類1000張,并按照順序排列,因此需要對數據進行打亂。
按照8:2的比例劃分訓練集與測試集
train_ds = image_label_ds.take(12000).shuffle(2000) test_ds = image_label_ds.skip(12000).shuffle(3000)超參數的設置
height = 64 width = 64 batch_size = 128 epochs = 50對訓練集與測試集進行batch_size 劃分
train_ds = train_ds.batch(batch_size)#設置batch_size train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) test_ds = test_ds.batch(batch_size) test_ds = test_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)再次檢查圖片,看看是否被打亂順序:
plt.figure(figsize=(8, 8))for images, labels in train_ds.take(1):# print(images.shape)for i in range(12):ax = plt.subplot(4, 3, i + 1)plt.imshow(images[i])plt.title(labels[i].numpy()) # 使用.numpy()將張量轉換為 NumPy 數組plt.axis("off")break plt.show()
順序已被打亂,初始目標完成。
3.網絡搭建&&編譯
model = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=32,kernel_size=(3,3),padding="same",activation="relu",input_shape=[64, 64, 3]),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation="relu"),tf.keras.layers.Dense(15, activation="softmax") ])model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy']) model.summary() history = model.fit(train_ds,validation_data=test_ds,epochs = epochs )經過50次epochs,訓練結果如下:
準確率達到了100%
4.混淆矩陣的繪制
模型加載:
model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/chinese_mnist/model.h5")標簽列表如下所示:
all_label_names = ['零','一','二','三','四','五','六','七','八','九','十','百','千','萬','億']繪制混淆矩陣
from sklearn.metrics import confusion_matrix import seaborn as sns import pandas as pd# 繪制混淆矩陣 all_label_names = ['零','一','二','三','四','五','六','七','八','九','十','百','千','萬','億'] def plot_cm(labels, pre):conf_numpy = confusion_matrix(labels, pre) # 根據實際值和預測值繪制混淆矩陣conf_df = pd.DataFrame(conf_numpy, index=all_label_names,columns=all_label_names) # 將data和all_label_names制成DataFrameplt.figure(figsize=(8, 7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu") # 將data繪制為混淆矩陣plt.title('混淆矩陣', fontsize=15)plt.ylabel('真實值', fontsize=14)plt.xlabel('預測值', fontsize=14)plt.show()model = tf.keras.models.load_model("E:/Users/yqx/PycharmProjects/chinese_mnist/model.h5")test_pre = [] test_label = [] for images, labels in test_ds:for image, label in zip(images, labels):img_array = tf.expand_dims(image, 0) # 增加一個維度pre = model.predict(img_array) # 預測結果test_pre.append(all_label_names[np.argmax(pre)]) # 將預測結果傳入列表test_label.append(all_label_names[label.numpy()]) # 將真實結果傳入列表 plot_cm(test_label, test_pre) # 繪制混淆矩陣#
總結:本次實驗最復雜的就是標簽處理那一塊,只有處理好這一步驟,才能正確的將圖片和標簽劃分到一起。實驗數據只有15000張,而Mnist數據集有70000張,雖然本次的模型準確率達到了100%,但是仍有可能在別的圖片預測錯誤。
努力加油a啊
總結
以上是生活随笔為你收集整理的深度学习之基于CNN实现汉字版手写数字识别(Chinese-Mnist)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 音悦台高清mv下载_音悦台没有了去哪看m
- 下一篇: 机器学习之决策树的原理及sklearn实