【算法竞赛学习】数字中国创新大赛智慧海洋建设-Task4模型建立
智慧海洋建設-Task4模型建立
此部分為智慧海洋建設競賽的模型建立模塊。在該模塊中主要介紹了如何進行模型建立并對模型調優。
學習目標
內容介紹
- 隨機森林
- lightGBM模型
- Xgboost模型
模型訓練與預測
模型訓練與預測的主要步驟為:
(1):導入需要的工具庫
(2):對數據預處理,包括導入數據集、處理數據等操作,具體為缺失值處理、連續特征歸一化、類別特征轉換等
(3):訓練模型。選擇合適的機器學習模型,利用訓練集對模型進行訓練,達到最佳擬合效果。
(4):預測結果。將待預測的數據輸入到訓練好的模型中,得到預測的結果。
下面進行幾種常用的分類算法進行介紹
隨機森林分類
隨機森林參數介紹
隨機森林是通過集成學習的思想將多棵樹集成的一種算法,基本單元是決策樹,而它的本質屬于機器學習的一個分支——集成學習。
隨機森林模型的主要優點是:在當前算法中,具有較好的準確率;能夠有效地運行在大數據集上;能夠處理具有高維特征的輸入樣本,而且不需要降維;能夠評估各個特征在分類問題上的重要性;在生成過程中,能夠獲取到內部生成誤差的一種無偏估計;對于缺省值問題也能夠獲得很好的結果。
使用sklearn調用隨機森林分類樹進行預測算法:
from sklearn import datasets from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score #數據集導入 iris=datasets.load_iris() feature=iris.feature_names X = iris.data y = iris.target #隨機森林 clf=RandomForestClassifier(n_estimators=200) train_X,test_X,train_y,test_y = train_test_split(X,y,test_size=0.1,random_state=5) clf.fit(train_X,train_y) test_pred=clf.predict(test_X) #特征的重要性查看 print(str(feature)+'\n'+str(clf.feature_importances_)) ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] [0.09838896 0.01544017 0.34365936 0.5425115 ]采用F1 score進行模型的評價,此為一篇csdn中對該評價方法的簡單說明
#F1-score 用于模型評價 #如果是二分類問題則選擇參數‘binary’ #如果考慮類別的不平衡性,需要計算類別的加權平均,則使用‘weighted’ #如果不考慮類別的不平衡性,計算宏平均,則使用‘macro’ score=f1_score(test_y,test_pred,average='macro') print("隨機森林-macro:",score) score=f1_score(test_y,test_pred,average='weighted') print("隨機森林-weighted:",score) 隨機森林-macro: 0.818181818181818 隨機森林-weighted: 0.8lightGBM模型
lightGBM的學習可參見這篇文章
lightGBM中文文檔這個對超參數的講解較為詳細,建議仔細閱讀
- 使用較小的 max_bin
- 使用較小的 num_leaves
- 使用 min_data_in_leaf 和 min_sum_hessian_in_leaf
- 通過設置 bagging_fraction 和 bagging_freq 來使用 bagging
- 通過設置 feature_fraction 來使用特征子抽樣
- 使用更大的訓練數據
- 使用 lambda_l1, lambda_l2 和 min_gain_to_split 來使用正則
- 嘗試 max_depth 來避免生成過深的樹
- 通過設置 bagging_fraction 和 bagging_freq 參數來使用 bagging 方法
- 通過設置 feature_fraction 參數來使用特征的子抽樣
- 使用較小的 max_bin
- 使用 save_binary 在未來的學習過程對數據加載進行加速
- 使用并行學習, 可參考 并行學習指南
- 使用較大的 max_bin (學習速度可能變慢)
- 使用較小的 learning_rate 和較大的 num_iterations
- 使用較大的 num_leaves (可能導致過擬合)
- 使用更大的訓練數據
- 嘗試 dart
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-ggU5Xc7j-1644978859938)(Task4_files/Task4_16_0.png)]
xgboost模型
XGBoost基礎介紹
XGBoost參數介紹
XGboost參數調優方法
https://blog.csdn.net/han_xiaoyang/article/details/52665396
https://www.cnblogs.com/TimVerion/p/11436001.html
交叉驗證
交叉驗證是驗證分類器性能的一種統計分析方法,其基本思想在某種意義下將原始數據進行分組,一部分作為訓練集,另一部分作為驗證集。首先是用訓練集對分類器進行訓練,再利用驗證集來測試所得到的的模型,以此來作為評價分類器的性能指標。常用的交叉驗證方法包括簡單交叉驗證、K折交叉驗證、留一法交叉驗證和留P法交叉驗證
1.簡單交叉驗證(cross validation)
簡單交叉驗證是將原始數據分為兩組,一組作為訓練集,另一組作為驗證集,利用訓練集訓練分類器,然后利用驗證集驗證模型,將最后的分類準確率作為此分類器的性能指標。通常是劃分30%的數據作為測試數據
2.K折交叉驗證(K-Fold cross validation)
K折交叉驗證是將原始數據分為K組,然后將每個子集數據分別做一次驗證集,其余的K-1組子集作為訓練集,這樣就會得到K個模型,將K個模型最終的驗證集的分類準確率取平均值,作為K折交叉驗證分類器的性能指標。通常設置為K為5或者10.
3.留一法交叉驗證(Leave-One-Out Cross Validation,LOO-CV)
留一法交叉驗證是指每個訓練集由除一個樣本之外的其余樣本組成,留下的一個樣本組成檢驗集。這樣對于N個樣本的數據集,可以組成N個不同的訓練集和N個不同的驗證集,因此該方法會得到N個模型,用N個模型最終的驗證集的分類準確率的平均是作為分類器的性能指標。
4.留P法交叉驗證
該方法與留一法類似,是從完整數據集中刪除P個樣本,產生所有可能的訓練集和驗證集。
交叉驗證示例代碼
1.簡單交叉驗證
from sklearn.model_selection import train_test_split from sklearn import datasets #數據集導入 iris=datasets.load_iris() feature=iris.feature_names X = iris.data y = iris.target X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.4,random_state=0)2.K折交叉驗證
from sklearn.model_selection import KFold folds = KFold(n_splits=10, shuffle=is_shuffle)3.留一法交叉驗證
from sklearn.model_selection import LeaveOneOut loo=LeaveOneOut()4.留P法交叉驗證
from sklearn.model_selection import LeavePOut lpo=LeavePOut(p=5)另外還有一些其他交叉驗證的分割方法,如基于類標簽,具有分層的交叉驗證。這一類交叉驗證方法主要用于解決樣本不平衡的問題。
這種情況下常用StratifiedKFold和StratifiedShuffleSplit的分層抽樣方法,可以確保相應的類別頻率在每個訓練和驗證的(fold)中得以保留。
StratifiedKFold:是K-fold的變種,會返回stratified(分層)的折疊:每個小集合中的各個類別的樣本比例大致和完整數據集相同。
StratifiedShuffleSplit:是ShuffleSplit的一種變種,會返回直接的劃分,比如創建一個劃分,但是劃分中的每個類的比例和完整數據集中的相同。
模型調參
調參就是對模型的參數進行調整,找到使模型最優的超參數,調參的目標就是盡可能達到整體模型的最優
1.網格搜索
網格搜索就是一種窮舉搜索,在所有候選的參數選擇中通過循環遍歷去在所有候選參數中尋找表現最好的結果。
2.學習曲線
學習曲線是在訓練集大小不同時通過繪制模型訓練集和交叉驗證集上的準確率來觀察模型在新數據上的表現,進而來判斷模型是否方差偏高或偏差過高,以及增大訓練集是否可以減小過擬合。
左上角的偏差很高,訓練集和驗證集的準確率都很低,很可能是欠擬合。
我們可以增加模型參數,比如,構建更多的特征,減小正則項。
此時通過增加數據量是不起作用的。
2、當訓練集和測試集的誤差之間有大的差距時,為高方差。
當訓練集的準確率比其他獨立數據集上的測試結果的準確率要高時,一般都是過擬合。
右上角方差很高,訓練集和驗證集的準確率相差太多,應該是過擬合。
我們可以增大訓練集,降低模型復雜度,增大正則項,或者通過特征選擇減少特征數。
理想情況是是找到偏差和方差都很小的情況,即收斂且誤差較小。
3.驗證曲線
和學習曲線不同,驗證曲線的橫軸為某個超參數的一系列值,由此比較不同超參數設置下的模型準確值。從下圖的驗證曲線可以看到,隨著超參數設置的改變,模型可能會有從欠擬合到合適再到過擬合的過程,進而可以選擇一個合適的超參數設置來提高模型的性能。
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-BQ5IURz1-1644978859940)(attachment:image.png)]
智慧海洋數據集模型代碼示例
lightGBM模型
import pandas as pd import numpy as np from tqdm import tqdm from sklearn.metrics import classification_report, f1_score from sklearn.model_selection import StratifiedKFold, KFold,train_test_split import lightgbm as lgb import os import warnings from hyperopt import fmin, tpe, hp, STATUS_OK, Trials all_df=pd.read_csv('group_df.csv',index_col=0) use_train = all_df[all_df['label'] != -1] use_test = all_df[all_df['label'] == -1]#label為-1時是測試集 use_feats = [c for c in use_train.columns if c not in ['ID', 'label']] X_train,X_verify,y_train,y_verify= train_test_split(use_train[use_feats],use_train['label'],test_size=0.3,random_state=0)1.根據特征的重要性進行特征選擇
##############特征選擇參數################### selectFeatures = 200 # 控制特征數 earlyStopping = 100 # 控制早停 select_num_boost_round = 1000 # 特征選擇訓練輪次 #首先設置基礎參數 selfParam = {'learning_rate':0.01, # 學習率'boosting':'dart', # 算法類型, gbdt,dart'objective':'multiclass', # 多分類'metric':'None','num_leaves':32, # 'feature_fraction':0.7, # 訓練特征比例'bagging_fraction':0.8, # 訓練樣本比例 'min_data_in_leaf':30, # 葉子最小樣本'num_class': 3,'max_depth':6, # 樹的最大深度'num_threads':8,#LightGBM 的線程數'min_data_in_bin':30, # 單箱數據量'max_bin':256, # 最大分箱數 'is_unbalance':True, # 非平衡樣本'train_metric':True,'verbose':-1, } # 特征選擇 --------------------------------------------------------------------------------- def f1_score_eval(preds, valid_df):labels = valid_df.get_label()preds = np.argmax(preds.reshape(3, -1), axis=0)scores = f1_score(y_true=labels, y_pred=preds, average='macro')return 'f1_score', scores, Truetrain_data = lgb.Dataset(data=X_train,label=y_train,feature_name=use_feats) valid_data = lgb.Dataset(data=X_verify,label=y_verify,reference=train_data,feature_name=use_feats)sm = lgb.train(params=selfParam,train_set=train_data,num_boost_round=select_num_boost_round,valid_sets=[valid_data],valid_names=['valid'],feature_name=use_feats,early_stopping_rounds=earlyStopping,verbose_eval=False,keep_training_booster=True,feval=f1_score_eval) features_importance = {k:v for k,v in zip(sm.feature_name(),sm.feature_importance(iteration=sm.best_iteration))} sort_feature_importance = sorted(features_importance.items(),key=lambda x:x[1],reverse=True) print('total feature best score:', sm.best_score) print('total feature importance:',sort_feature_importance) print('select forward {} features:{}'.format(selectFeatures,sort_feature_importance[:selectFeatures])) #model_feature是選擇的超參數 model_feature = [k[0] for k in sort_feature_importance[:selectFeatures]] D:\SOFTWEAR_H\Anaconda3\lib\site-packages\lightgbm\callback.py:186: UserWarning: Early stopping is not available in dart modewarnings.warn('Early stopping is not available in dart mode')total feature best score: defaultdict(<class 'collections.OrderedDict'>, {'valid': OrderedDict([('f1_score', 0.9004541298211368)])}) total feature importance: [('pos_neq_zero_speed_q_40', 1783), ('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369), ('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194), ('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963), ('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874), ('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70', 856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772), ('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x', 729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693), ('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean', 626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607), ('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572), ('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean', 557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503), ('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490), ('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x', 470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique', 460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431), ('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420), ('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407), ('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397), ('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384), ('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367), ('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364), ('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341), ('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332), ('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326), ('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323), ('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean', 312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20', 304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max', 301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300), ('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294), ('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291), ('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283), ('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279), ('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271), ('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264), ('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258), ('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248), ('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241), ('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238), ('w2v_14_mean', 236), ('lat_nunique', 235), ('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216), ('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214), ('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212), ('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208), ('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203), ('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201), ('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x', 197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193), ('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y', 191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x', 187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean', 184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean', 181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180), ('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178), ('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174), ('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172), ('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171), ('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x', 167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163), ('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162), ('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157), ('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154), ('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151), ('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148), ('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146), ('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y', 144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142), ('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x', 136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136), ('direction_max', 135), ('pos_neq_zero_lon_median', 134), ('lat_lon_tfidf_27_x', 134), ('lon_q_80', 133), ('lat_lon_countvec_15_x', 133), ('pos_neq_zero_lon_max', 132), ('lat_lon_countvec_14_x', 132), ('lat_lon_tfidf_26_x', 131), ('grad_tfidf_19_x', 131), ('sample_tfidf_8_x', 131), ('lat_q_60', 130), ('sample_tfidf_28_x', 130), ('lat_lon_countvec_27_y', 130), ('lat_lon_countvec_6_x', 128), ('lat_lon_countvec_0_y', 128), ('sample_tfidf_12_y', 127), ('lat_lon_tfidf_8_x', 126), ('sample_tfidf_29_y', 126), ('lat_lon_countvec_17_x', 125), ('direction_q_70', 124), ('lat_lon_tfidf_20_y', 124), ('lat_lon_tfidf_3_x', 121), ('sample_tfidf_21_y', 120), ('grad_tfidf_0_y', 119), ('pos_neq_zero_lat_median', 118), ('lat_lon_tfidf_16_y', 118), ('grad_tfidf_10_x', 117), ('sample_tfidf_2_x', 116), ('lat_lon_countvec_4_y', 116), ('speed_median', 115), ('pos_neq_zero_direction_q_10', 115), ('speed_neq_zero_lon_mean', 115), ('pos_neq_zero_direction_max', 114), ('lat_q_40', 113), ('grad_tfidf_1_x', 113), ('speed_nunique', 111), ('sample_tfidf_23_y', 111), ('speed_q_30', 110), ('pos_neq_zero_lat_q_30', 110), ('lat_lon_tfidf_10_x', 110), ('lat_lon_countvec_10_x', 110), ('lat_lon_tfidf_23_y', 109), ('pos_neq_zero_speed_min', 106), ('speed_neq_zero_lat_q_60', 106), ('lat_lon_countvec_21_x', 106), ('lat_lon_countvec_18_y', 106), ('lat_lon_tfidf_17_x', 105), ('grad_tfidf_8_y', 103), ('grad_tfidf_6_y', 102), ('sample_tfidf_10_y', 101), ('pos_neq_zero_lon_min', 100), ('lat_lon_countvec_8_x', 100), ('lat_lon_countvec_9_y', 100), ('direction_mean', 99), ('grad_tfidf_21_y', 99), ('lat_lon_tfidf_6_x', 98), ('lat_lon_tfidf_18_y', 97), ('direction_q_80', 96), ('pos_neq_zero_direction_q_70', 96), ('lat_lon_countvec_20_x', 95), ('speed_neq_zero_direction_q_70', 93), ('lat_lon_countvec_25_x', 93), ('lat_lon_countvec_23_x', 92), ('lat_lon_tfidf_14_y', 92), ('lat_q_90', 91), ('sample_tfidf_7_x', 91), ('pos_neq_zero_lon_q_70', 90), ('lat_lon_countvec_5_y', 90), ('pos_neq_zero_direction_q_20', 89), ('lat_lon_tfidf_12_y', 89), ('lat_lon_tfidf_28_y', 89), ('sample_tfidf_4_y', 89), ('direction_q_40', 88), ('pos_neq_zero_lat_q_20', 87), ('grad_tfidf_17_x', 87), ('sample_tfidf_9_y', 87), ('sample_tfidf_24_y', 87), ('pos_neq_zero_lat_max', 86), ('pos_neq_zero_lon_mean', 86), ('speed_neq_zero_direction_q_40', 86), ('lat_lon_countvec_7_x', 86), ('speed_neq_zero_speed_q_40', 85), ('sample_tfidf_6_x', 84), ('sample_tfidf_19_y', 84), ('speed_min', 83), ('direction_q_10', 83), ('lat_lon_countvec_19_x', 83), ('grad_tfidf_24_x', 83), ('speed_q_60', 82), ('lat_lon_tfidf_25_x', 82), ('sample_tfidf_3_x', 82), ('grad_tfidf_22_y', 82), ('direction_q_30', 80), ('speed_neq_zero_direction_mean', 80), ('grad_tfidf_18_y', 77), ('lat_q_10', 76), ('speed_neq_zero_speed_max', 75), ('grad_tfidf_3_y', 75), ('sample_tfidf_11_y', 75), ('lon_nunique', 74), ('lon_q_90', 74), ('speed_neq_zero_lon_q_10', 74), ('speed_neq_zero_speed_median', 74), ('grad_tfidf_28_x', 74), ('grad_tfidf_20_y', 74), ('speed_neq_zero_lon_q_70', 73), ('lat_lon_tfidf_24_y', 73), ('pos_neq_zero_lat_q_60', 72), ('lat_lon_countvec_2_x', 72), ('lat_lon_countvec_3_x', 69), ('sample_tfidf_20_y', 69), ('lat_lon_tfidf_13_x', 68), ('grad_tfidf_16_y', 68), ('sample_tfidf_13_y', 67), ('speed_neq_zero_lon_q_30', 66), ('speed_q_40', 65), ('grad_tfidf_4_y', 65), ('sample_tfidf_5_y', 65), ('lat_q_30', 64), ('pos_neq_zero_direction_median', 64), ('speed_neq_zero_lat_median', 64), ('grad_tfidf_11_y', 64), ('grad_tfidf_27_y', 64), ('lat_lon_tfidf_19_y', 62), ('pos_neq_zero_lon_q_40', 61), ('lat_lon_countvec_26_y', 61), ('pos_neq_zero_lon_q_80', 60), ('sample_tfidf_17_y', 60), ('lon_q_40', 59), ('lat_lon_countvec_28_y', 59), ('lat_lon_tfidf_22_y', 57), ('grad_tfidf_29_x', 56), ('lat_lon_countvec_12_y', 56), ('sample_tfidf_15_y', 56), ('sample_tfidf_27_y', 56), ('speed_q_70', 55), ('lat_lon_tfidf_21_y', 55), ('grad_tfidf_9_y', 55), ('sample_tfidf_25_y', 55), ('pos_neq_zero_direction_mean', 54), ('sample_tfidf_26_x', 54), ('sample_tfidf_18_y', 53), ('speed_neq_zero_lon_q_90', 51), ('speed_neq_zero_direction_max', 51), ('lat_lon_tfidf_5_y', 50), ('pos_neq_zero_direction_q_60', 49), ('sample_tfidf_2_y', 49), ('pos_neq_zero_lon_q_60', 48), ('speed_neq_zero_speed_mean', 48), ('lat_lon_tfidf_15_y', 48), ('pos_neq_zero_direction_q_30', 47), ('speed_neq_zero_lon_nunique', 47), ('lat_lon_countvec_24_x', 47), ('sample_tfidf_8_y', 47), ('lat_lon_tfidf_10_y', 46), ('lon_q_60', 45), ('pos_neq_zero_lat_q_70', 45), ('speed_neq_zero_direction_q_10', 45), ('lat_lon_tfidf_3_y', 45), ('speed_neq_zero_lat_mean', 43), ('speed_neq_zero_lat_q_80', 43), ('lat_lon_tfidf_2_y', 43), ('lat_lon_tfidf_8_y', 43), ('grad_tfidf_19_y', 43), ('grad_tfidf_25_y', 43), ('grad_tfidf_26_y', 43), ('lon_q_30', 42), ('speed_neq_zero_lon_q_20', 42), ('pos_neq_zero_speed_nunique', 41), ('speed_neq_zero_speed_nunique', 41), ('speed_neq_zero_speed_q_30', 41), ('lat_lon_tfidf_7_y', 41), ('lat_lon_tfidf_17_y', 41), ('lat_lon_countvec_14_y', 41), ('grad_tfidf_10_y', 41), ('lat_lon_tfidf_26_y', 40), ('grad_tfidf_1_y', 40), ('speed_neq_zero_lat_q_20', 39), ('speed_q_80', 38), ('speed_neq_zero_lat_q_30', 38), ('lat_lon_countvec_15_y', 38), ('pos_neq_zero_direction_q_40', 37), ('speed_neq_zero_direction_median', 37), ('pos_neq_zero_lon_q_30', 36), ('lat_lon_countvec_11_y', 36), ('lat_lon_countvec_21_y', 35), ('sample_tfidf_28_y', 35), ('speed_neq_zero_speed_q_60', 34), ('lat_lon_countvec_29_y', 34), ('sample_tfidf_1_y', 34), ('sample_tfidf_22_y', 34), ('lat_lon_countvec_6_y', 33), ('lat_lon_countvec_10_y', 33), ('lat_lon_countvec_16_y', 33), ('speed_mean', 32), ('lat_lon_countvec_17_y', 31), ('lat_lon_countvec_23_y', 31), ('speed_neq_zero_direction_q_30', 30), ('lat_lon_tfidf_13_y', 30), ('sample_tfidf_16_y', 30), ('speed_neq_zero_lat_q_10', 29), ('lat_lon_tfidf_27_y', 29), ('grad_tfidf_17_y', 29), ('lat_lon_countvec_13_x', 27), ('lat_lon_countvec_19_y', 27), ('grad_tfidf_24_y', 26), ('speed_neq_zero_lon_q_40', 25), ('lat_lon_tfidf_25_y', 25), ('lat_lon_countvec_8_y', 25), ('speed_neq_zero_lon_median', 24), ('speed_neq_zero_speed_min', 24), ('lat_lon_countvec_25_y', 24), ('sample_tfidf_6_y', 24), ('pos_neq_zero_lat_nunique', 23), ('speed_neq_zero_lon_q_80', 23), ('lat_lon_countvec_20_y', 23), ('speed_neq_zero_speed_q_10', 22), ('lat_lon_countvec_3_y', 22), ('grad_tfidf_28_y', 22), ('sample_tfidf_7_y', 22), ('lat_lon_countvec_7_y', 21), ('sample_tfidf_26_y', 21), ('lat_lon_tfidf_6_y', 20), ('sample_tfidf_3_y', 20), ('grad_tfidf_29_y', 18), ('speed_neq_zero_lon_q_60', 16), ('speed_neq_zero_speed_q_20', 14), ('lat_lon_countvec_24_y', 14), ('lat_lon_countvec_2_y', 11), ('speed_neq_zero_direction_q_20', 9), ('lat_lon_countvec_13_y', 8), ('speed_q_10', 7), ('pos_neq_zero_lon_nunique', 5), ('direction_q_20', 4), ('speed_q_20', 2), ('pos_neq_zero_direction_min', 2), ('direction_min', 0), ('speed_neq_zero_direction_min', 0)] select forward 200 features:[('pos_neq_zero_speed_q_40', 1783), ('lat_lon_countvec_1_x', 1771), ('rank2_mode_lat', 1737), ('pos_neq_zero_speed_median', 1379), ('pos_neq_zero_speed_q_60', 1369), ('lat_lon_tfidf_0_x', 1251), ('pos_neq_zero_speed_q_80', 1194), ('sample_tfidf_0_x', 1168), ('w2v_9_mean', 1134), ('lat_lon_tfidf_11_x', 963), ('rank3_mode_lat', 946), ('w2v_5_mean', 900), ('w2v_16_mean', 874), ('pos_neq_zero_speed_q_30', 866), ('w2v_12_mean', 862), ('pos_neq_zero_speed_q_70', 856), ('lat_lon_tfidf_9_x', 787), ('grad_tfidf_7_x', 772), ('pos_neq_zero_speed_q_90', 746), ('rank3_mode_cnt', 733), ('grad_tfidf_12_x', 729), ('w2v_4_mean', 697), ('sample_tfidf_14_x', 695), ('lat_lon_tfidf_4_x', 693), ('lat_min', 683), ('w2v_23_mean', 647), ('rank2_mode_lon', 631), ('w2v_26_mean', 626), ('rank1_mode_lon', 620), ('grad_tfidf_15_x', 607), ('speed_neq_zero_speed_q_90', 603), ('grad_tfidf_5_x', 572), ('lat_lon_countvec_22_x', 571), ('lat_lon_countvec_1_y', 565), ('w2v_13_mean', 557), ('w2v_27_mean', 550), ('grad_tfidf_2_x', 507), ('lat_lon_tfidf_20_x', 503), ('lat_lon_countvec_0_x', 499), ('lat_lon_countvec_18_x', 490), ('sample_tfidf_21_x', 488), ('grad_tfidf_14_x', 484), ('lat_lon_countvec_27_x', 470), ('w2v_22_mean', 466), ('lat_lon_tfidf_1_x', 461), ('direction_nunique', 460), ('lon_max', 457), ('w2v_15_mean', 441), ('grad_tfidf_23_x', 431), ('w2v_19_mean', 429), ('w2v_11_mean', 428), ('lat_lon_tfidf_29_x', 420), ('pos_neq_zero_lon_q_10', 417), ('w2v_3_mean', 411), ('lat_lon_tfidf_0_y', 407), ('sample_tfidf_29_x', 406), ('anchor_cnt', 404), ('grad_tfidf_8_x', 397), ('sample_tfidf_10_x', 397), ('sample_tfidf_12_x', 385), ('w2v_28_mean', 384), ('grad_tfidf_13_x', 381), ('direction_q_90', 380), ('speed_neq_zero_lon_min', 374), ('w2v_25_mean', 371), ('anchor_ratio', 367), ('lat_lon_tfidf_16_x', 367), ('rank1_mode_lat', 365), ('w2v_18_mean', 365), ('sample_tfidf_23_x', 364), ('lon_min', 354), ('grad_tfidf_0_x', 351), ('pos_neq_zero_lat_q_90', 341), ('w2v_20_mean', 341), ('sample_tfidf_4_x', 334), ('lat_lon_tfidf_23_x', 332), ('sample_tfidf_0_y', 328), ('pos_neq_zero_direction_q_90', 326), ('speed_neq_zero_direction_nunique', 326), ('sample_tfidf_19_x', 323), ('lat_lon_countvec_9_x', 319), ('pos_neq_zero_lon_q_90', 314), ('w2v_8_mean', 312), ('grad_tfidf_3_x', 309), ('lon_median', 305), ('pos_neq_zero_speed_q_20', 304), ('lat_lon_countvec_4_x', 304), ('lat_mean', 301), ('speed_neq_zero_lon_max', 301), ('lat_lon_tfidf_14_x', 301), ('speed_neq_zero_lat_min', 300), ('lat_lon_countvec_5_x', 296), ('speed_neq_zero_speed_q_80', 294), ('grad_tfidf_16_x', 293), ('rank3_mode_lon', 292), ('lat_lon_tfidf_18_x', 291), ('w2v_7_mean', 290), ('grad_tfidf_6_x', 285), ('grad_tfidf_20_x', 283), ('grad_tfidf_18_x', 282), ('w2v_0_mean', 280), ('grad_tfidf_21_x', 279), ('grad_tfidf_22_x', 273), ('sample_tfidf_24_x', 273), ('speed_q_90', 271), ('w2v_2_mean', 271), ('lat_max', 264), ('sample_tfidf_9_x', 264), ('grad_tfidf_11_x', 262), ('lon_q_20', 260), ('rank1_mode_cnt', 258), ('speed_max', 256), ('lat_lon_tfidf_12_x', 251), ('pos_neq_zero_lon_q_20', 248), ('lat_lon_tfidf_28_x', 242), ('speed_neq_zero_direction_q_60', 241), ('sample_tfidf_11_x', 241), ('w2v_17_mean', 241), ('sample_tfidf_13_x', 238), ('w2v_14_mean', 236), ('lat_nunique', 235), ('grad_tfidf_4_x', 234), ('w2v_21_mean', 234), ('sample_tfidf_5_x', 231), ('lat_lon_tfidf_9_y', 225), ('speed_neq_zero_lat_q_90', 222), ('direction_median', 221), ('sample_tfidf_17_x', 220), ('sample_tfidf_14_y', 216), ('lat_lon_tfidf_21_x', 215), ('lon_q_10', 214), ('lat_lon_tfidf_22_x', 214), ('grad_tfidf_26_x', 213), ('grad_tfidf_7_y', 213), ('w2v_29_mean', 212), ('pos_neq_zero_lat_q_80', 210), ('cnt', 209), ('lat_lon_tfidf_4_y', 208), ('direction_q_60', 204), ('sample_tfidf_18_x', 203), ('lat_lon_tfidf_11_y', 203), ('pos_neq_zero_lat_min', 202), ('pos_neq_zero_speed_mean', 201), ('speed_neq_zero_lat_q_70', 200), ('grad_tfidf_12_y', 198), ('sample_tfidf_20_x', 197), ('w2v_1_mean', 194), ('speed_neq_zero_lat_q_40', 193), ('pos_neq_zero_speed_max', 192), ('grad_tfidf_27_x', 192), ('grad_tfidf_15_y', 191), ('lat_lon_tfidf_19_x', 189), ('lat_median', 187), ('lat_lon_tfidf_15_x', 187), ('lat_q_20', 186), ('lat_q_70', 186), ('lon_q_70', 185), ('w2v_24_mean', 184), ('pos_neq_zero_lat_q_40', 183), ('grad_tfidf_25_x', 181), ('w2v_10_mean', 181), ('lon_mean', 180), ('sample_tfidf_27_x', 180), ('w2v_6_mean', 180), ('lat_lon_tfidf_24_x', 178), ('lat_lon_countvec_12_x', 178), ('pos_neq_zero_lat_mean', 177), ('speed_neq_zero_speed_q_70', 174), ('speed_neq_zero_direction_q_80', 172), ('rank2_mode_cnt', 172), ('speed_neq_zero_lat_nunique', 171), ('lat_lon_tfidf_2_x', 171), ('sample_tfidf_25_x', 170), ('lat_lon_tfidf_5_x', 169), ('lat_lon_countvec_26_x', 167), ('grad_tfidf_9_x', 166), ('lat_lon_countvec_28_x', 163), ('lat_lon_countvec_22_y', 163), ('sample_tfidf_1_x', 162), ('pos_neq_zero_direction_nunique', 161), ('pos_neq_zero_speed_q_10', 157), ('sample_tfidf_16_x', 155), ('speed_neq_zero_direction_q_90', 154), ('grad_tfidf_14_y', 153), ('lat_lon_tfidf_7_x', 151), ('pos_neq_zero_direction_q_80', 149), ('lat_q_80', 148), ('grad_tfidf_23_y', 148), ('lat_lon_countvec_11_x', 147), ('sample_tfidf_22_x', 146), ('speed_neq_zero_lat_max', 144), ('sample_tfidf_15_x', 144), ('grad_tfidf_2_y', 144), ('pos_neq_zero_lat_q_10', 142), ('lat_lon_tfidf_1_y', 142), ('lat_lon_countvec_16_x', 141), ('grad_tfidf_13_y', 138), ('lat_lon_countvec_29_x', 136), ('lat_lon_tfidf_29_y', 136), ('grad_tfidf_5_y', 136)]貝葉斯優化介紹也是在建模調參過程中常用的一種方法,下面是通過貝葉斯優化進行超參數選擇的代碼
##############超參數優化的超參域################### spaceParam = {'boosting': hp.choice('boosting',['gbdt','dart']),'learning_rate':hp.loguniform('learning_rate', np.log(0.01), np.log(0.05)),'num_leaves': hp.quniform('num_leaves', 3, 66, 3), 'feature_fraction': hp.uniform('feature_fraction', 0.7,1),'min_data_in_leaf': hp.quniform('min_data_in_leaf', 10, 50,5), 'num_boost_round':hp.quniform('num_boost_round',500,2000,100), 'bagging_fraction':hp.uniform('bagging_fraction',0.6,1) } # 超參數優化 --------------------------------------------------------------------------------- def getParam(param):for k in ['num_leaves', 'min_data_in_leaf','num_boost_round']:param[k] = int(float(param[k]))for k in ['learning_rate', 'feature_fraction','bagging_fraction']:param[k] = float(param[k])if param['boosting'] == 0:param['boosting'] = 'gbdt'elif param['boosting'] == 1:param['boosting'] = 'dart'# 添加固定參數param['objective'] = 'multiclass'param['max_depth'] = 7param['num_threads'] = 8param['is_unbalance'] = Trueparam['metric'] = 'None'param['train_metric'] = Trueparam['verbose'] = -1param['bagging_freq']=5param['num_class']=3 param['feature_pre_filter']=Falsereturn param def f1_score_eval(preds, valid_df):labels = valid_df.get_label()preds = np.argmax(preds.reshape(3, -1), axis=0)scores = f1_score(y_true=labels, y_pred=preds, average='macro')return 'f1_score', scores, True def lossFun(param):param = getParam(param)m = lgb.train(params=param,train_set=train_data,num_boost_round=param['num_boost_round'],valid_sets=[train_data,valid_data],valid_names=['train','valid'],feature_name=features,feval=f1_score_eval,early_stopping_rounds=earlyStopping,verbose_eval=False,keep_training_booster=True)train_f1_score = m.best_score['train']['f1_score']valid_f1_score = m.best_score['valid']['f1_score']loss_f1_score = 1 - valid_f1_scoreprint('訓練集f1_score:{},測試集f1_score:{},loss_f1_score:{}'.format(train_f1_score, valid_f1_score, loss_f1_score))return {'loss': loss_f1_score, 'params': param, 'status': STATUS_OK}features = model_feature train_data = lgb.Dataset(data=X_train[model_feature],label=y_train,feature_name=features) valid_data = lgb.Dataset(data=X_verify[features],label=y_verify,reference=train_data,feature_name=features)best_param = fmin(fn=lossFun, space=spaceParam, algo=tpe.suggest, max_evals=100, trials=Trials()) best_param = getParam(best_param) print('Search best param:',best_param) 訓練集f1_score:1.0,測試集f1_score:0.9238060849905194,loss_f1_score:0.07619391500948058 訓練集f1_score:0.9414337502771342,測試集f1_score:0.8878751759836653,loss_f1_score:0.11212482401633472 訓練集f1_score:1.0,測試集f1_score:0.9275451088133652,loss_f1_score:0.07245489118663484 訓練集f1_score:1.0,測試集f1_score:0.9262405937033683,loss_f1_score:0.07375940629663169 訓練集f1_score:0.9708237804866381,測試集f1_score:0.9105982243190386,loss_f1_score:0.08940177568096142 訓練集f1_score:0.9689912364726484,測試集f1_score:0.9086459359345839,loss_f1_score:0.09135406406541613 訓練集f1_score:0.9841597696688008,測試集f1_score:0.9027075194168233,loss_f1_score:0.09729248058317674 訓練集f1_score:1.0,測試集f1_score:0.9215512877825286,loss_f1_score:0.0784487122174714 訓練集f1_score:1.0,測試集f1_score:0.924555451978199,loss_f1_score:0.075444548021801 訓練集f1_score:0.998357894114157,測試集f1_score:0.9157797895654226,loss_f1_score:0.08422021043457739 訓練集f1_score:1.0,測試集f1_score:0.9225868784774544,loss_f1_score:0.07741312152254565 訓練集f1_score:1.0,測試集f1_score:0.9188521505717673,loss_f1_score:0.08114784942823272 訓練集f1_score:0.9268245763808158,測試集f1_score:0.8763935795977332,loss_f1_score:0.12360642040226677 訓練集f1_score:1.0,測試集f1_score:0.9215959099478135,loss_f1_score:0.07840409005218651 訓練集f1_score:1.0,測試集f1_score:0.9265015559936258,loss_f1_score:0.07349844400637418 訓練集f1_score:1.0,測試集f1_score:0.9143628354188641,loss_f1_score:0.0856371645811359 訓練集f1_score:1.0,測試集f1_score:0.9202754009210264,loss_f1_score:0.07972459907897356 訓練集f1_score:0.9550283459834631,測試集f1_score:0.8923546584333147,loss_f1_score:0.10764534156668526 訓練集f1_score:1.0,測試集f1_score:0.9255732985564632,loss_f1_score:0.0744267014435368 訓練集f1_score:1.0,測試集f1_score:0.926093875740129,loss_f1_score:0.07390612425987098 訓練集f1_score:1.0,測試集f1_score:0.9275189170142104,loss_f1_score:0.07248108298578959 訓練集f1_score:1.0,測試集f1_score:0.9257895202231272,loss_f1_score:0.07421047977687278 訓練集f1_score:1.0,測試集f1_score:0.9248738969479765,loss_f1_score:0.0751261030520235 訓練集f1_score:1.0,測試集f1_score:0.9272520229049039,loss_f1_score:0.07274797709509606 訓練集f1_score:1.0,測試集f1_score:0.9256769527801775,loss_f1_score:0.07432304721982252 訓練集f1_score:1.0,測試集f1_score:0.9252959646692677,loss_f1_score:0.07470403533073233 訓練集f1_score:1.0,測試集f1_score:0.9280536344383128,loss_f1_score:0.07194636556168721 訓練集f1_score:1.0,測試集f1_score:0.9316114105930104,loss_f1_score:0.06838858940698955 訓練集f1_score:1.0,測試集f1_score:0.9282603014798921,loss_f1_score:0.07173969852010786 訓練集f1_score:1.0,測試集f1_score:0.9169851848129301,loss_f1_score:0.08301481518706988 訓練集f1_score:0.9998006409358186,測試集f1_score:0.9170084634982812,loss_f1_score:0.08299153650171875 訓練集f1_score:1.0,測試集f1_score:0.919142326688697,loss_f1_score:0.080857673311303 訓練集f1_score:1.0,測試集f1_score:0.927350422658861,loss_f1_score:0.07264957734113897 訓練集f1_score:1.0,測試集f1_score:0.9248086877712395,loss_f1_score:0.07519131222876052 訓練集f1_score:1.0,測試集f1_score:0.9170626453496801,loss_f1_score:0.08293735465031993 訓練集f1_score:1.0,測試集f1_score:0.9277641923766077,loss_f1_score:0.07223580762339232 訓練集f1_score:1.0,測試集f1_score:0.9221988666312404,loss_f1_score:0.0778011333687596 訓練集f1_score:1.0,測試集f1_score:0.9225220095934339,loss_f1_score:0.07747799040656611 訓練集f1_score:1.0,測試集f1_score:0.9239565521812777,loss_f1_score:0.0760434478187223 訓練集f1_score:1.0,測試集f1_score:0.9276828960144917,loss_f1_score:0.07231710398550828 訓練集f1_score:1.0,測試集f1_score:0.9205931627810685,loss_f1_score:0.07940683721893149 訓練集f1_score:1.0,測試集f1_score:0.9262928923256212,loss_f1_score:0.07370710767437882 訓練集f1_score:0.9944566925965641,測試集f1_score:0.9103100448505551,loss_f1_score:0.08968995514944489 訓練集f1_score:1.0,測試集f1_score:0.9267901922541096,loss_f1_score:0.07320980774589037 訓練集f1_score:1.0,測試集f1_score:0.920503002249437,loss_f1_score:0.07949699775056296 訓練集f1_score:0.9315809154440894,測試集f1_score:0.888114739372245,loss_f1_score:0.11188526062775495 訓練集f1_score:1.0,測試集f1_score:0.9312944518110373,loss_f1_score:0.06870554818896268 訓練集f1_score:1.0,測試集f1_score:0.9303459748533314,loss_f1_score:0.06965402514666863 訓練集f1_score:1.0,測試集f1_score:0.931353840440614,loss_f1_score:0.06864615955938602 訓練集f1_score:1.0,測試集f1_score:0.9229280238009058,loss_f1_score:0.07707197619909423 訓練集f1_score:1.0,測試集f1_score:0.9081707271979852,loss_f1_score:0.0918292728020148 訓練集f1_score:1.0,測試集f1_score:0.9263682433473132,loss_f1_score:0.07363175665268684 訓練集f1_score:0.9979810910594639,測試集f1_score:0.9137152734108268,loss_f1_score:0.08628472658917319 訓練集f1_score:1.0,測試集f1_score:0.9258220879299731,loss_f1_score:0.07417791207002689 訓練集f1_score:1.0,測試集f1_score:0.9174454505221505,loss_f1_score:0.08255454947784946 訓練集f1_score:1.0,測試集f1_score:0.9271364668867941,loss_f1_score:0.07286353311320592 訓練集f1_score:1.0,測試集f1_score:0.9147023183361269,loss_f1_score:0.08529768166387308 訓練集f1_score:0.9818127606280159,測試集f1_score:0.9017199309349478,loss_f1_score:0.09828006906505216 訓練集f1_score:1.0,測試集f1_score:0.9144702886766378,loss_f1_score:0.08552971132336218 訓練集f1_score:0.9987361493711533,測試集f1_score:0.9152462742627984,loss_f1_score:0.08475372573720164 訓練集f1_score:1.0,測試集f1_score:0.9283825864164065,loss_f1_score:0.07161741358359353 訓練集f1_score:1.0,測試集f1_score:0.9185245776900096,loss_f1_score:0.08147542230999039 訓練集f1_score:1.0,測試集f1_score:0.9176200948292667,loss_f1_score:0.08237990517073335 訓練集f1_score:0.9993129514194335,測試集f1_score:0.9174352830766729,loss_f1_score:0.08256471692332712 訓練集f1_score:1.0,測試集f1_score:0.9276704131051788,loss_f1_score:0.07232958689482116 訓練集f1_score:1.0,測試集f1_score:0.9268048760558437,loss_f1_score:0.07319512394415628 訓練集f1_score:1.0,測試集f1_score:0.9304568955332027,loss_f1_score:0.06954310446679735 訓練集f1_score:1.0,測試集f1_score:0.9222607611550148,loss_f1_score:0.07773923884498524 訓練集f1_score:1.0,測試集f1_score:0.9303686983620825,loss_f1_score:0.06963130163791753 訓練集f1_score:1.0,測試集f1_score:0.9275281467065163,loss_f1_score:0.07247185329348371 訓練集f1_score:1.0,測試集f1_score:0.9263494542572851,loss_f1_score:0.0736505457427149 訓練集f1_score:1.0,測試集f1_score:0.9262464202510822,loss_f1_score:0.07375357974891783 訓練集f1_score:1.0,測試集f1_score:0.9213298706249988,loss_f1_score:0.07867012937500117 訓練集f1_score:1.0,測試集f1_score:0.9255381820063792,loss_f1_score:0.07446181799362084 訓練集f1_score:1.0,測試集f1_score:0.9262492441399471,loss_f1_score:0.07375075586005286 訓練集f1_score:1.0,測試集f1_score:0.9267529385979496,loss_f1_score:0.0732470614020504 訓練集f1_score:1.0,測試集f1_score:0.9279362552557409,loss_f1_score:0.07206374474425914 訓練集f1_score:1.0,測試集f1_score:0.9105496558898486,loss_f1_score:0.0894503441101514 訓練集f1_score:1.0,測試集f1_score:0.9255677088759965,loss_f1_score:0.07443229112400351 訓練集f1_score:1.0,測試集f1_score:0.9258810998636311,loss_f1_score:0.0741189001363689 訓練集f1_score:1.0,測試集f1_score:0.9236045683410877,loss_f1_score:0.07639543165891227 訓練集f1_score:1.0,測試集f1_score:0.9236482035413927,loss_f1_score:0.07635179645860735 訓練集f1_score:0.9998006409358186,測試集f1_score:0.9161826380576955,loss_f1_score:0.08381736194230449 訓練集f1_score:1.0,測試集f1_score:0.9226427795765888,loss_f1_score:0.0773572204234112 訓練集f1_score:1.0,測試集f1_score:0.9227047668043812,loss_f1_score:0.07729523319561882 訓練集f1_score:1.0,測試集f1_score:0.9255689533534145,loss_f1_score:0.07443104664658551 訓練集f1_score:1.0,測試集f1_score:0.9322007348532765,loss_f1_score:0.06779926514672352 訓練集f1_score:1.0,測試集f1_score:0.9169573599775939,loss_f1_score:0.08304264002240613 訓練集f1_score:1.0,測試集f1_score:0.9230059720988804,loss_f1_score:0.07699402790111964 訓練集f1_score:1.0,測試集f1_score:0.922697478395862,loss_f1_score:0.07730252160413797 訓練集f1_score:1.0,測試集f1_score:0.9079606352786754,loss_f1_score:0.09203936472132457 訓練集f1_score:1.0,測試集f1_score:0.9229248123974857,loss_f1_score:0.0770751876025143 訓練集f1_score:1.0,測試集f1_score:0.923913432252704,loss_f1_score:0.07608656774729605 訓練集f1_score:1.0,測試集f1_score:0.9257200990324236,loss_f1_score:0.07427990096757642 訓練集f1_score:1.0,測試集f1_score:0.9276995504041144,loss_f1_score:0.07230044959588555 訓練集f1_score:1.0,測試集f1_score:0.9251348482525271,loss_f1_score:0.07486515174747288 訓練集f1_score:1.0,測試集f1_score:0.9231090610362633,loss_f1_score:0.07689093896373667 訓練集f1_score:1.0,測試集f1_score:0.9164413618677342,loss_f1_score:0.08355863813226583 訓練集f1_score:1.0,測試集f1_score:0.9293008018695311,loss_f1_score:0.07069919813046888 訓練集f1_score:1.0,測試集f1_score:0.9301285411934597,loss_f1_score:0.06987145880654033 100%|█████████████████████████████████████████████| 100/100 [33:56<00:00, 20.36s/trial, best loss: 0.06779926514672352] Search best param: {'bagging_fraction': 0.7310343530671259, 'boosting': 'gbdt', 'feature_fraction': 0.8644701162989126, 'learning_rate': 0.0483933201073737, 'min_data_in_leaf': 15, 'num_boost_round': 1100, 'num_leaves': 60, 'objective': 'multiclass', 'max_depth': 7, 'num_threads': 8, 'is_unbalance': True, 'metric': 'None', 'train_metric': True, 'verbose': -1, 'bagging_freq': 5, 'num_class': 3, 'feature_pre_filter': False}經過特征選擇和超參數優化后,最終的模型使用為將參數設置為貝葉斯優化之后的超參數,然后進行5折交叉,對測試集進行疊加求平均。
def f1_score_eval(preds, valid_df):labels = valid_df.get_label()preds = np.argmax(preds.reshape(3, -1), axis=0)scores = f1_score(y_true=labels, y_pred=preds, average='macro')return 'f1_score', scores, Truedef sub_on_line_lgb(train_, test_, pred, label, cate_cols, split,is_shuffle=True,use_cart=False,get_prob=False):n_class = 3train_pred = np.zeros((train_.shape[0], n_class))test_pred = np.zeros((test_.shape[0], n_class))n_splits = 5assert split in ['kf', 'skf'], '{} Not Support this type of split way'.format(split)if split == 'kf':folds = KFold(n_splits=n_splits, shuffle=is_shuffle, random_state=1024)kf_way = folds.split(train_[pred])else:#與KFold最大的差異在于,他是分層采樣,確保訓練集,測試集中各類別樣本的比例與原始數據集中相同。folds = StratifiedKFold(n_splits=n_splits,shuffle=is_shuffle,random_state=1024)kf_way = folds.split(train_[pred], train_[label])print('Use {} features ...'.format(len(pred)))#將以下參數改為貝葉斯優化之后的參數params = {'learning_rate': 0.05,'boosting_type': 'gbdt','objective': 'multiclass','metric': 'None','num_leaves': 60,'feature_fraction':0.86,'bagging_fraction': 0.73,'bagging_freq': 5,'seed': 1,'bagging_seed': 1,'feature_fraction_seed': 7,'min_data_in_leaf': 15,'num_class': n_class,'nthread': 8,'verbose': -1,'num_boost_round': 1100,'max_depth': 7,}for n_fold, (train_idx, valid_idx) in enumerate(kf_way, start=1):print('the {} training start ...'.format(n_fold))train_x, train_y = train_[pred].iloc[train_idx], train_[label].iloc[train_idx]valid_x, valid_y = train_[pred].iloc[valid_idx], train_[label].iloc[valid_idx]if use_cart:dtrain = lgb.Dataset(train_x,label=train_y,categorical_feature=cate_cols)dvalid = lgb.Dataset(valid_x,label=valid_y,categorical_feature=cate_cols)else:dtrain = lgb.Dataset(train_x, label=train_y)dvalid = lgb.Dataset(valid_x, label=valid_y)clf = lgb.train(params=params,train_set=dtrain, # num_boost_round=3000,valid_sets=[dvalid],early_stopping_rounds=100,verbose_eval=100,feval=f1_score_eval)train_pred[valid_idx] = clf.predict(valid_x,num_iteration=clf.best_iteration)test_pred += clf.predict(test_[pred],num_iteration=clf.best_iteration) / folds.n_splitsprint(classification_report(train_[label], np.argmax(train_pred,axis=1),digits=4))if get_prob:sub_probs = ['qyxs_prob_{}'.format(q) for q in ['圍網', '刺網', '拖網']]prob_df = pd.DataFrame(test_pred, columns=sub_probs)prob_df['ID'] = test_['ID'].valuesreturn prob_dfelse:test_['label'] = np.argmax(test_pred, axis=1)return test_[['ID', 'label']]use_train = all_df[all_df['label'] != -1] use_test = all_df[all_df['label'] == -1] # use_feats = [c for c in use_train.columns if c not in ['ID', 'label']] use_feats=model_feature sub = sub_on_line_lgb(use_train, use_test, use_feats, 'label', [], 'kf',is_shuffle=True,use_cart=False,get_prob=False) Use 200 features ... the 1 training start ... Training until validation scores don't improve for 100 roundsD:\SOFTWEAR_H\Anaconda3\lib\site-packages\lightgbm\engine.py:151: UserWarning: Found `num_boost_round` in params. Will use it instead of argumentwarnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))[100] valid_0's f1_score: 0.894256 [200] valid_0's f1_score: 0.909942 [300] valid_0's f1_score: 0.913423 [400] valid_0's f1_score: 0.917897 [500] valid_0's f1_score: 0.920616 Early stopping, best iteration is: [456] valid_0's f1_score: 0.920717 the 2 training start ... Training until validation scores don't improve for 100 rounds [100] valid_0's f1_score: 0.918357 [200] valid_0's f1_score: 0.916436 Early stopping, best iteration is: [140] valid_0's f1_score: 0.92449 the 3 training start ... Training until validation scores don't improve for 100 rounds [100] valid_0's f1_score: 0.915242 [200] valid_0's f1_score: 0.927189 [300] valid_0's f1_score: 0.930614 Early stopping, best iteration is: [238] valid_0's f1_score: 0.930614 the 4 training start ... Training until validation scores don't improve for 100 rounds [100] valid_0's f1_score: 0.901683 [200] valid_0's f1_score: 0.912985 [300] valid_0's f1_score: 0.916988 [400] valid_0's f1_score: 0.92147 [500] valid_0's f1_score: 0.921353 Early stopping, best iteration is: [411] valid_0's f1_score: 0.922153 the 5 training start ... Training until validation scores don't improve for 100 rounds [100] valid_0's f1_score: 0.900975 [200] valid_0's f1_score: 0.908373 [300] valid_0's f1_score: 0.91384 [400] valid_0's f1_score: 0.917567 Early stopping, best iteration is: [369] valid_0's f1_score: 0.919843precision recall f1-score support0 0.8726 0.9001 0.8861 16211 0.9569 0.8949 0.9249 10182 0.9586 0.9619 0.9603 4361accuracy 0.9379 7000macro avg 0.9294 0.9190 0.9238 7000 weighted avg 0.9385 0.9379 0.9380 7000<ipython-input-42-6cbdd079efb6>:88: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value insteadSee the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copytest_['label'] = np.argmax(test_pred, axis=1)總結
以上是生活随笔為你收集整理的【算法竞赛学习】数字中国创新大赛智慧海洋建设-Task4模型建立的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 首发价 1299 元,森海塞尔新款 IE
- 下一篇: 绕不开的 IAP 支付 —— 如何实现支