[机器学习] gcForest 官方代码详解
1.介紹
gcForest v1.1.1是gcForest的一個(gè)官方托管在GitHub上的版本,是由Ji Feng(Deep Forest的paper的作者之一)維護(hù)和開(kāi)發(fā),該版本支持Python3.5,且有類似于Scikit-Learn的API接口風(fēng)格,在該項(xiàng)目中提供了一些調(diào)用例子,目前支持的基分類器有RandomForestClassifier,XGBClassifer,ExtraTreesClassifier,LogisticRegression,SGDClassifier如果采用XGBoost的基分類器還可以使用GPU
本文采用的是v1.1.1版本,github地址https://github.com/kingfengji/gcForest
如果想增加其他基分類器,可以在模塊中的lib/gcforest/estimators/__init__.py中添加
使用該模塊需要依賴安裝如下模塊:
- argparse
- joblib
- keras
- psutil
- scikit-learn>=0.18.1
- scipy
- simplejson
- tensorflow
- xgboost
2.API調(diào)用樣例
?
這里先列出gcForest提供的API接口:
-
fit_tranform(X_train,y_train) 是gcForest模型最后一層每個(gè)估計(jì)器預(yù)測(cè)的概率concatenated的結(jié)果
-
fit_transform(X_train,y_train,X_test=x_test,y_test=y_test) 測(cè)試數(shù)據(jù)的準(zhǔn)確率在訓(xùn)練的過(guò)程中也會(huì)被記錄下來(lái)
-
set_keep_model_mem(False) 如果你的緩存不夠,把該參數(shù)設(shè)置成False(默認(rèn)為True),如果設(shè)置成False,你需要使用fit_transform(X_train,y_train,X_test=x_test,y_test=y_test)來(lái)評(píng)估你的模型
-
predict(X_test) # 模型預(yù)測(cè)
-
transform(X_test)
代碼主要分為兩部分:examples文件夾下是主代碼.py和配置文件.json;libs文件夾下是代碼中用到的庫(kù)
主代碼的實(shí)現(xiàn)
最簡(jiǎn)單的調(diào)用gcForest的方式如下:
# 導(dǎo)入必要的模塊 from gcforest.gcforest import GCForest# 初始化一個(gè)gcForest對(duì)象 gc = GCForest(config) # config是一個(gè)字典結(jié)構(gòu)# gcForest模型最后一層每個(gè)估計(jì)器預(yù)測(cè)的概率concatenated的結(jié)果 X_train_enc = gc.fit_transform(X_train,y_train)# 測(cè)試集的預(yù)測(cè) y_pred = gc.predict(X_test)?
lib庫(kù)的詳解
gcforest.py 整個(gè)框架的實(shí)現(xiàn)
fgnet.py 多粒度部分,FineGrained的實(shí)現(xiàn)
cascade/cascade_classifier 級(jí)聯(lián)分類器的實(shí)現(xiàn)
datasets/.... 包含一系列數(shù)據(jù)集的定義
estimator/... 包含決策樹(shù)在進(jìn)行評(píng)估用到的函數(shù)(多種分類器的預(yù)估)
layer/... 包含不同的層操作,如連接、池化、滑窗等
utils/.. 包含各種功能函數(shù),譬如計(jì)算準(zhǔn)確率、win_vote、win_avg、get_windows等
?
json配置文件的詳解
參數(shù)介紹
- max_depth: 決策樹(shù)最大深度。默認(rèn)為"None",決策樹(shù)在建立子樹(shù)的時(shí)候不會(huì)限制子樹(shù)的深度這樣建樹(shù)時(shí),會(huì)使每一個(gè)葉節(jié)點(diǎn)只有一個(gè)類別,或是達(dá)到min_samples_split。一般來(lái)說(shuō),數(shù)據(jù)少或者特征少的時(shí)候可以不管這個(gè)值。如果模型樣本量多,特征也多的情況下,推薦限制這個(gè)最大深度,具體的取值取決于數(shù)據(jù)的分布。常用的可以取值10-100之間。
- estimators表示選擇的分類器
- n_estimators 為森林里的樹(shù)的數(shù)量
- n_jobs: int (default=1)
The number of jobs to run in parallel for any Random Forest fit and predict.
If -1, then the number of jobs is set to the number of cores.
訓(xùn)練的配置,分三類情況:
支持的基本分類器:
RandomForestClassifier
XGBClassifier
ExtraTreesClassifier
LogisticRegression
SGDClassifier
你可以通過(guò)下述方式手動(dòng)添加任何分類器:
lib/gcforest/estimators/__init__.py滑動(dòng)窗口的大小: {[d/16], [d/8], [d/4]},d代表輸入特征的數(shù)量;
"look_indexs_cycle": [
[0, 1],
[2, 3],
[4, 5]]
代表級(jí)聯(lián)多粒度的方式,第一層級(jí)聯(lián)0、1森林的輸出,第二層級(jí)聯(lián)2、3森林的輸出,第三層級(jí)聯(lián)4、5森林的輸出
3.MNIST樣例
下面我們使用MNIST數(shù)據(jù)集來(lái)演示gcForest的使用及代碼的詳細(xì)說(shuō)明:
# 導(dǎo)入必要的模塊import argparse # 命令行參數(shù)調(diào)用模塊 import numpy as np import sys from keras.datasets import mnist # MNIST數(shù)據(jù)集 import pickle from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import accuracy_score sys.path.insert(0, "lib")from gcforest.gcforest import GCForest from gcforest.utils.config_utils import load_jsondef parse_args():'''解析終端命令行參數(shù)(model)'''parser = argparse.ArgumentParser()parser.add_argument("--model", dest="model", type=str, default=None, help="gcfoest Net Model File")args = parser.parse_args()return argsdef get_toy_config():'''生成級(jí)聯(lián)結(jié)構(gòu)的相關(guān)結(jié)構(gòu)'''config = {}ca_config = {}ca_config["random_state"] = 0ca_config["max_layers"] = 100ca_config["early_stopping_rounds"] = 3ca_config["n_classes"] = 10ca_config["estimators"] = []ca_config["estimators"].append({"n_folds": 5, "type": "XGBClassifier", "n_estimators": 10, "max_depth": 5,"objective": "multi:softprob", "silent": True, "nthread": -1, "learning_rate": 0.1} )ca_config["estimators"].append({"n_folds": 5, "type": "RandomForestClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 5, "type": "ExtraTreesClassifier","n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 5, "type": "LogisticRegression"})config["cascade"] = ca_configreturn config# get_toy_config()生成的結(jié)構(gòu),如下所示:''' { "cascade": {"random_state": 0,"max_layers": 100,"early_stopping_rounds": 3,"n_classes": 10,"estimators": [{"n_folds":5,"type":"XGBClassifier","n_estimators":10,"max_depth":5,"objective":"multi:softprob", "silent":true, "nthread":-1, "learning_rate":0.1},{"n_folds":5,"type":"RandomForestClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},{"n_folds":5,"type":"ExtraTreesClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},{"n_folds":5,"type":"LogisticRegression"}] } } '''if __name__ == "__main__":args = parse_args()if args.model is None:config = get_toy_config()else:config = load_json(args.model)gc = GCForest(config)# 如果模型消耗太大內(nèi)存,可以使用如下命令使得gcforest不保存在內(nèi)存中# gc.set_keep_model_in_mem(False), 默認(rèn)情況下是True.(X_train, y_train), (X_test, y_test) = mnist.load_data()# X_train, y_train = X_train[:2000], y_train[:2000]# np.newaxis相當(dāng)于增加了一個(gè)維度X_train = X_train[:, np.newaxis, :, :]X_test = X_test[:, np.newaxis, :, :]X_train_enc = gc.fit_transform(X_train, y_train)# X_enc是gcForest模型最后一層每個(gè)估計(jì)器預(yù)測(cè)的概率concatenated的結(jié)果# X_enc.shape =# (n_datas, n_estimators * n_classes): 如果是級(jí)聯(lián)結(jié)構(gòu)# (n_datas, n_estimators * n_classes, dimX, dimY): 如果只有多粒度掃描結(jié)構(gòu)# 可以在fit_transform方法中加入X_test,y_test,這樣測(cè)試數(shù)據(jù)的準(zhǔn)確率在訓(xùn)練的過(guò)程中# 也會(huì)被記錄下來(lái)。# X_train_enc, X_test_enc = gc.fit_transform(X_train, y_train, X_test=X_test, y_test=y_test)# 注意: 如果設(shè)置了gc.set_keep_model_in_mem(True),必須使用# gc.fit_transform(X_train, y_train, X_test=X_test, y_test=y_test)# 評(píng)估模型# 測(cè)試集預(yù)測(cè)與評(píng)估y_pred = gc.predict(X_test)acc = accuracy_score(y_test, y_pred)print("Test Accuracy of GcForest = {:.2f} %".format(acc * 100))# 可以使用gcForest得到的X_enc數(shù)據(jù)進(jìn)行其他模型的訓(xùn)練比如xgboost/RF# 數(shù)據(jù)的concatX_test_enc = gc.transform(X_test)X_train_enc = X_train_enc.reshape((X_train_enc.shape[0], -1))X_test_enc = X_test_enc.reshape((X_test_enc.shape[0], -1))X_train_origin = X_train.reshape((X_train.shape[0], -1))X_test_origin = X_test.reshape((X_test.shape[0], -1))X_train_enc = np.hstack((X_train_origin, X_train_enc))X_test_enc = np.hstack((X_test_origin, X_test_enc))print("X_train_enc.shape={}, X_test_enc.shape={}".format(X_train_enc.shape,X_test_enc.shape))# 訓(xùn)練一個(gè)RFclf = RandomForestClassifier(n_estimators=1000, max_depth=None, n_jobs=-1)clf.fit(X_train_enc, y_train)y_pred = clf.predict(X_test_enc)acc = accuracy_score(y_test, y_pred)print("Test Accuracy of Other classifier using gcforest's X_encode = {:.2f} %".format(acc * 100))# 模型寫入pickle文件with open("test.pkl", "wb") as f:pickle.dump(gc, f, pickle.HIGHEST_PROTOCOL)# 加載訓(xùn)練的模型with open("test.pkl", "rb") as f:gc = pickle.load(f)y_pred = gc.predict(X_test)acc = accuracy_score(y_test, y_pred)print("Test Accuracy of GcForest (save and load) = {:.2f} %".format(acc * 100))這里需要注意的是gcForest不但可以對(duì)傳統(tǒng)的結(jié)構(gòu)化的2維數(shù)據(jù)建模,還可以對(duì)非結(jié)構(gòu)化的數(shù)據(jù)比如圖像,序列化的文本數(shù)據(jù),音頻數(shù)據(jù)等進(jìn)行建模,但要注意數(shù)據(jù)維度的設(shè)定:
-
如果僅使用級(jí)聯(lián)結(jié)構(gòu),X_train,X_test對(duì)于2-D數(shù)組其維度為(n_samples,n_features);3-D或4-D數(shù)組會(huì)自動(dòng)reshape為2-D,例如MNIST數(shù)據(jù)(60000,28,28)會(huì)reshape為(60000,784),(60000,3,28,28)會(huì)reshape為(60000,2352)。
-
如果使用多粒度掃描結(jié)構(gòu),X_train,X_test必須是4—D的數(shù)組,圖像數(shù)據(jù)其維度是(n_samples,n_channels,n_height,n_width);序列數(shù)據(jù)其維度為(n_smaples,n_features,seq_len,1),例如對(duì)于IMDB數(shù)據(jù),n_features為1,對(duì)于音頻MFCC特征,其n_features可以為13,26等。
上述代碼可以通過(guò)兩種方式運(yùn)行:
- 一種方式是通過(guò)json文件定義模型結(jié)構(gòu),比如級(jí)聯(lián)森林結(jié)構(gòu),只需要寫一個(gè)json文件如代碼中顯示的結(jié)構(gòu),然后通過(guò)命令行運(yùn)行python examples/demo_mnist.py --model examples/demo_mnist-gc.json就可以完成訓(xùn)練;如果既使用多粒度掃面又使用級(jí)聯(lián)結(jié)構(gòu),那么需要同時(shí)把多粒度掃描的結(jié)構(gòu)定義出來(lái)。
- 定義好的json可以通過(guò)模塊中的load_json()方法加載,然后作為參數(shù)初始化模型,如下:
- 另一種方式是直接通過(guò)Python代碼定義模型結(jié)構(gòu),實(shí)際上模型結(jié)構(gòu)就是一個(gè)字典數(shù)據(jù)結(jié)構(gòu),即是上述代碼中的get_toy_config()方法。
?
?
?
總結(jié)
以上是生活随笔為你收集整理的[机器学习] gcForest 官方代码详解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: [机器学习]PMML预测模型标记语言
- 下一篇: 2022最新短视频API解析接口源码(北