生活随笔
收集整理的這篇文章主要介紹了
TFboys:使用Tensorflow搭建深层网络分类器
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
前言
根據(jù)官方文檔整理而來(lái)的,主要是對(duì)Iris數(shù)據(jù)集進(jìn)行分類(lèi)。使用tf.contrib.learn.tf.contrib.learn快速搭建一個(gè)深層網(wǎng)絡(luò)分類(lèi)器,
步驟
導(dǎo)入csv數(shù)據(jù)搭建網(wǎng)絡(luò)分類(lèi)器訓(xùn)練網(wǎng)絡(luò)計(jì)算測(cè)試集正確率對(duì)新樣本進(jìn)行分類(lèi)
數(shù)據(jù)
Iris數(shù)據(jù)集包含150行數(shù)據(jù),有三種不同的Iris品種分類(lèi)。每一行數(shù)據(jù)給出了四個(gè)特征信息和一個(gè)分類(lèi)信息。
現(xiàn)在已經(jīng)將數(shù)據(jù)分為訓(xùn)練集和測(cè)試集
- A training set of 120 samples (iris_training.csv)
- A test set of 30 samples (iris_test.csv)
網(wǎng)絡(luò)搭建
1. 首先,導(dǎo)入tensorflow 和 numpy
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tfimport numpy as np
2. 導(dǎo)入數(shù)據(jù)
# 定義數(shù)據(jù)地址IRIS_TRAINING = "iris_training.csv"IRIS_TEST = "iris_test.csv"# 導(dǎo)入數(shù)據(jù)training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)
load_csv_with_header() 有三個(gè)參數(shù)
- filename, 數(shù)據(jù)地址
- target_dtype, 目標(biāo)值的numpy datatype(iris的目標(biāo)值是0,1,2,所以是np.int)
- features_dtype, 特征值的numpy datatype .
3. 搭建網(wǎng)絡(luò)結(jié)構(gòu)
# 每行數(shù)據(jù)4個(gè)特征,都是real-value的feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]# 3層DNN,3分類(lèi)問(wèn)題classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="iris_model")
參數(shù)解釋
- feature_columns 特征值
- hidden_units=[10, 20, 10]. 3個(gè)隱藏層,包含的隱藏神經(jīng)元依次是10, 20, 10
- n_classes 類(lèi)別個(gè)數(shù)
- model_dir 模型保存地址
4. 訓(xùn)練數(shù)據(jù)
classifier.fit(x=training_set.data, y=training_set.target, steps=2000)
steps 為訓(xùn)練次數(shù)
5. 計(jì)算準(zhǔn)確率
accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]print('Accuracy: {0:f}'.format(accuracy_score))
運(yùn)行結(jié)果是
Accuracy: 0.966667
6. 對(duì)新樣本進(jìn)行預(yù)測(cè)
# Classify two new flower samples.new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)y = list(classifier.predict(new_samples, as_iterable=True))print('Predictions: {}'.format(str(y)))
運(yùn)行結(jié)果為:
Prediction: [1 2]
完整代碼
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tfimport numpy as npIRIS_TRAINING = "iris_training.csv"IRIS_TEST = "iris_test.csv"training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="iris_model")classifier.fit(x=training_set.data, y=training_set.target, steps=2000)accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]print('Accuracy: {0:f}'.format(accuracy_score))new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)y = list(classifier.predict(new_samples, as_iterable=True))print('Predictions: {}'.format(str(y)))
參考
- tf.contrib.learn Quickstart
- tf.contrib.learn API
原文地址: http://www.datalearner.com/blog/1051488938031745
總結(jié)
以上是生活随笔為你收集整理的TFboys:使用Tensorflow搭建深层网络分类器的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。