【深度学习】神经网络模型特征重要性可以查看了!!!
作者:杰少
查看NN模型特征重要性的技巧
簡 介
我們都知道樹模型的特征重要性是非常容易繪制出來的,只需要直接調用樹模型自帶的API即可以得到在樹模型中每個特征的重要性,那么對于神經網絡我們該如何得到其特征重要性呢?
本篇文章我們就以LSTM為例,來介紹神經網絡中模型特征重要性的一種獲取方式。
NN模型特征重要性
01
基本思路
該策略的思想來源于:Permutation Feature Importance,我們以特征對于模型最終預測結果的變化來衡量特征的重要性。
02
實現步驟
NN模型特征重要性的獲取步驟如下:
訓練一個NN;
每次獲取一個特征列,然后對其進行隨機shuffle,使用模型對其進行預測并得到Loss;
記錄每個特征列以及其對應的Loss;
每個Loss就是該特征對應的特征重要性,如果Loss越大,說明該特征對于NN模型越加重要;反之,則越加不重要。
Code
代碼摘自:https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook
import?matplotlib.pyplot?as?plt from?tqdm.notebook?import?tqdmimport?tensorflow?as?tf from?tensorflow?import?keras import?tensorflow.keras.backend?as?K from?tensorflow.keras.callbacks?import?EarlyStopping,?ModelCheckpoint from?tensorflow.keras.callbacks?import?LearningRateScheduler,?ReduceLROnPlateau from?tensorflow.keras.optimizers.schedules?import?ExponentialDecay from?sklearn.metrics?import?mean_absolute_error?as?mae from?sklearn.preprocessing?import?RobustScaler,?normalize from?sklearn.model_selection?import?train_test_split,?GroupKFold,?KFold from?IPython.display?import?displayCOMPUTE_LSTM_IMPORTANCE?=?1 ONE_FOLD_ONLY?=?1with?gpu_strategy.scope():kf?=?KFold(n_splits=NUM_FOLDS,?shuffle=True,?random_state=2021)test_preds?=?[]for?fold,?(train_idx,?test_idx)?in?enumerate(kf.split(train,?targets)):K.clear_session()print('-'*15,?'>',?f'Fold?{fold+1}',?'<',?'-'*15)X_train,?X_valid?=?train[train_idx],?train[test_idx]y_train,?y_valid?=?targets[train_idx],?targets[test_idx]#?導入已經訓練好的模型model?=?keras.models.load_model('models/XXX.h5')#?計算特征重要性if?COMPUTE_LSTM_IMPORTANCE:results?=?[]print('?Computing?LSTM?feature?importance...')for?k?in?tqdm(range(len(COLS))):if?k>0:?save_col?=?X_valid[:,:,k-1].copy()np.random.shuffle(X_valid[:,:,k-1])oof_preds?=?model.predict(X_valid,?verbose=0).squeeze()?mae?=?np.mean(np.abs(?oof_preds-y_valid?))results.append({'feature':COLS[k],'mae':mae})if?k>0:?X_valid[:,:,k-1]?=?save_col#?展示特征重要性print()df?=?pd.DataFrame(results)df?=?df.sort_values('mae')plt.figure(figsize=(10,20))plt.barh(np.arange(len(COLS)),df.mae)plt.yticks(np.arange(len(COLS)),df.feature.values)plt.title('LSTM?Feature?Importance',size=16)plt.ylim((-1,len(COLS)))plt.show()#?SAVE?LSTM?FEATURE?IMPORTANCEdf?=?df.sort_values('mae',ascending=False)df.to_csv(f'lstm_feature_importance_fold_{fold}.csv',index=False)#?ONLY?DO?ONE?FOLDif?ONE_FOLD_ONLY:?break適用情況
適用于所有的NN模型。
參考文獻
https://www.kaggle.com/cdeotte/lstm-feature-importance/notebook
Permutation Feature Importance
本站qq群851320808,加入微信群請掃碼:
總結
以上是生活随笔為你收集整理的【深度学习】神经网络模型特征重要性可以查看了!!!的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Windows平台如何查看一个dll依赖
- 下一篇: 最新版chrome 70浏览器同步、清除