Wide Deep模型的理解及实战(Tensorflow)
目錄
?
一、背景
二、概述
三、模型原理
3.1、Wide模型
3.2、Deep模型
3.3、Wide和Deep模型的協同訓練
四、系統介紹
4.1、系統簡介
4.2、系統流程
五、tensorflow實戰
5.1數據集介
5.2 代碼
一、背景
Google于?2016 年在DLRS上發表了一篇文章:2016-Wide & Deep Learning for Recommender Systems,模型的核心思想是結合線性模型的記憶能力(memorization)和 DNN 模型的泛化能力(generalization),在訓練過程中同時優化 2 個模型的參數,從而達到整體模型的預測能力最優。
記憶(memorization)即從歷史數據中發現item或者特征之間的相關性。
泛化(generalization)即相關性的傳遞,發現在歷史數據中很少或者沒有出現的新的特征組合。
二、概述
一個推薦系統可被看作是一個搜索排序系統,這一系統的查詢輸入是一組(用戶,上下文信息)的集合,輸出是一個排序后的物品列表。接收到一個查詢請求,推薦系統會在數據庫中找到相關的物品并根據特定的目的進行排序,常見的目的有用戶點擊、購買。
與廣義的搜索排序問題一樣,推薦系統也面臨這樣一個挑戰:如何同時實現記憶與泛化(memorization and generalization)。記憶可不嚴格地定義為,學習并利用歷史數據中的高頻共現物體或特征所具有的關系。另一方面,泛化指這一關系的轉移能力(transitivity),以及發現在歷史數據中罕見或未曾出現過的新的特征組合?;谟洃浀耐扑]通常更集中于用戶歷史行為所涵蓋的特定主題(topical)。而基于泛化的推薦會傾向于提升推薦物品的多樣性。
在工業界的大規模線上推薦和排序系統中,廣義線性模型如邏輯回歸被廣泛應用,因為它們簡單、可擴展(scalable)且具有可解釋性。這些模型常?;讵殶峋幋a(one-hot encoding)的二進制化的稀疏特征進行訓練。例如,二進制特征“user_installed_app=netflix”在用戶安裝了Netflix時為1。模型的記憶能力可通過系數特征的外積變換有效實現,如AND(user_installed_app=netflix, impression_app=pandora)在用戶安裝了Netflix并隨后見到過Pandora時值為1。這闡明了一對特征的共現是如何與目標標簽產生關聯的。使用更粗粒度的特征可實現模型的泛化,如AND(user_installed_category=video, impression_category=music),但常常是需要手工做特征工程的。外積變換的一個局限是,它無法對訓練集中未出現過的“查詢-物品特征對”進行泛化。
基于嵌入技術的模型,如因子分解機(factorization machines)、深度神經網絡等,可對未見過的“查詢-物品特征對”進行泛化,這是通過對一個低維稠密嵌入向量的學習來實現的,且這種模型依賴于較少的特征工程。然而,在“查詢-物品矩陣”是稀疏、高維的時候,其低維表示是難以有效學習的,例如用戶有特定偏好或物品的吸引力較窄(niche/narrow appeal)時。在這種情形下,這一“查詢-物品對”應與大部分“查詢-物品對”無關,但稠密嵌入會對所有“查詢-物品對”給出非零的預測,因而會導致過于泛化并給出不相關的推薦。另一方面,線性模型使用外積特征變換可依賴相當少的參數來記住這些“例外規則”。
本文中,我們呈現Wide&Deep學習框架以在一個模型中同時達到記憶和泛化,這是通過協同訓練一個線性模型模塊和一個神經網絡模塊完成的,如下圖所示。
在wide &deep模型中包括兩個部分,分別為Wide模型和Deep模型,Wide模型如上圖左邊部分,Deep模型如上圖右邊部分。
wide &deep模型的思想來源是根據人腦有不斷記憶并且泛化的過程,這里講寬線形模型和深度神經網絡模型相結合,汲取各自優勢形成了wide &deep模型,以用于推薦排序。
wide &deep模型旨在使得訓練得到的模型能夠同時獲得記憶和泛化的能力:
- 記憶(memorization)即從歷史數據中發現item或者特征之間的相關性。這里通過大量的特征交叉產生特征交互作用的“記憶”,高效且可解釋。但要泛化需要更多的特征工程。
- 泛化(generalization)即相關性的傳遞,發現在歷史數據中很少或者沒有出現的新的特征組合。這里通過Embedding的方法,使用低維稠密特征的輸入,可以更好的泛化訓練樣本中未出現的交叉特征。
三、模型原理
3.1、Wide模型
wide部分是一個廣義線性模型,具有著的形式,如上圖(左)所示。是模型的預測,是個特征對應的向量,是模型參數,是模型偏差。最終在的基礎上增加sigmoid函數座位最終的輸出,其實就是一個LR模型。
特征及和包括原始輸入特征(raw input features)和變化得到的特征(transformed features)。其中一個最為重要的變換是外積變換(cross-product transformation),它被定義為:
其中,是一個bool型變量,在第個變換包含第i個特征時為1,否則為0。對于二進制特征,當且僅當所有組成特征(“gender=female” 和 “language=en”)都為1時值為1,否則為0。這將捕獲到二進制特征間的交互,并向廣義線性模型中添加非線性項。
?
3.2、Deep模型
deep部分是一個前饋神經網絡,如上圖(右)所示。對于類別型特征,原始輸入是特征字符串(如“language=en”)。這些稀疏、高維的類別型特征首先被轉化為低維稠密實值向量,通常被稱為嵌入向量。嵌入向量的維度通常在O(10)到O(100)間。在模型訓練階段,嵌入向量被隨機初始化并根據最小化最終的損失函數來學習向量參數。這些低維稠密嵌入向量隨后在神經網絡的前饋通路中被輸入到隱藏層中。特別地,每個隱藏層進行了如下計算:
其中:f是激活函數,通常為ReLU,是層的序號。
?
3.3、Wide和Deep模型的協同訓練
wide模塊和deep模塊的組合依賴于對其輸出的對數幾率(log odds)的加權求和作為預測,隨后這一預測值被輸入到一個一般邏輯損失函數(logistic loss function)中進行聯合訓練。需注意的是,聯合訓練(joint training)和拼裝(ensemble)是有區別的。在拼裝時,獨立模型是分別訓練的,它們的預測結果是在最終推斷結果時組合在一起的,而不是在訓練的時候。作為對比,聯合訓練是在訓練環節同時優化wide模型、deep模型以及它們總和的權重。在模型大小上也有不同:對拼裝而言,由于訓練是解耦的,獨立模型常常需要比較大(如更多的特征和變換)來達到足夠合理的準確度用于模型拼裝;作為對比,聯合訓練中wide部分只需要補充deep部分的弱點,即只需要少量的外積特征變換而不是全部的wide模型。
對Wide&Deep模型的聯合訓練通過同時對輸出向模型的wide和deep兩部分進行梯度的反向傳播(back propagating)來實現的,這其中應用了小批量隨機優化(mini-batch stochastic optimization)的技術。在實驗中,我們使用了FTRL(Follow-the-regularized-leader)算法以及使用正則化來優化wide部分的模型,并使用AdaGrad優化deep部分。
組合的模型如圖1(中)所示。對于邏輯回歸問題,模型的預測是:
四、系統介紹
4.1、系統簡介
在下圖中呈現了一個對推薦系統的概覽。用戶訪問應用商店時,將生成一個請求,它可能包括多種用戶和上下文特征。推薦系統返回一個APP列表(也可稱為"印象"/impression),用戶可對這一列表進行點擊、購買等行為。這些用戶行為和查詢、"印象",會被記錄在日志中以作為學習訓練使用的訓練數據。
考慮到數據庫中有超過一百萬的APP,在服務延遲需求(通常為O(10)ms)下,難以對每個請求窮盡所有APP的分值計算。因此,在收到請求時的第一步是召回(retrieval)。召回系統通過使用一組信號,返回一個與當前請求最匹配的物品小列表,通常這一組信號是一組機器學習的模型和人工指定的規則的組合。在縮小候選池后,排序系統通過物品的分值進行排序。這些分支通??杀硎緸镻(y|x),在給定特征x時用戶行為標簽為y的概率,特征x包括了用戶特征(如國籍,語言,人口統計數據),上下文特征(如設備,小時數,星期數)以及印象(impression)特征(如APP的“年齡”,APP的歷史統計數據)。
4.2、系統流程
應用推薦流程的實現由三部分組成:數據生成,模型訓練和模型服務,如下圖所示。
數據生成
在這一步中,在一段時間中的用戶和app的印象(impression)數據被用來生成訓練數據集。每項數據對應于一個印象(impression)。標簽是app獲取與否:如果被印記(impressed)的APP得到了用戶的安裝則為1,否則為0。
模型訓練
我們在實驗中使用的模型結構如下圖所示。在訓練中,輸入層接收輸入數據和詞匯表(vocabularies),并生成稀疏、稠密特征以及標簽。wide部分由對用戶安裝app和印象(impression)app的外積變換構成。而對于deep部分,模型對每個類別型特征學習了一個32維的嵌入向量。我們將所有的潛入向量與稠密特征拼接在一起,得到了一個約1200維的稠密向量。這一拼接得到的向量隨后被輸入到三個ReLU層,并最終通過邏輯輸出單元(logistic output unit)。
Wide&Deep模型基于超5000億樣本進行訓練。每當有一組新的訓練數據到達時,模型都需要被重新訓練。然而,每次都進行重計算所需的計算量巨大,且會延遲數據抵達到模型更新的時效性。為了解決這一挑戰,我們實現了一個熱啟動系統,它將使用舊模型的潛入向量、線性模型權重等參數對新模型進行初始化。
在將模型加載到服務中前,會做一次預運行,來確保它不會導致線上服務出錯。我們的“靠譜測試”(sanity check)通過對新舊模型的經驗化驗證完成。
Embedding維度大小的建議,從經驗上來講,Embedding層維度大小可以用公式來定:
n是原始維度上特征不同取值的個數,k是一個常數,通常小于10.
線上應用
一旦模型完成了訓練和驗證,我們就將它加載到模型服務中。對于每個請求,該服務將收到一組召回系統提供的候選app列表以及用戶特征來計算評分。隨后,app按照評分從高到低排序,我們將向用戶按照這一順序展現app列表。評分通過一個Wide&Deep模型的前向推導(forward inference)來計算。
為了將請求響應速度優化至10ms,我們實現了多線程并行來優化服務性能,通過更小批次的并行,來替換先前在一個批次的推導步驟中對所有候選app進行評分。
?
五、tensorflow實戰
5.1數據集介紹
該數據集由Barry Becker從1994人口普查數據庫中提取得到。
預測任務是確定一個人年薪是否超過50K。
原文鏈接:http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names
實例數目
48842個連續或離散的實例。其中訓練集32561個,測試集16281個。
45222個因有未知量而被移除的實例。其中訓練集30162個,測試集15060個。
屬性數目
6個連續變量,8個名詞性屬性。
屬性信息
年齡:連續值;
工作類別:私人、自由職業非公司、自由職業公司、聯邦政府、地方政府、州政府、無薪、無工作經驗;
一個州內觀測人數:連續值;
教育程度: Bachelors(學士), Some-college(大學未畢業), 11th(高二), HS-grad(高中畢業),Prof-school(職業學校), Assoc-acdm(大學???#xff09;, Assoc-voc(準職業學位), 9th(初三),7th-8th(初中一、二年級), 12th(高三), Masters(碩士), 1st-4th(小學1-4年級), 10th(高一), Doctorate(博士), 5th-6th(小學5、6年級), Preschool(幼兒園).
教育時間:連續值;
婚姻狀態: Married-civ-spouse(已婚平民配偶), Divorced(離婚), Never-married(未婚), Separated(分居), Widowed(喪偶), Married-spouse-absent(已婚配偶異地), arried-AF-spouse(已婚軍屬)
職業:Tech-support(技術支持), Craft-repair(手工藝維修), Other-service(其他職業),Sales(銷售), Exec-managerial(執行主管), Prof-specialty(專業技術),Handlers-cleaners(勞工保潔), Machine-op-inspct(機械操作), Adm-clerical(管理文書),Farming-fishing(農業捕撈), Transport-moving(運輸), Priv-house-serv(家政服務),Protective-serv(保安), Armed-Forces(軍人)
家庭角色:Wife(妻子), Own-child(孩子), Husband(丈夫), Not-in-family(離家), Other-relative(其他關系), Unmarried(未婚)
種族: White(白人), Asian-Pac-Islander(亞裔、太平洋島裔), Amer-Indian-Eskimo(美洲印第安裔、愛斯基摩裔), Other(其他), Black(非遺)
性別: Female(女), Male(男)
資本收益:連續值
資本虧損:連續值
每周工作時長:連續值
原國籍:United-States(美國), Cambodia(柬埔寨), England(英國), Puerto-Rico(波多黎各),Canada(加拿大), Germany(德國), Outlying-US(Guam-USVI-etc) (美國海外屬地), India(印度),Japan(日本), Greece(希臘), South(南美), China(中國), Cuba(古巴), Iran(伊朗), Honduras(洪都拉斯), Philippines(菲律賓), Italy(意大利), Poland(波蘭), Jamaica(牙買加),Vietnam(越南), Mexico(墨西哥), Portugal(葡萄牙), Ireland(愛爾蘭), France(法國),Dominican-Republic(多米尼加共和國), Laos(老撾), Ecuador(厄瓜多爾), Taiwan(臺灣), Haiti(海地), Columbia(哥倫比亞), Hungary(匈牙利), Guatemala(危地馬拉),Nicaragua(尼加拉瓜), Scotland(蘇格蘭), Thailand(泰國), Yugoslavia(南斯拉夫), El-Salvador(薩爾瓦多), Trinadad&Tobago(特立尼達和多巴哥), Peru(秘魯), Hong(香港), Holand-Netherlands(荷蘭)
類別:>50K,<=50K
?
5.2 代碼
這里采用tensorflow高級API-Estimators進行訓練
5.2.1 導入包
import numpy as np import tensorflow as tf import pandas as pd import random import math import refrom sklearn import preprocessing from os import path, listdir from sklearn.datasets import load_svmlight_files from sklearn.model_selection import train_test_split from sklearn import metrics from tensorflow.contrib import layersfrom sklearn import metricsimport time import datetimeimport os os.environ["CUDA_VISIBLE_DEVICES"]="0"import tensorflow as tf5.2.2 數據準備
關于tf.feature_column的可以查閱博文:https://blog.csdn.net/Andy_shenzl/article/details/105145865
# 定義輸入樣本格式 _CSV_COLUMNS = ['age', 'workclass', 'fnlwgt', 'education', 'education_num','marital_status', 'occupation', 'relationship', 'race', 'gender','capital_gain', 'capital_loss', 'hours_per_week', 'native_country','income_bracket' ] _CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],[0], [0], [0], [''], ['']] _NUM_EXAMPLES = {'train': 32561,'validation': 16281, }"""Builds a set of wide and deep feature columns.""" def build_model_columns():# 1. 特征處理,包括:連續特征、離散特征、轉換特征、交叉特征等# 連續特征 (其中在Wide和Deep組件都會用到)age = tf.feature_column.numeric_column('age')education_num = tf.feature_column.numeric_column('education_num')capital_gain = tf.feature_column.numeric_column('capital_gain')capital_loss = tf.feature_column.numeric_column('capital_loss')hours_per_week = tf.feature_column.numeric_column('hours_per_week')# 離散特征education = tf.feature_column.categorical_column_with_vocabulary_list('education', ['Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college','Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school','5th-6th', '10th', '1st-4th', 'Preschool', '12th'])marital_status = tf.feature_column.categorical_column_with_vocabulary_list('marital_status', ['Married-civ-spouse', 'Divorced', 'Married-spouse-absent','Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])relationship = tf.feature_column.categorical_column_with_vocabulary_list('relationship', ['Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried','Other-relative'])workclass = tf.feature_column.categorical_column_with_vocabulary_list('workclass', ['Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov','Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])# 離散hash bucket特征occupation = tf.feature_column.categorical_column_with_hash_bucket('occupation', hash_bucket_size=1000)# 特征Transformationsage_buckets = tf.feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])# 2. 設定Wide層特征"""Wide部分使用了規范化后的連續特征、離散特征、交叉特征"""# 基本特征列base_columns = [# 全是離散特征education, marital_status, relationship, workclass, occupation,age_buckets,]# 交叉特征列crossed_columns = [tf.feature_column.crossed_column(['education', 'occupation'], hash_bucket_size=1000),tf.feature_column.crossed_column([age_buckets, 'education', 'occupation'], hash_bucket_size=1000)]# wide特征列wide_columns = base_columns + crossed_columns# 3. 設定Deep層特征"""Deep層主要針對離散特征進行處理,其中處理方式有:1. Sparse Features -> Embedding vector -> 串聯(連續特征),其中Embedding Values隨機初始化。2. 另外一種處理離散特征的方法是:one-hot和multi-hot representation. 此方法適用于低維度特征,其中embedding是通用的做法其中:采用embedding_column(embedding)和indicator_column(multi-hot)API"""# deep特征列deep_columns = [age,education_num,capital_gain,capital_loss,hours_per_week,tf.feature_column.indicator_column(workclass),tf.feature_column.indicator_column(education),tf.feature_column.indicator_column(marital_status),tf.feature_column.indicator_column(relationship),# embedding特征tf.feature_column.embedding_column(occupation, dimension=8)]return wide_columns, deep_columns5.2.3 定義輸入
def input_fn(data_file, num_epochs, shuffle, batch_size):"""為Estimator創建一個input function"""#assert判斷,為False時執行后面語句assert tf.gfile.Exists(data_file), "{0} not found.".format(data_file)def parse_csv(line):print("Parsing", data_file)# tf.decode_csv會把csv文件轉換成Tensor。其中record_defaults用于指明每一列的缺失值用什么填充。columns = tf.decode_csv(line, record_defaults=_CSV_COLUMN_DEFAULTS)features = dict(zip(_CSV_COLUMNS, columns))#pop函數提取labellabels = features.pop('income_bracket')# tf.equal(x, y) 返回一個bool類型Tensor, 表示x == y, element-wisereturn features, tf.equal(labels, '>50K') dataset = tf.data.TextLineDataset(data_file).map(parse_csv, num_parallel_calls=5)'''使用 tf.data.Dataset.map,我們可以很方便地對數據集中的各個元素進行預處理。map接收一個函數,Dataset中的每個元素都會被當作這個函數的輸入,并將函數返回值作為新的Dataset因為輸入元素之間時獨立的,所以可以在多個 CPU 核心上并行地進行預處理。num_parallel_calls 參數的最優值取決于你的硬件、訓練數據的特質(比如:它的 size、shape)、map 函數的計算量 和 CPU 上同時進行的其它處理。比較簡單的一個設置方法是:將 num_parallel_calls 設置為 CPU 的核心數。例如,CPU 有四個核心時,將 num_parallel_calls 設置為 4 將會很高效。相反,如果 num_parallel_calls 大于 CPU 的核心數,將導致低效的調度,導致輸入管道的性能下降。也可以設置shuffleshuffle的功能為打亂dataset中的元素,它有一個參數buffersize,表示打亂時使用的buffer的大小,建議舍的不要太小,一般是1000'''dataset = dataset.repeat(num_epochs)#repeat的功能就是將整個序列重復多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(2)就可以將之變成2個epochdataset = dataset.batch(batch_size)'''batch是機器學習中批量梯度下降法(Batch Gradient Descent, BGD)的概念,在每次梯度下降的時候取batch-size的數據量做平均來獲取梯度下降方向,例如我們將batch-size設為2,那么每次iterator都會得到2個數據'''iterator = dataset.make_one_shot_iterator()batch_features, batch_labels = iterator.get_next()return batch_features, batch_labels5.2.4 模型準備
tensorFlow-estimator模塊
# Wide & Deep Model def build_estimator(model_dir, model_type):"""Build an estimator appropriate for the given model type."""wide_columns, deep_columns = build_model_columns()hidden_units = [100, 50]# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which# trains faster than GPU for this model.run_config = tf.estimator.RunConfig().replace(session_config=tf.ConfigProto(device_count={'GPU': 0}))if model_type == 'wide':return tf.estimator.LinearClassifier(model_dir=model_dir,feature_columns=wide_columns,config=run_config)elif model_type == 'deep':return tf.estimator.DNNClassifier(model_dir=model_dir,feature_columns=deep_columns,hidden_units=hidden_units,config=run_config)else:return tf.estimator.DNNLinearCombinedClassifier(model_dir=model_dir,linear_feature_columns=wide_columns,dnn_feature_columns=deep_columns,dnn_hidden_units=hidden_units,config=run_config)?
5.2.5 模型訓練
?
# 模型路徑 model_type = 'widedeep' model_dir = '/Users/admin/Desktop/model/推薦算法/widedeep'# Wide & Deep 聯合模型 model = build_estimator(model_dir, model_type)# ## 4)模型訓練# In[11]:# 訓練參數 train_epochs = 10 batch_size = 5000 train_file = '/Users/admin/Desktop/model/推薦算法/widedeep/adult.data' test_file = '/Users/admin/Desktop/model/推薦算法/widedeep/adult.test'# 6. 開始訓練 for n in range(train_epochs):# 模型訓練model.train(input_fn=lambda: input_fn(train_file, train_epochs, True, batch_size))# 模型評估results = model.evaluate(input_fn=lambda: input_fn(test_file, 1, False, batch_size))# 打印評估結果print("Results at epoch {0}".format((n+1) * train_epochs))print('-'*30)for key in sorted(results):print("{0:20}: {1:.4f}".format(key, results[key]))模型最后一次輸出結果
Parsing /Users/admin/Desktop/model/推薦算法/widedeep/adult.data INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Restoring parameters from /Users/admin/Desktop/model/推薦算法/widedeep/model.ckpt-1254 INFO:tensorflow:Saving checkpoints for 1255 into /Users/admin/Desktop/model/推薦算法/widedeep/model.ckpt. INFO:tensorflow:loss = 8.967002, step = 1255 INFO:tensorflow:Saving checkpoints for 1320 into /Users/admin/Desktop/model/推薦算法/widedeep/model.ckpt. INFO:tensorflow:Loss for final step: 1.0456264. Parsing /Users/admin/Desktop/model/推薦算法/widedeep/adult.test WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool. WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool. INFO:tensorflow:Starting evaluation at 2020-03-27-06:13:54 INFO:tensorflow:Restoring parameters from /Users/admin/Desktop/model/推薦算法/widedeep/model.ckpt-1320 INFO:tensorflow:Finished evaluation at 2020-03-27-06:13:56 INFO:tensorflow:Saving dict for global step 1320: accuracy = 1.0, accuracy_baseline = 1.0, auc = 1.0, auc_precision_recall = 0.0, average_loss = 0.0018178094, global_step = 1320, label/mean = 0.0, loss = 7.3989387, prediction/mean = 0.001807458 Results at epoch 100 ------------------------------ accuracy : 1.0000 accuracy_baseline : 1.0000 auc : 1.0000 auc_precision_recall: 0.0000 average_loss : 0.0018 global_step : 1320.0000 label/mean : 0.0000 loss : 7.3989 prediction/mean : 0.0018?
2020/04/02 更
六、報錯提示
報錯1:
- Field 0 in record 0 is not a valid int32: 5.0
? ? 發現數據類型有些問題,因為我的數據中有浮點數,但如果默認值設置的全是0,TensorFlow根據默認值推測數據應當全是INT32類型,而不是float32。
? ? 解決方案:
? ? 將默認值由1改為0.0即可,這樣TensorFlow就會推斷所有值均為浮點數,所有的數值都可識別了。
報錯2:? ?
- InvalidArgumentError : Shape in shape_and_slice spec [,] does not match the shape stored in checkpoi
? ? -解決方法:
? ? classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,hidden_units=[10, 20, 10],n_classes=2,model_dir="/tmp/iris_model")
? ? 這里的model_dir給改了。。因為tensorflow默認會先去找已經訓練過的模型
?
?
參考:
https://arxiv.org/abs/1606.07792
總結
以上是生活随笔為你收集整理的Wide Deep模型的理解及实战(Tensorflow)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 怎么判断安卓解锁是否成功
- 下一篇: linux 内核list head,Li