ADMM算法(交替方向乘子法)
有了前面標準Lagrangian乘子法與對偶上升法和增廣Lagrangian法的基礎,理解ADMM就容易了很多。本文主要來自張賢達《矩陣分析與優化(第二版)》4.7.4節。
ADMM算法
ADMM認為,在統計學與機器學習中,經常會遇到大尺度的等式約束優化問題,即x∈Rnx\in \mathbb{R}^nx∈Rn的維數nnn很大。如果xxx可以分解為幾個子向量,即x=(x1,?,xr)x=(x_1,\cdots,x_r)x=(x1?,?,xr?),其目標函數也可以分解為:
f(x)=∑i=1rfi(x)xi∈Rni,∑i=1rni=nf(x)=\sum_{i=1}^r f_i(x) \\ x_i\in \mathbb{R}^{n_i},\sum_{i=1}^r n_i=n f(x)=i=1∑r?fi?(x)xi?∈Rni?,i=1∑r?ni?=n
則大尺度的優化問題可以轉化為分布式優化問題。相應的,等式約束矩陣Ax=bAx=bAx=b也分塊為:
A=[A1,?,Ar],Ax=∑i=1rAixi=bA=[A_1,\cdots,A_r], Ax=\sum_{i=1}^r A_ix_i=b A=[A1?,?,Ar?],Ax=i=1∑r?Ai?xi?=b
于是增廣Lagrangian目標函數Lρ(x,λ)L_\rho(x,\lambda)Lρ?(x,λ)可以寫作:
Lρ(x,λ)=∑i=1r(fi(xi)+λTAixi)?λTb+ρ2∥∑i=1r(Aixi)?b∥22L_\rho(x,\lambda)=\sum_{i=1}^r(f_i(x_i)+\lambda^TA_ix_i)-\lambda^Tb+\frac{\rho}{2}\|\sum_{i=1}^r(A_ix_i)-b\|_2^2 Lρ?(x,λ)=i=1∑r?(fi?(xi?)+λTAi?xi?)?λTb+2ρ?∥i=1∑r?(Ai?xi?)?b∥22?
所取的罰函數與增廣Lagrangian乘子法中的仍相同。再采用對偶上升法,即可得到能進行并行運算的分散算法:
xik+1=arg?min?xi∈RniLi(xi,λk),i=1,?,rλk+1=λk+ρk(∑i=1rAixik+1?b)x_i^{k+1}=\argmin_{x_i\in \mathbb{R}^{n_i}} L_i(x_i,\lambda_k),i=1,\cdots,r \\ \lambda_{k+1}=\lambda_k+\rho_k (\sum_{i=1}^r A_i x_i^{k+1}-b) xik+1?=xi?∈Rni?argmin?Li?(xi?,λk?),i=1,?,rλk+1?=λk?+ρk?(i=1∑r?Ai?xik+1??b)
這里的xix_ixi?是可以獨立更新的。由于xix_ixi?以一種交替的或序貫的方式進行更新,所以稱為“交替方向”乘子法(ADMM算法)。
舉個r=2r=2r=2的例子
r=2r=2r=2,則目標函數為:
min?f(x)+g(z)s.t.Ax+Bz=c\min f(x)+g(z)\\ s.t.\ Ax+Bz=c minf(x)+g(z)s.t.?Ax+Bz=c
上式中,x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rpx\in \mathbb{R}^n,z\in \mathbb{R}^m,A\in \mathbb{R}^{p\times n},B\in \mathbb{R}^{p\times m},c\in \mathbb{R}^{p}x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rp。則增廣Lagrangian目標函數為:
Lρ(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)+ρ2∥Ax+Bz?c∥22(1)L_\rho(x,z,\lambda)=f(x)+g(z)+\lambda^T(Ax+Bz-c)+\frac{\rho}{2}\|Ax+Bz-c\|_2^2 \\ \tag{1} Lρ?(x,z,λ)=f(x)+g(z)+λT(Ax+Bz?c)+2ρ?∥Ax+Bz?c∥22?(1)
上式的交替方向乘子法的更新公式為:
xk+1=arg?min?x∈RnLρ(x,zk,λk)zk+1=arg?min?z∈RmLρ(xk+1,z,λk)λk+1=λk+ρk(Axk+1+Bzk+1?c)x_{k+1}=\argmin_{x\in \mathbb{R}^n} L_\rho(x,z_k,\lambda_k)\\ z_{k+1}=\argmin_{z\in \mathbb{R}^m} L_\rho(x_{k+1},z,\lambda_k)\\ \lambda_{k+1}=\lambda_k+\rho_k(Ax_{k+1}+Bz_{k+1}-c) xk+1?=x∈Rnargmin?Lρ?(x,zk?,λk?)zk+1?=z∈Rmargmin?Lρ?(xk+1?,z,λk?)λk+1?=λk?+ρk?(Axk+1?+Bzk+1??c)
誤差分析與停止條件
公式(1)(1)(1)的最優化條件分為原始可行性:
Ax+Bz?c=0Ax+Bz-c=0 Ax+Bz?c=0
和對偶可行性:
0∈?f(x)+ATλ+ρAT(Ax+Bz?c)=?f(x)+ATλ0∈?f(x)+BTλ+ρBT(Ax+Bz?c)=?g(z)+BTλ(2)0\in \partial f(x)+A^T\lambda+\rho A^T(Ax+Bz-c)=\partial f(x)+A^T\lambda \\ 0\in \partial f(x)+B^T\lambda+\rho B^T(Ax+Bz-c)=\partial g(z)+B^T\lambda \\ \tag{2} 0∈?f(x)+ATλ+ρAT(Ax+Bz?c)=?f(x)+ATλ0∈?f(x)+BTλ+ρBT(Ax+Bz?c)=?g(z)+BTλ(2)
根據我的企業級理解,這原始和對偶可行性關系分別是等式約束成立和偏導為0,是從KKT條件來的,都是必要條件。不過書里沒有明確指出是或者不是。
關于推導,這里用的是0∈0\in0∈而不是0=0=0=,這是什么企業級邏輯我沒弄懂,不過我覺得不影響理解,意思差不多。書上這里求導有問題,疑似紕漏,我改成了公式(1)(1)(1)的正確的求導結果,這樣也和后文更對的上。要記得,公式(1)(1)(1)中向量Ax+Bz?cAx+Bz-cAx+Bz?c二范數平方的一階導等于向量的二倍,再乘一個系數ATA^TAT,就可以得到這個結果。求導法則可以參考我的這篇總結。再加上Ax+Bz?c=0Ax+Bz-c=0Ax+Bz?c=0的約束,就能推導下來了。
在迭代的過程中,原始可行性不可能完全滿足,設其誤差為:
rk=Axk+Bzk?cr_k=Ax_k+Bz_k-c rk?=Axk?+Bzk??c
稱為第kkk次迭代的原始殘差(向量),這樣Lagrangian乘子向量的更新可以用這個殘差重寫為:
λk+1=λk+ρkrk+1\lambda_{k+1}=\lambda_k+\rho_k r_{k+1} λk+1?=λk?+ρk?rk+1?
同樣,對偶可行性也不會完全滿足:
0∈?f(xk+1)+ATλk+ρAT(Axk+1+Bzk?c)=?f(xk+1)+AT[λk+ρ(Axk+1+Bzk+1?c)+ρB(zk?zk+1)]=?f(xk+1)+AT[λk+ρrk+1+ρB(zk?zk+1)]=?f(xk+1)+ATλk+1+ρATB(zk?zk+1)0\in \partial f(x_{k+1})+A^T\lambda_k+\rho A^T(Ax_{k+1}+Bz_k-c) \\ =\partial f(x_{k+1})+A^T[\lambda_k+\rho (Ax_{k+1}+Bz_{k+1}-c)+\rho B(z_k-z_{k+1})] \\ =\partial f(x_{k+1})+A^T[\lambda_k+\rho r_{k+1}+\rho B(z_k-z_{k+1})]\\ =\partial f(x_{k+1}) +A^T\lambda_{k+1}+\rho A^TB(z_k-z_{k+1}) 0∈?f(xk+1?)+ATλk?+ρAT(Axk+1?+Bzk??c)=?f(xk+1?)+AT[λk?+ρ(Axk+1?+Bzk+1??c)+ρB(zk??zk+1?)]=?f(xk+1?)+AT[λk?+ρrk+1?+ρB(zk??zk+1?)]=?f(xk+1?)+ATλk+1?+ρATB(zk??zk+1?)
注意,由于書上公式(2)(2)(2)求導是錯的,所以這一步更別扭,怎么看都不對勁,這里我也改成了我認為的正確的推導形式。
對照公式(2)(2)(2)中的第一個式子可知對偶殘差為:
sk+1=ρATB(zk?zk+1)s_{k+1}=\rho A^TB(z_k-z_{k+1}) sk+1?=ρATB(zk??zk+1?)
交替方向乘子法的停止條件就是兩個殘差都小于閾值:
∥rk+1∥2≤εpriand∥sk+1∥≤εdual\|r_{k+1}\|_2\le \varepsilon_{pri} \ and \ \|s_{k+1}\|\le \varepsilon_{dual} ∥rk+1?∥2?≤εpri??and?∥sk+1?∥≤εdual?
縮放形式的ADMM
令v=(1/ρ)λv=(1/\rho)\lambdav=(1/ρ)λ為經過1/ρ1/\rho1/ρ縮放的Lagrangian乘子向量,則更新公式變為:
xk+1=arg?min?x∈RnLρ(x,zk,vk)zk+1=arg?min?z∈RmLρ(xk+1,z,vk)vk+1=vk+Axk+1+Bzk+1?c=vk+rk+1x_{k+1}=\argmin_{x\in \mathbb{R}^n} L_\rho(x,z_k,v_k)\\ z_{k+1}=\argmin_{z\in \mathbb{R}^m} L_\rho(x_{k+1},z,v_k)\\ v_{k+1}=v_k+Ax_{k+1}+Bz_{k+1}-c=v_k+r_{k+1} xk+1?=x∈Rnargmin?Lρ?(x,zk?,vk?)zk+1?=z∈Rmargmin?Lρ?(xk+1?,z,vk?)vk+1?=vk?+Axk+1?+Bzk+1??c=vk?+rk+1?
其第kkk次迭代的殘差rkr_krk?為:
rk=Axk+Bzk?c=v0+∑i=1krir_k=Ax_k+Bz_k-c=v_0+\sum_{i=1}^kr_i rk?=Axk?+Bzk??c=v0?+i=1∑k?ri?
即,第kkk次迭代的縮放對偶向量是所有kkk次迭代的原始殘差之和。這種方法稱為有縮放的交替方向乘子法。
最后,個人認為,不論是對偶上升法,增廣Lagrangian乘子法,還是ADMM算法,核心思想都相似,而且具體使用時都要與其他最優化方法結合,因為arg?min?L(x,z,λ)\argmin L(x,z,\lambda)argminL(x,z,λ)的求解是還需要別的方法的,停止條件需要根據使用環境具體再去確定。
總結
以上是生活随笔為你收集整理的ADMM算法(交替方向乘子法)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java二维数组水平翻转,C 语言 利用
- 下一篇: c++ 数组的输入遇到特定字符停止输入_