数据挖掘—网格搜索2
數(shù)據(jù)挖掘—網(wǎng)格搜索2
- 1、分析交叉驗證的結(jié)果
- 2、網(wǎng)格的條件參數(shù)
- 3、使用不同的交叉驗證策略進行網(wǎng)格搜索
- (1) 傳入交叉驗證分離器
- (2)嵌套交叉驗證
- (3)交叉驗證與網(wǎng)格搜索并行
1、分析交叉驗證的結(jié)果
將交叉驗證的結(jié)果可視化通常有助于理解模型泛化能力對所搜索參數(shù)的依賴關(guān)系。由于運行網(wǎng)格搜索的計算成本相當高,所以通常最好從相對比較稀疏且較小的網(wǎng)格開始搜索。然后我們可以檢查交叉驗證網(wǎng)格搜索的結(jié)果,可能也會擴展搜索范圍。網(wǎng)格搜索的結(jié)果可以在 cv_results_ 屬性中找到,它是一個字典,其中保存了搜索的所有內(nèi)容。你可以在下面的輸出中看到,它里面包含許多細節(jié),最好將其轉(zhuǎn)換成 pandas 數(shù)據(jù)框后再查看:
import pandas as pd results=pd.DataFrame(grid_search.cv_results_) display(results[:5])
results 中每一行對應(yīng)一種特定的參數(shù)設(shè)置。對于每種參數(shù)設(shè)置,交叉驗證所有劃分的結(jié)果都被記錄下來,所有劃分的平均值和標準差也被記錄下來。由于我們搜索的是一個二維參數(shù)網(wǎng)格( C 和 gamma ),所以最適合用熱圖可視化(見圖 5-8)。我們首先提取平均驗證分數(shù),然后改變分數(shù)數(shù)組的形狀,使其坐標軸分別對``應(yīng)于 C 和 gamma :
import numpy as np scores=np.array(results.mean_test_score).reshape(6,6) import seaborn as sns import matplotlib.pyplot as plt sns.heatmap(scores,xticklabels=param_grid['gamma'],yticklabels=param_grid['C'],annot=True)
你可以看到, SVC 對參數(shù)設(shè)置非常敏感。對于許多種參數(shù)設(shè)置,精度都在 40% 左右,這是非常糟糕的;對于其他參數(shù)設(shè)置,精度約為 96%。我們可以從這張圖中看出以下幾點。首先,我們調(diào)的參數(shù)對于獲得良好的性能非常重要。這兩個參數(shù)( C 和 gamma)都很重要,因為調(diào)節(jié)它們可以將精度從 40% 提高到96%。此外,在我們選擇的參數(shù)范圍中也可以看到輸出發(fā)生了顯著的變化。樣重要的是要注意,參數(shù)的范圍要足夠大:每個參數(shù)的最佳取值能位于圖像的邊界上。
2、網(wǎng)格的條件參數(shù)
在某些情況下,嘗試所有參數(shù)的所有可能組合(正如GridSearchCV 所做的那樣)并不是一個好主意。例如, SVC 有一個 kernel 參數(shù),根據(jù)所選擇的 kernel (內(nèi)核),其他參數(shù)也是與之相關(guān)的。如果kernel=‘linear’ ,那么模型是線性的,只用到 C 參數(shù)。如果kernel=‘rbf’ ,則需要使用 C 和 gamma 兩個參(但用不到類似 degree 的其他參數(shù))。在這種情況下,搜索 C 、 gamma 和 kernel 所有可能的組合沒有意義:如果kernel=‘linear’ ,那么 gamma 是用不到的,嘗試 gamma 的不同取值將會浪費時間。為了處理這種**“條件”(conditional)參數(shù)**, GridSearchCV 的 param_grid 可以是字典組成的列表(a list of dictionaries)。列表中的每個字典可擴展為一個獨立的網(wǎng)格。包含內(nèi)核與參數(shù)的網(wǎng)格搜索可能如下所示。
param_grid=[{'kernel':['rbf'],'C':[0.001,0.01,0.1,1,10,100],'gamma':[0.001,0.01,0.1,1,10,100]},{'kernel':['linear'],'C':[0.001,0.01,0.1,1,10,100]}] grid_search=GridSearchCV(SVC(),param_grid,cv=5) grid_search.fit(X_train,y_train) grid_search.score(X_test,y_test)0.9736842105263158
grid_search.best_score_0.9732142857142857
grid_search.best_params_{‘C’: 100, ‘gamma’: 0.01}
results=pd.DataFrame(grid_search.cv_results_) display(results.T)3、使用不同的交叉驗證策略進行網(wǎng)格搜索
(1) 傳入交叉驗證分離器
與 cross_val_score 類似, GridSearchCV 對分類問題默認使用分層 k 折交叉驗證StratifiedKFold ,對回歸問題默認使用 k 折交叉驗證KFold 。
但是,你可以傳入任何交叉驗證分離器作為 GridSearchCV 的cv 參數(shù):
cv=loo 留一
SuffleSplit
StratifiedShuffleSplit
GroupKFold:每個組都整體的出現(xiàn)在訓(xùn)練集或者測試集中
(2)嵌套交叉驗證
在前面的例子中,我們先介紹了將數(shù)據(jù)單次劃分為訓(xùn)練集、驗證集與測試集,然后介紹了先將數(shù)據(jù)劃分為訓(xùn)練集和測試集,再在訓(xùn)練集上進行交叉驗證。但前面在使用GridSearchCV 時,我們?nèi)匀粚?shù)據(jù)單次劃分為訓(xùn)練集和測試集這可能會導(dǎo)致結(jié)果不穩(wěn)定,也讓我們過于依賴數(shù)據(jù)的此次劃分。我們可以再深入一點不是只將原始數(shù)據(jù)一次劃分為訓(xùn)練集和測試集,而是使用交叉驗證進行多次劃分,這就是所謂的嵌套交叉驗證(nested cross validation)。在嵌套交叉驗證中,有一個外層循環(huán),遍歷將數(shù)據(jù)劃分為訓(xùn)練集和測試集的所有劃分。對于每種劃分都運行一次網(wǎng)格搜索(對于外層循環(huán)的每種劃分可能會得到不同的最佳參數(shù))。然后,對于每種外層劃分,利用最佳參數(shù)設(shè)置計算得到測試集分數(shù)。
這一過程的結(jié)果是由分數(shù)組成的列表——不是一個模型,也不是一種參數(shù)設(shè)置。這些分數(shù)告訴我們在網(wǎng)格找到的最佳參數(shù)下模型的泛化能力好壞。由于嵌套交叉驗證不提供可用于新數(shù)據(jù)的模型,所以在尋找可用于未來數(shù)據(jù)的預(yù)測模型時很少用到它。但是,它對于評估給定模型在特定數(shù)據(jù)集上的效果很有用。
scores = cross_val_score(GridSearchCV(SVC(), param_grid, cv=5),iris.data, iris.target, cv=5) print("Cross-validation scores: ", scores) print("Mean cross-validation score: ", scores.mean())Cross-validation scores: [ 0.967 1. 0.967 0.967 1. ]
Mean cross-validation score: 0.98
嵌套交叉驗證的結(jié)果可以總結(jié)為“ SVC 在 iris 數(shù)據(jù)集上的交叉驗證平均精度為 98%”——不多也不少。
這里我們在內(nèi)層循環(huán)和外層循環(huán)中都使用了分層 5 折交叉驗證。由于 param_grid 包含 36種參數(shù)組合,所以需要構(gòu)建 36×5×5 = 900 個模型,導(dǎo)致嵌套交叉驗證過程的代價很高。
(3)交叉驗證與網(wǎng)格搜索并行
雖然在許多參數(shù)上運行網(wǎng)格搜索和在大型數(shù)據(jù)集上運行網(wǎng)格搜索的計算量可能很大,但令人尷尬的是,這些計算都是并行的(parallel)。這也就是說,在一種交叉驗證劃分下使用特定參數(shù)設(shè)置來構(gòu)建一個模型,與利用其他參數(shù)的模型是完全獨立的。這使得網(wǎng)格搜索與交叉驗證成為多個 CPU 內(nèi)核或集群上并行化的理想選擇。你可以將 n_jobs 參數(shù)設(shè)置為你想使用的 CPU 內(nèi)核數(shù)量,從而在 GridSearchCV 和 cross_val_score 中使用多個內(nèi)核。你可以設(shè)置 n_jobs=-1 來使用所有可用的內(nèi)核。
總結(jié)
以上是生活随笔為你收集整理的数据挖掘—网格搜索2的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 常用图像像素格式 NV12、NV2、I4
- 下一篇: linux svn切换分支,玩转SVN-