机器学习算法实现解析:libFM之libFM的训练过程之SGD的方法
本節主要介紹的是libFM源碼分析的第五部分之一——libFM的訓練過程之SGD的方法。
5.1、基于梯度的模型訓練方法
在libFM中,提供了兩大類的模型訓練方法,一類是基于梯度的訓練方法,另一類是基于MCMC的模型訓練方法。對于基于梯度的訓練方法,其類為fm_learn_sgd類,其父類為fm_learn類,主要關系為:
?
?
fm_learn_sgd類是所有基于梯度的訓練方法的父類,其具體的代碼如下所示:
#include "fm_learn.h" #include "../../fm_core/fm_sgd.h"// 繼承自fm_learn class fm_learn_sgd: public fm_learn {protected://DVector<double> sum, sum_sqr;public:int num_iter;// 迭代次數double learn_rate;// 學習率DVector<double> learn_rates;// 多個學習率 // 初始化virtual void init() { fm_learn::init(); learn_rates.setSize(3);// 設置學習率// sum.setSize(fm->num_factor); // sum_sqr.setSize(fm->num_factor);} // 利用梯度下降法進行更新,具體的訓練的過程在其子類中virtual void learn(Data& train, Data& test) { fm_learn::learn(train, test);// 該函數并沒有具體實現// 輸出運行時的參數,包括:學習率,迭代次數std::cout << "learnrate=" << learn_rate << std::endl;std::cout << "learnrates=" << learn_rates(0) << "," << learn_rates(1) << "," << learn_rates(2) << std::endl;std::cout << "#iterations=" << num_iter << std::endl;if (train.relation.dim > 0) {// 判斷relationthrow "relations are not supported with SGD";}std::cout.flush();// 刷新}// SGD重新修正fm模型的權重void SGD(sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {fm_SGD(fm, learn_rate, x, multiplier, sum);// 調用fm_sgd中的fm_SGD函數} // debug函數,主要用于打印中間結果void debug() {std::cout << "num_iter=" << num_iter << std::endl;fm_learn::debug(); }// 對數據進行預測virtual void predict(Data& data, DVector<double>& out) {assert(data.data->getNumRows() == out.dim);// 判斷樣本個數是否相等for (data.data->begin(); !data.data->end(); data.data->next()) {double p = predict_case(data);// 得到線性項和交叉項的和,調用的是fm_learn中的方法if (task == TASK_REGRESSION ) {// 回歸任務p = std::min(max_target, p);p = std::max(min_target, p);} else if (task == TASK_CLASSIFICATION) {// 分類任務p = 1.0/(1.0 + exp(-p));// Sigmoid函數處理} else {// 異常處理throw "task not supported";}out(data.data->getRowIndex()) = p;} } };在fm_learn_sgd類中,主要包括五個函數,分別為:初始化init函數,訓練learn函數,SGD訓練SGD函數,debug的debug函數和預測predict函數。
5.1.1、初始化init函數
在初始化中,對學習率的大小進行了初始化,同時繼承了父類中的初始化方法。
5.1.2、訓練learn函數
在learn函數中,沒有具體的訓練的過程,只是對訓練中需要用到的參數進行輸出,具體的訓練的過程在其對應的子類中定義,如fm_learn_sgd_element類和fm_learn_sgd_element_adapt_reg類。
5.1.3、SGD訓練SGD函數
SGD函數使用的是fm_sgd.h文件中的fm_SGD函數。fm_SGD函數是利用梯度下降法對模型中的參數進行調整,以得到最終的模型中的參數。在利用梯度下降法對模型中的參數進行調整的過程中,假設損失函數為ll,那么,對于回歸問題來說,其損失函數為:
?
l=12(y^(i)?y(i))2l=12(y^(i)?y(i))2
?
對于二分類問題,其損失函數為:
?
l=?lnσ(y^(i)y(i))l=?lnσ(y^(i)y(i))
?
其中,σσ為Sigmoid函數:
?
σ(x)=11+e(?x)σ(x)=11+e(?x)
?
對于σ(x)σ(x),其導函數為:
?
σ′=σ(1?σ)σ′=σ(1?σ)
?
在可用SGD更新的過程中,首先需要計算損失函數的梯度,因此,對應于上述的回歸問題和二分類問題,其中回歸問題的損失函數的梯度為:
?
?l?θ=(y^(i)?y(i))??y^(i)?θ?l?θ=(y^(i)?y(i))??y^(i)?θ
?
分類問題的損失函數的梯度為:
?
?l?θ=(σ(y^(i)y(i))?1)?y(i)??y^(i)?θ?l?θ=(σ(y^(i)y(i))?1)?y(i)??y^(i)?θ
?
其中,λλ稱為正則化參數,在具體的應用中,通常加上L2L2正則,即:
?
?l?θ+λθ?l?θ+λθ
?
在定義好上述的計算方法后,其核心的問題是如何計算?y^(i)?θ?y^(i)?θ,在“機器學習算法實現解析——libFM之libFM的模型處理部分”中已知:
?
y^:=w0+∑i=1nwixi+∑i=1n?1∑j=i+1n?vi,vj?xixjy^:=w0+∑i=1nwixi+∑i=1n?1∑j=i+1n?vi,vj?xixj
?
因此,當y^y^分別對w0w0,wiwi以及vi,fvi,f求偏導時,其結果分別為:
?
?y^?θ=???????1xixi(∑j=1xjvj,f?xivi,f)?if?θ=w0?if?θ=wi?if?θ=vi,f?y^?θ={1?if?θ=w0xi?if?θ=wixi(∑j=1xjvj,f?xivi,f)?if?θ=vi,f
?
在利用梯度的方法中,其參數θθ的更新方法為:
?
θ=θ?η?(?l?θ+λθ)θ=θ?η?(?l?θ+λθ)
?
其中,ηη為學習率,在libFM中,其具體的代碼如下所示:
// 利用SGD更新模型的參數 void fm_SGD(fm_model* fm, const double& learn_rate, sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {// 1、常數項的修正if (fm->k0) {double& w0 = fm->w0;w0 -= learn_rate * (multiplier + fm->reg0 * w0);}// 2、一次項的修正if (fm->k1) {for (uint i = 0; i < x.size; i++) {double& w = fm->w(x.data[i].id);w -= learn_rate * (multiplier * x.data[i].value + fm->regw * w);}}// 3、交叉項的修正for (int f = 0; f < fm->num_factor; f++) {for (uint i = 0; i < x.size; i++) {double& v = fm->v(f,x.data[i].id);double grad = sum(f) * x.data[i].value - v * x.data[i].value * x.data[i].value; v -= learn_rate * (multiplier * grad + fm->regv * v);}} }以上的更新的過程分別對應著上面的更新公式,其中multiplier變量分別對應著回歸中的(y^(i)?y(i))(y^(i)?y(i))和分類中的(σ(y^(i)y(i))?1)?y(i)(σ(y^(i)y(i))?1)?y(i)。
5.1.4、預測predict函數
predict函數用于對樣本進行預測,這里使用到了predict_case函數,該函數在“機器學習算法實現解析——libFM之libFM的訓練過程概述”中有詳細的說明,得到值后,分別對回歸問題和分類問題做處理,在回歸問題中,主要是防止超出最大值和最小值,在分類問題中,將其值放入Sigmoid函數,得到最終的結果。
5.2、SGD的訓練方法
隨機梯度下降法(Stochastic Gradient Descent ,SGD)是一種簡單有效的優化方法。對于梯度下降法的更多內容,可以參見“梯度下降優化算法綜述”。在利用SGD對FM模型訓練的過程如下圖所示:
?
?
在libFM中,SGD的實現在fm_learn_sgd_element.h文件中。在該文件中,定義了fm_learn_sgd_element類,fm_learn_sgd_element類繼承自fm_learn_sgd類,主要實現了fm_learn_sgd類中的learn方法,具體的程序代碼如下所示:
#include "fm_learn_sgd.h"// 繼承了fm_learn_sgd class fm_learn_sgd_element: public fm_learn_sgd {public:// 初始化virtual void init() {fm_learn_sgd::init();// 日志輸出if (log != NULL) {log->addField("rmse_train", std::numeric_limits<double>::quiet_NaN());}}// 利用SGD訓練FM模型virtual void learn(Data& train, Data& test) {fm_learn_sgd::learn(train, test);// 輸出參數信息std::cout << "SGD: DON'T FORGET TO SHUFFLE THE ROWS IN TRAINING DATA TO GET THE BEST RESULTS." << std::endl; // SGDfor (int i = 0; i < num_iter; i++) {// 開始迭代,每一輪的迭代過程double iteration_time = getusertime();// 記錄開始的時間for (train.data->begin(); !train.data->end(); train.data->next()) {// 對于每一個樣本double p = fm->predict(train.data->getRow(), sum, sum_sqr);// 得到樣本的預測值double mult = 0;// 損失函數的導數if (task == 0) {// 回歸p = std::min(max_target, p);p = std::max(min_target, p);// loss=(y_ori-y_pre)^2mult = -(train.target(train.data->getRowIndex())-p);// 對損失函數求導} else if (task == 1) {// 分類// lossmult = -train.target(train.data->getRowIndex())*(1.0-1.0/(1.0+exp(-train.target(train.data->getRowIndex())*p)));}// 利用梯度下降法對參數進行學習SGD(train.data->getRow(), mult, sum); } iteration_time = (getusertime() - iteration_time);// 記錄時間差// evaluate函數是調用的fm_learn類中的方法double rmse_train = evaluate(train);// 對訓練結果評估double rmse_test = evaluate(test);// 將模型應用在測試數據上std::cout << "#Iter=" << std::setw(3) << i << "\tTrain=" << rmse_train << "\tTest=" << rmse_test << std::endl;// 日志輸出if (log != NULL) {log->log("rmse_train", rmse_train);log->log("time_learn", iteration_time);log->newLine();}} }};在learn函數中,實現了SGD訓練FM模型的主要過程,在實現的過程中,分別調用了SGD函數和evaluate函數,其中SGD函數如上面的5.1.3、SGD訓練SGD函數小節所示,利用SGD函數對FM模型中的參數進行更新,evaluate函數如“機器學習算法實現解析——libFM之libFM的訓練過程概述”中所示,evaluate函數用于評估學習出的模型的效果。其中mult變量分別對應著回歸中的(y^(i)?y(i))(y^(i)?y(i))和分類中的(σ(y^(i)y(i))?1)?y(i)(σ(y^(i)y(i))?1)?y(i)。
參考文獻
- Rendle S. Factorization Machines[C]// IEEE International Conference on Data Mining. IEEE Computer Society, 2010:995-1000.
- Rendle S. Factorization Machines with libFM[M]. ACM, 2012.
--------------------- 本文來自 zhiyong_will 的CSDN 博客 ,全文地址請點擊:https://blog.csdn.net/google19890102/article/details/72866334?utm_source=copy
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的机器学习算法实现解析:libFM之libFM的训练过程之SGD的方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: scala与java混合编译出现的问题
- 下一篇: FAIR重磅发布大规模语料库XNLI:解