如何从TensorFlow的mnist数据集导出手写体数字图片
在TensorFlow的官方入門課程中,多次用到mnist數據集。
mnist數據集是一個數字手寫體圖片庫,但它的存儲格式并非常見的圖片格式,所有的圖片都集中保存在四個擴展名為idx3-ubyte的二進制文件。
如果我們想要知道大名鼎鼎的mnist手寫體數字都長什么樣子,就需要從mnist數據集中導出手寫體數字圖片。了解這些手寫體的總體形狀,也有助于加深我們對TensorFlow入門課程的理解。
下面先給出通過TensorFlow api接口導出mnist手寫體數字圖片的python代碼,再對代碼進行分析。代碼在win7下測試通過,linux環境也可以參考本處代碼。
(非常良心的注釋和打印有木有)
#!/usr/bin/python3.5 # -*- coding: utf-8 -*-import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_datafrom PIL import Image# 聲明圖片寬高 rows = 28 cols = 28# 要提取的圖片數量 images_to_extract = 8000# 當前路徑下的保存目錄 save_dir = "./mnist_digits_images"# 讀入mnist數據 mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)# 創建會話 sess = tf.Session()# 獲取圖片總數 shape = sess.run(tf.shape(mnist.train.images)) images_count = shape[0] pixels_per_image = shape[1]# 獲取標簽總數 shape = sess.run(tf.shape(mnist.train.labels)) labels_count = shape[0]# mnist.train.labels是一個二維張量,為便于后續生成數字圖片目錄名,有必要一維化(后來發現只要把數據集的one_hot屬性設為False,mnist.train.labels本身就是一維) #labels = sess.run(tf.argmax(mnist.train.labels, 1)) labels = mnist.train.labels# 檢查數據集是否符合預期格式 if (images_count == labels_count) and (shape.size == 1):print ("數據集總共包含 %s 張圖片,和 %s 個標簽" % (images_count, labels_count))print ("每張圖片包含 %s 個像素" % (pixels_per_image))print ("數據類型:%s" % (mnist.train.images.dtype))# mnist圖像數據的數值范圍是[0,1],需要擴展到[0,255],以便于人眼觀看if mnist.train.images.dtype == "float32":print ("準備將數據類型從[0,1]轉為binary[0,255]...")for i in range(0,images_to_extract):for n in range(pixels_per_image):if mnist.train.images[i][n] != 0:mnist.train.images[i][n] = 255# 由于數據集圖片數量龐大,轉換可能要花不少時間,有必要打印轉換進度if ((i+1)%50) == 0:print ("圖像浮點數值擴展進度:已轉換 %s 張,共需轉換 %s 張" % (i+1, images_to_extract))# 創建數字圖片的保存目錄for i in range(10):dir = "%s/%s/" % (save_dir,i)if not os.path.exists(dir):print ("目錄 ""%s"" 不存在!自動創建該目錄..." % dir)os.makedirs(dir)# 通過python圖片處理庫,生成圖片indices = [0 for x in range(0, 10)]for i in range(0,images_to_extract):img = Image.new("L",(cols,rows))for m in range(rows):for n in range(cols):img.putpixel((n,m), int(mnist.train.images[i][n+m*cols]))# 根據圖片所代表的數字label生成對應的保存路徑digit = labels[i]path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])indices[digit] += 1img.save(path)# 由于數據集圖片數量龐大,保存過程可能要花不少時間,有必要打印保存進度if ((i+1)%50) == 0:print ("圖片保存進度:已保存 %s 張,共需保存 %s 張" % (i+1, images_to_extract))else:print ("圖片數量和標簽數量不一致!")上述代碼的實現思路如下:
1.讀入mnist手寫體數據;
2.把數據的值從[0,1]浮點范圍轉化為黑白格式(背景為0-黑色,前景為255-白色);
3.根據mnist.train.labels的內容,生成數字索引,也就是建立每一張圖片和其所代表數字的關聯,由此創建對應的保存目錄;
4.循環遍歷mnist.train.images,把每張圖片的像素數據賦值給python圖片處理庫PIL的Image類實例,再調用Image類的save方法把圖片保存在第3步驟中創建的對應目錄。
?
在運行上述代碼之前,你需要確保本地已經安裝python的圖片處理庫PIL,pip安裝命令如下:
pip3 install Pillow
或 pip install Pillow,取決于你的pip版本。
?
上述python代碼運行后,在當前目錄下會生成mnist_digits_images目錄,在該目錄下,可以看到如下內容:可以看到,我們成功地生成了黑底白字的數字圖片。
如果仔細觀察這些圖片,會看到一些肉眼也難以分辨的數字,譬如:
上面這幾個數字是2。想不到吧?
下面這兩個是5(看起來更像6):
這個是7:(7長這樣?有句MMP不知當講不當講)
猜猜下面這個是什么:
這是大寫的L?不是。
有點像1,是1嗎?也不是。
倒立拉粑的7?sorry,又猜錯了。
實話告訴您,它是2!一開始我也是不相信的,知道真相的那一刻我下巴差點掉下來!
總結
以上是生活随笔為你收集整理的如何从TensorFlow的mnist数据集导出手写体数字图片的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国新中产家庭“清洁观”:能躺着不站着,
- 下一篇: numpy 中ravel()和flatt