tensorflow随笔-保存与读取使用模型
1、MNIST是深度學習的經典入門demo,他是由6萬張訓練圖片和1萬張測試圖片構成的,每張圖片都是2828大小(如下圖),而且都是黑白色構成(這里的黑色是一個0-1的浮點數,黑色越深表示數值越靠近1),這些圖片是采集的不同的人手寫從0到9的數字。
下面先訓練識別數字模型
再保存模型
最后,讀取保存的模型,對數字圖片進行識別。
2、保存模型
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sun Feb 3 20:28:26 2019""" from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)import tensorflow as tf import osx=tf.placeholder(tf.float32,[None,784])w=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10]))y=tf.nn.softmax(tf.matmul(x,w)+b) y_=tf.placeholder(tf.float32,[None,10]) cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)init=tf.global_variables_initializer() sess=tf.Session() sess.run(init) saver=tf.train.Saver() for i in range(1000):sampleX,sampleY=mnist.train.next_batch(100)sess.run(train_step,feed_dict={x:sampleX,y_:sampleY})print("訓練完成") print("保存生成模型...") model_dir="mnist_model" model_name="ml1" if not os.path.exists(model_dir):os.mkdir(model_dir)saver.save(sess,os.path.join(model_dir,model_name)) print("保存生成模型成功")訓練完成
保存生成模型…
保存生成模型成功
[root@VM03centos learn]# ls mnist_model
checkpoint ml1.data-00000-of-00001 ml1.index ml1.meta
[root@VM03centos learn]# ls MNISTdata
t10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz
[root@VM03centos learn]#
讀取數字識別模型,對某個數字圖像進行識別
讀取模型…
INFO:tensorflow:Restoring parameters from mnist_model/ml1
讀取模型完成
根據模型進行計算…
預測輸出結果:[[1.8999807e-06 9.8351490e-01 3.0815993e-03 4.3848301e-03 4.1427880e-05
1.6864968e-04 7.6594086e-05 4.5587993e-03 3.2991443e-03 8.7222963e-04]]
預測結果:1
實際結果:1
總結
以上是生活随笔為你收集整理的tensorflow随笔-保存与读取使用模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 网页java在div输出内容_JS实现读
- 下一篇: Java面试题(亲身经历)