Gauss-Newton算法代码详细解释(转载+自己注释)
這篇博客是對[1]中不詳細的地方進行細節上的闡述,
并且每句代碼都加了注釋,使得更加容易理解
下面的論述(包括偽代碼和算法)特指被最小化的目標函數是MSE的時候
需要注意,如果不是MSE為目標函數,那么下面的第二張截圖開始的偽代碼都是需要更換的,這里的偽代碼僅僅針對目標函數恰好為MSE的時候
本文的代碼來自[2],因為[1]中的代碼參數名字有點奇怪,所以就不予采用
該代碼的運行需要安裝opencv,安裝流程使用[3]
###################偽代碼#######################################
這里解釋下:
上面的▽f\triangledown f▽f為啥是:
2∑i=1mri?ri?xj2\sum_{i=1}^m r_i\frac{\partial r_i}{\partial x_j}2i=1∑m?ri??xj??ri??
因為這里的fff其實是MSE,也就是:
所以求導后就有了上面▽f\triangledown f▽f的模樣
################具體案例與代碼解釋#############################
例子1,根據美國1815年至1885年數據,估計人口模型中的參數A和B。如下表所示,已知年份和人口總量,及人口模型方程,求方程中的參數(為什么Gauss-Newton用來求解參數在本文后面有解釋)。
代碼main.cpp如下:
#include <cstdio> #include <vector> #include <opencv2/core.hpp> #include<iostream> using namespace std; using namespace cv;const double DERIV_STEP = 1e-5; const int MAX_ITER = 100;void GaussNewton(double(*Func)(const Mat &input, const Mat params),const Mat &inputs, const Mat &outputs, Mat params); //算法頂層函數聲明 double Deriv(double(*Func)(const Mat &input, const Mat params),const Mat &input, const Mat params, int n); //導數函數聲明 // The user defines their function here double Func(const Mat &input, const Mat params);//下面是被調用的函數//這個函數的作用就是f=A·e^(Bt) double Func(const Mat &input, const Mat params) { // 這里的params樣例如下: // params=[7.000153882130696;0.262076597545448] // 也就是說,這里的params類似一個列表. //所以,其實這里的params=[A,B],input=t// Assumes input is a single row matrix// Assumes params is a column matrixdouble A = params.at<double>(0, 0);//一個浮點數//這里的(0,0)代表獲取上述列表的第0個元素double B = params.at<double>(1, 0);//一個浮點數//這里的(0,1)代表獲取上述列表的第1個元素double x = input.at<double>(0, 0);//x就是上面函數表達式中的treturn A*exp(x*B); }//calc the n-th params' partial derivation , the params are our final target double Deriv(double(*Func)(const Mat &input, const Mat params), const Mat &input, const Mat params, int n) {// Assumes input is a single row matrix// Returns the derivative of the nth parameterMat params1 = params.clone();Mat params2 = params.clone();// Use central difference to get derivativeparams1.at<double>(n, 0) -= DERIV_STEP;//對A,B兩個系數縮減params2.at<double>(n, 0) += DERIV_STEP;//對A,B兩個系數遞增double p1 = Func(input, params1);double p2 = Func(input, params2);double d = (p2 - p1) / (2 * DERIV_STEP);//這里是在計算導數return d; }void GaussNewton(double(*Func)(const Mat &input, const Mat params),const Mat &inputs, const Mat &outputs, Mat params) {int m = inputs.rows;//行m=8,表示8條數據int n = inputs.cols;//列n=1int num_params = params.rows;//nfum_params=2,表示目標函數表達式中的未知參數的個數Mat r(m, 1, CV_64F); // residual matrix殘差矩陣Mat Jf(m, num_params, CV_64F); // Jacobian of Func()雅各比矩陣Mat input(1, n, CV_64F); // single row inputdouble last_mse = 0;for (int i = 0; i < MAX_ITER; i++){double mse = 0;//擬合效果的指標for (int j = 0; j < m; j++)//遍歷行, ?Xj{for (int k = 0; k < n; k++)//遍歷列,?Xk{//copy Independent variable vector, the yearinput.at<double>(0, k) = inputs.at<double>(j, k);}r.at<double>(j, 0) = outputs.at<double>(j, 0) - Func(input, params);//擬合值與實際值之間的差//之所以是(j,0)是因為輸出值肯定只有一列mse += r.at<double>(j, 0)*r.at<double>(j, 0);//殘差矩陣的平方,這里之所以要平方是根據MSE的定義來的for (int k = 0; k < num_params; k++)//遍歷列{Jf.at<double>(j, k) = Deriv(Func, input, params, k);//對第k個元素求偏導//雅各比矩陣中的某個元素是求導值}}mse /= m;//MSE的定義中需要除以整體元素的個數// The difference in mse is very small, so quitif (fabs(mse - last_mse) < 1e-8)//如果MSE不再變化,就認為擬合結束{break;}Mat delta = ((Jf.t()*Jf)).inv() * Jf.t()*r;//計算△params += delta;//printf("%d: mse=%f\n", i, mse);printf("%d %f\n", i, mse);last_mse = mse;} }//------------------下面是頂層文件----------------------------------------------int main() {// For this demo we're going to try and fit to the function// F = A*exp(t*B), There are 2 parameters: A Bint num_params = 2;// Generate random data using these parametersint total_data = 8;Mat inputs(total_data, 1, CV_64F);//這里的CV_64F代表一種單通道的矩陣類型Mat outputs(total_data, 1, CV_64F);//-------------------------------下面是采樣數據----------------------------------------for (int i = 0; i < total_data; i++){inputs.at<double>(i, 0) = i + 1; //load year}//load America populationoutputs.at<double>(0, 0) = 8.3;outputs.at<double>(1, 0) = 11.0;outputs.at<double>(2, 0) = 14.7;outputs.at<double>(3, 0) = 19.7;outputs.at<double>(4, 0) = 26.7;outputs.at<double>(5, 0) = 35.2;outputs.at<double>(6, 0) = 44.4;outputs.at<double>(7, 0) = 55.9;//-------------------------------------上面是采樣數據-----------------------------------------// Guess the parameters, it should be close to the true value, else it can fail for very sensitive functions!Mat params(num_params, 1, CV_64F);//下面是初始值設置params.at<double>(0, 0) = 6;params.at<double>(1, 0) = 0.3;GaussNewton(Func, inputs, outputs, params);cout<<"最終params="<<params<<endl;printf("Parameters from GaussNewton: %f %f\n", params.at<double>(0, 0), params.at<double>(1, 0));return 0; }Clion2018.3中的CMakeLists.txt是(沒這個玩意兒還真運行不了):
cmake_minimum_required(VERSION 3.12) project(GaussNewton)set(CMAKE_CXX_STANDARD 14) include_directories($ENV{CMAKE_INCLUDE_PATH}) set(CMAKE_CXX_STANDARD 14)#C++ standard set(OpenCV_DIR /home/appleyuchi/opencv/opencv_install/lib/cmake/opencv4) find_package( OpenCV REQUIRED ) # locate OpenCV in system include_directories( ${OpenCV_INCLUDE_DIRS} ) # provide library headers add_executable(GaussNewton main.cpp) target_link_libraries(GaussNewton ${OpenCV_LIBS} /home/appleyuchi/opencv/opencv_install/lib/libopencv_highgui.so) # link OpenCV libraries , hightgui.so not found by cmake so this hack MESSAGE("OpenCV_LIBS: " ${OpenCV_LIBS} ) #display opencv libs found使用的opencv版本是4.0.1
關于這個代碼的一個問題:
NewtonGauss算法明明是為了計算最小值而存在的,為什么到了代碼里變成了求參數值A和B?
答:
因為這個代碼中,我們常見的Jacobian是對于x1,x2求偏導,其實在代碼里面是對應于A和B求偏導,
也就是說,代碼其實是在求解對于A和B的Jacobian矩陣.
所以這里關注的"目標函數迭代到最小值"其實是MSE,并不是fff
算法的目的是,當A和B為何值時,目標函數的值最小.
這樣就符合了Newton-Gauss算法的本意:
求解目標函數MSE的最低值以及最低值對應的變量A,B的具體數值
代碼和偽代碼的對應關系:
| num_params=A,B | x1,x2x_1,x_2x1?,x2? |
| m | m |
| Jf.at(j, k)函數在某個點的導數 | Jr=αfαxkJ_r=\frac{\alpha f}{\alpha {x_k}}Jr?=αxk?αf? |
| 點坐標(A,B,outputs.at(i,0)) | 點坐標(x1,x2,f(x1,x2)x_1,x_2,f(x_1,x_2)x1?,x2?,f(x1?,x2?)) |
Reference:
[1]Gauss-Newton算法學習
[2]梯度下降法,牛頓法,高斯-牛頓迭代法,附代碼實現
[3]opencv4.0.1配合contrib在linux下面安裝編譯全過程
總結
以上是生活随笔為你收集整理的Gauss-Newton算法代码详细解释(转载+自己注释)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: CV_64F,CV_64FC1以及CV_
- 下一篇: theano中的vector和dvect