机器学习入门(07)— MNIST 数据集手写数字的识别
和求解機器學習問題的步驟(分成學習和推理兩個階段進行)一樣,使用神經網絡解決問題時,也需要首先使用訓練數據(學習數據)進行權重參數的學習;進行推理時,使用剛才學習到的參數,對輸入數據進行分類。
1. MNIST 數據集
MNIST 數據集是由 0 到 9 的數字圖像構成的(圖3-24)。訓練圖像有 6 萬張,測試圖像有1 萬張,這些圖像可以用于學習和推理。
MNIST 數據集的一般使用方法是,先用訓練圖像進行學習,再用學習到的模型度量能在多大程度上對測試圖像進行正確的分類。
MNIST 的圖像數據是 28 像素 × 28 像素的灰度圖像(1 通道),各個像素的取值在0 到 255 之間。每個圖像數據都相應地標有“7”“2”“1”等標簽。
2. 代碼實現
2.1 下載并讀取數據
mnist.py 代碼實現
# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img': 'train-images-idx3-ubyte.gz','train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz','test_label': 't10k-labels-idx1-ubyte.gz'
}current_dir = os.path.dirname(os.path.abspath(__file__))
save_file = os.path.join(current_dir, "mnist.pkl")train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def download_mnist():for data_name in key_file.values():file_path = os.path.join(current_dir, data_name)if os.path.exists(file_path):print("{} exists, return".format(data_name))returnprint("download {} start...".format(data_name))urllib.request.urlretrieve(url_base + data_name, file_path)print("download {} end...".format(data_name))def _load_label(file_name):file_path = os.path.join(current_dir, file_name)print("Converting {} to NumPy Array ...".format(file_name))with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Converting Done")return labelsdef _load_img(file_name):file_path = os.path.join(current_dir, file_name)print("Converting {} to NumPy Array ...".format(file_name))with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Converting Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, one_hot_label=False, flatten=True):"""讀入MNIST數據集Parameters----------normalize : 將圖像的像素值正規化為0.0~1.0one_hot_label : one_hot_label為True的情況下,標簽作為one-hot數組返回one-hot數組是指[0,0,1,0,0,0,0,0,0,0]這樣的數組flatten : 是否將圖像展開為一維數組Returns-------(訓練圖像, 訓練標簽), (測試圖像, 測試標簽)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()
load_mnist 函數以“( 訓練圖像, 訓練標簽),( 測試圖像,測試標簽)”的形式返回讀入的 MNIST 數據。此外,還可以像
load_mnist(normalize=True, flatten=True, one_hot_label=False)
這樣,設置 3 個參數。
- 第 1 個參數
normalize設置是否將輸入圖像正規化為 0.0~1.0 的值。如果將該參數設置為False,則輸入圖像的像素會保持原來的 0~255。 - 第 2 個參數
flatten設置是否展開輸入圖像(變成一維數組)。如果將該參數設置為False,則輸入圖像為 1 × 28 × 28 的三維數組;若設置為True,則輸入圖像會保存為由 784 個元素構成的一維數組。 - 第 3 個參數
one_hot_label設置是否將標簽保存為onehot表示(one-hot representation)。one-hot表示是僅正確解標簽為 1,其余皆為 0 的數組,就像 [0,0,1,0,0,0,0,0,0,0] 這樣。當one_hot_label為False時,只是像 7、2 這樣簡單保存正確解標簽;當one_hot_label為True時,標簽則保存為one-hot表示。
2.2 顯示數據
mnist_show.py 代碼實現:
import numpy as np
from PIL import Image
from mnist import load_mnistdef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)img = x_train[0]
label = t_train[0]
print(label) # 5print(img.shape) # (784,)
img = img.reshape(28, 28) # 把圖像的形狀變為原來的尺寸
print(img.shape) # (28, 28)img_show(img)
需要注意的是,flatten=True 時讀入的圖像是以一列(一維)NumPy 數組的形式保存的。因此,顯示圖像時,需要把它變為原來的 28 像素× 28像素的形狀。可以通過reshape() 方法的參數指定期望的形狀,更改 NumPy數組的形狀。
此外,還需要把保存為 NumPy 數組的圖像數據轉換為 PIL 用的數據對象,這個轉換處理由Image.fromarray() 來完成。
2.3 神經網絡推理
對這個 MNIST 數據集實現神經網絡的推理處理。神經網絡的輸入層有 784 個神經元,輸出層有 10 個神經元。輸入層的 784 這個數字來源于圖像大小的 28 × 28 = 784,輸出層的 10 這個數字來源于10 類別分類(數字0 到9,共10 類別)。
此外,這個神經網絡有 2 個隱藏層,第 1 個隱藏層有 50 個神經元,第 2 個隱藏層有 100 個神經元。這個 50 和 100 可以設置為任何值。
neuralnet_mnist.py 代碼實現
# coding: utf-8import pickle
import numpy as npfrom mnist import load_mnistdef sigmoid(x):return 1 / (1 + np.exp(-x))def softmax(x):if x.ndim == 2:x = x.Tx = x - np.max(x, axis=0)y = np.exp(x) / np.sum(np.exp(x), axis=0)return y.Tx = x - np.max(x) # 溢出對策return np.exp(x) / np.sum(np.exp(x))def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_testdef init_network():with open("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef predict(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3y = softmax(a3)return yx, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):y = predict(network, x[i])p = np.argmax(y) # 獲取概率最高的元素的索引if p == t[i]:accuracy_cnt += 1print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
init_network() 會讀入保存在 pickle 文件 sample_weight.pkl 中的學習到的權重參數A 。這個文件中以字典變量的形式保存了權重和偏置參數。
首先獲得 MNIST 數據集,生成網絡。接著,用 for 語句逐一取出保存在 x 中的圖像數據,用 predict() 函數進行分類。
predict() 函數以 NumPy 數組的形式輸出各個標簽對應的概率。比如輸出 [0.1, 0.3, 0.2, …, 0.04] 的數組,該數組表示“0”的概率為0.1,“1”的概率為0.3,等等。然后,我們取出這個概率列表中的最大值的索引(第幾個元素的概率最高),作為預測結果。
可以用 np.argmax(x) 函數取出數組中的最大值的索引,np.argmax(x) 將獲取被賦給參數 x 的數組中的最大值元素的索引。最后,比較神經網絡所預測的答案和正確解標簽,將回答正確的概率作為識別精度。
執行代碼輸出結果是:
Accuracy:0.9352
這表示有93.52%的數據被正確分類了。
在這個例子中,我們把 load_mnist 函數的參數 normalize 設置成了True 。將normalize 設置成 True 后,函數內部會進行轉換,將圖像的各個像素值除以 255,使得數據的值在0.0~1.0 的范圍內。
像這樣把數據限定到某個范圍內的處理稱為正規化(normalization )或者叫歸一化處理。
此外,對神經網絡的輸入數據進行某種既定的轉換稱為預處理(pre-processing )。這里,作為對輸入圖像的一種預處理,我們進行了歸一化處理。
2.4 批處理
參考:《深度學習入門:基于Python的理論與實現》
總結
以上是生活随笔為你收集整理的机器学习入门(07)— MNIST 数据集手写数字的识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国的毒品都是从哪里来的?
- 下一篇: 吉他弦多少钱啊?