下载MNIST数据集并使用python将数据转换成NumPy数组(源码解析)
下載MNIST數據集并使用python將數據轉換成NumPy數組
- 首先來分析init_mnist函數
- 接下來繼續分析load_mnist函數
- 實現數據集轉換的python腳本的代碼
- 顯示MNIST圖像并確認數據
下載MNIST數據集并將數據轉換成NumPy數組的Python腳本里面最重要的就是load_mnist函數,其他項目想要調用數據集的話,就可以調用load_mnist函數,得到一個字典類型的數據,字典的值是一個Numpy數組。
這些過程是如何實現的,現在開始逐字逐句分析源碼:
在load_mnist函數中第一句話是
if not os.path.exists(save_file):init_mnist()如果說數據沒有被下載,那么就調用init_mnist()函數。
首先來分析init_mnist函數
在init_mnist()函數中,可以發現調用了download_mnist()函數。
def init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")在download_mnist()函數中,可以看到又調用了_download(v)函數。
def download_mnist():for v in key_file.values():_download(v)在_download(v)函數中,可以看出,它最重要的一句話就是urllib.request.urlretrieve,這個語句的意思就是把數據集下載到file_path路徑下的文件里面。
def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done") url_base = 'http://yann.lecun.com/exdb/mnist/' key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz' }然后回到download_mnist()函數,這里面調用了_convert_numpy函數
# download_mnist()函數dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")我們看 _convert_numpy函數:這函數返回一個字典數據類型,也就是鍵值對。這個函數里面調用了 _load_img函數。
def _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return dataset我們看 _load_img函數,由print(“Converting " + file_name + " to NumPy Array …”)可以了解到,這個函數是用來將數據集轉換成numpy數組的。
_load_img函數里面gzip.open(file_path, ‘rb’),數據集是gz后綴的,這句話就是把這個數據給讀出來。
def _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return data_load_img函數里面data = np.frombuffer(f.read(), np.uint8, offset=16)這句話,是把f.read()里面的數據轉化成numpy數組,而且數組元素類型是uint8,讀取的起始位置是16,為什么是16,可以看數據集TRAINING SET IMAGE FILE (train-images-idx3-ubyte)的存儲內容:
[offset] [type] [value] [description]` `0000 32 bit integer 0x00000803(2051) magic number` `0004 32 bit integer 60000 number of images` `0008 32 bit integer 28 number of rows` `0012 32 bit integer 28 number of columns` `0016 unsigned byte ?? pixel` `0017 unsigned byte ?? pixel` `........` `xxxx unsigned byte ?? pixel這部分是訓練集的image信息,image信息是通過灰度值存儲的,前16字節是數據集的信息,后面的字節都是圖片的信息。所以要存圖片的信息,就從16字節開始。
后面的data = data.reshape(-1, img_size)這句話,意思是把這個numpy數組變成行為1,列為img_size的樣子。那么img_size函數最后就返回一個numpy數組。至此, _load_img函數已經解析完。
再看_convert_numpy函數,返回的dataset也就是一個字典,鍵是字符串,值是numpy數組。
回到init_mnist()函數里面,由print(“Creating pickle file …”)可以看到得到dataset之后,該函數進行的是創建pickle文件的操作。with open(save_file, ‘wb’) as f 這句話,意思是以二進制格式打開名字為save_file的文件只用于寫入。我們的save_file = dataset_dir + “/mnist.pkl”,所以就是創建了一個pkl文件。那么寫入什么呢,接下來看pickle.dump(dataset, f, -1)這句話,這句話表明,將對象dataset保存到我們的pkl文件中去,這個-1是pickle進行轉換的協議版本。那么至此,init_mnist函數已經分析完,它返回一個pickle文件。
def init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")接下來繼續分析load_mnist函數
下面有一行,with open(save_file, ‘rb’) as f: dataset = pickle.load(f),把之前的pickle文件重構為原來的python對象,給dataset。
load_mnist的參數normalize=True,這是將輸入圖像正規化為0-1的值,各個像素取值在0-255之間,dataset[key] /= 255.0就變成0-1之間了。
load_mnist的參數one_hot_label如果為True的話,設置將標簽保存為ont-hot表示,one-hot表示是僅正確解標簽為1,其余皆為0的數組。調用了 _change_one_hot_label函數來實現。
def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tload_mnist的參數flatten設置為True,則輸入圖像會保存為由784個元素構成的一維數組,設置為False,則輸入圖像為1*28 *28的三維數組。
最后load_mnist返回字典類型的dataset。鍵分別是train_img、train_label、test_img、test_label,值是由后綴為.gz數據集文件轉換得到的Numpy數組。
def load_mnist(normalize=True, flatten=True, one_hot_label=False):if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])至此,load_mnist函數已經分析完畢,下載MNIST數據集并使用python將數據轉換成NumPy數組的全部代碼:
實現數據集轉換的python腳本的代碼
# coding: utf-8 try:import urllib.request except ImportError:raise ImportError('You should use Python 3.x') import os.path import gzip import pickle import os import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/' key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz' }dataset_dir = os.path.dirname(os.path.abspath(__file__)) save_file = dataset_dir + "/mnist.pkl"train_num = 60000 test_num = 10000 img_dim = (1, 28, 28) img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""讀入MNIST數據集Parameters----------normalize : 將圖像的像素值正規化為0.0~1.0one_hot_label : one_hot_label為True的情況下,標簽作為one-hot數組返回one-hot數組是指[0,0,1,0,0,0,0,0,0,0]這樣的數組flatten : 是否將圖像展開為一維數組Returns-------(訓練圖像, 訓練標簽), (測試圖像, 測試標簽)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()顯示MNIST圖像并確認數據
首先調用前面寫的load_mnist函數(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)得到x_train、t_train、x_test、t_test這幾個字典類型的對象。
要看訓練集的第一個數據,就可以通過img = x_train[0]讀出來第一個圖片,label = t_train[0]讀出來數據集里面放的第一個標簽。輸出出來發現,數據集里第一個圖是5 。
展示圖片用的是img_show函數,這個函數里面用的Image.fromarray作用是將array數據轉成PIL能用的數據格式,從而輸出圖片。
import sys, os sys.path.append(os.pardir) # 為了導入父目錄的文件而進行的設定 import numpy as np from dataset.mnist import load_mnist from PIL import Imagedef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)img = x_train[0] label = t_train[0] print(label) # 5print(img.shape) # (784,) img = img.reshape(28, 28) # 把圖像的形狀變為原來的尺寸 print(img.shape) # (28, 28)img_show(img)輸出:
Downloading train-images-idx3-ubyte.gz ... Done Downloading train-labels-idx1-ubyte.gz ... Done Downloading t10k-images-idx3-ubyte.gz ... Done Downloading t10k-labels-idx1-ubyte.gz ... Done Converting train-images-idx3-ubyte.gz to NumPy Array ... Done Converting train-labels-idx1-ubyte.gz to NumPy Array ... Done Converting t10k-images-idx3-ubyte.gz to NumPy Array ... Done Converting t10k-labels-idx1-ubyte.gz to NumPy Array ... Done Creating pickle file ... Done! 5 (784,) (28, 28)Process finished with exit code 0 創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的下载MNIST数据集并使用python将数据转换成NumPy数组(源码解析)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 定位系统服务器,android系统定位服
- 下一篇: linux shell 变量减法_Lin