拟牛顿法

一、牛顿法

1.1 基本介绍

牛顿法属于利用一阶和二阶导数的无约束目标最优化方法。基本思想是,在每一次迭代中,以牛顿方向为搜索方向进行更新。牛顿法对目标的可导性更严格,要求二阶可导,有Hesse矩阵求逆的计算复杂的缺点。XGBoost本质上就是利用牛顿法进行优化的。

1.2 基本原理

现在推导牛顿法。
假设无约束最优化问题是

minxf(x) min x f ( x )
<script type="math/tex; mode=display" id="MathJax-Element-1">\min_x f(x)</script>
对于一维 x x <script type="math/tex" id="MathJax-Element-2">x</script> 的情况,可以将 f ( x ( t + 1 ) ) <script type="math/tex" id="MathJax-Element-3">f(x^{(t+1)})</script> 在 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-4">x^{(t)}</script> 附近用二阶泰勒展开近似:
f(x(t+1))=f(x(t))+f(x(t))Δx+12f(x(t))Δx2 f ( x ( t + 1 ) ) = f ( x ( t ) ) + f ′ ( x ( t ) ) Δ x + 1 2 f ″ ( x ( t ) ) Δ x 2
<script type="math/tex; mode=display" id="MathJax-Element-5">f(x^{(t+1)})=f(x^{(t)})+f'(x^{(t)})\Delta x+\frac{1}{2}f''(x^{(t)})\Delta x^2</script>
然后用泰勒展开的极值点近似 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-6">f(x)</script> 的极值点:
f(x(t+1))x(t+1)=f(x(t))+f(x(t))Δx=0 ∂ f ( x ( t + 1 ) ) ∂ x ( t + 1 ) = f ′ ( x ( t ) ) + f ″ ( x ( t ) ) Δ x = 0
<script type="math/tex; mode=display" id="MathJax-Element-7">\frac{\partial f(x^{(t+1)})}{\partial x^{(t+1)}}=f'(x^{(t)})+f''(x^{(t)})\Delta x=0</script>
因此
Δx=x(t+1)x(t)=f(x(t))f(x(t))=gtht Δ x = x ( t + 1 ) − x ( t ) = − f ′ ( x ( t ) ) f ″ ( x ( t ) ) = − g t h t
<script type="math/tex; mode=display" id="MathJax-Element-8">\Delta x = x^{(t+1)}-x^{(t)}=-\frac{f'(x^{(t)})}{f''(x^{(t)})}=-\frac{g_t}{h_t}</script>
于是得到迭代公式, g g <script type="math/tex" id="MathJax-Element-9">g</script>和 h <script type="math/tex" id="MathJax-Element-10">h</script>分别是目标在当前 x x <script type="math/tex" id="MathJax-Element-11">x</script>上的一阶和二阶导
x ( t + 1 ) = x ( t ) g t h t
<script type="math/tex; mode=display" id="MathJax-Element-12">x^{(t+1)}=x^{(t)}-\frac{g_t}{h_t}</script>
推广到 x x <script type="math/tex" id="MathJax-Element-13">x</script>是多维向量的情况, g t <script type="math/tex" id="MathJax-Element-14">g_t</script> 仍然是向量,而 Ht H t <script type="math/tex" id="MathJax-Element-15">H_t</script> 是Hesse矩阵
H=[2fxixj] H = [ ∂ 2 f ∂ x i ∂ x j ]
<script type="math/tex; mode=display" id="MathJax-Element-16">H=\left [ \frac{\partial^2f}{\partial x_i\partial x_j} \right ]</script>
以二维 x=(x1,x2) x = ( x 1 , x 2 ) <script type="math/tex" id="MathJax-Element-17">x=(x_1,x_2)</script>为例:
H=2fx212fx2x12fx1x22fx22 H = [ ∂ 2 f ∂ x 1 2 ∂ 2 f ∂ x 1 x 2 ∂ 2 f ∂ x 2 x 1 ∂ 2 f ∂ x 2 2 ]
<script type="math/tex; mode=display" id="MathJax-Element-18">H=\left [\begin{matrix} \frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1x_2} \\ \frac{\partial^2 f}{\partial x_2x_1} & \frac{\partial^2 f}{\partial x_2^2} \end{matrix} \right ]</script>
参数更新方程推广为:
x(t+1)=x(t)H1tgt x ( t + 1 ) = x ( t ) − H t − 1 g t
<script type="math/tex; mode=display" id="MathJax-Element-19">x^{(t+1)}=x^{(t)}-H_t^{-1}g_t</script>
可见,每一次迭代的更新方向都是当前点的牛顿方向,步长固定为1。每一次都需要计算一阶导数 g g <script type="math/tex" id="MathJax-Element-20">g</script>以及Hesse矩阵的逆矩阵,对于高维特征而言,求逆矩阵的计算量巨大且耗时。

1.3 阻尼牛顿法

从上面的推导中看出,牛顿方向 H 1 g <script type="math/tex" id="MathJax-Element-21">-H^{-1}g</script> 能使得更新后函数处于极值点,但是它不一定是极小点,也就是说牛顿方向可能是下降方向,也可能是上升方向,以至于当初始点远离极小点时,牛顿法有可能不收敛。因此提出 阻尼牛顿法,在牛顿法的基础上,每次迭代除了计算更新方向(牛顿方向),还要对最优步长做一维搜索。

算法步骤

(1)给定给初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-22">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-23">\epsilon</script>
(2)计算点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-24">x^{(t)}</script> 处梯度 gt g t <script type="math/tex" id="MathJax-Element-25">g_t</script>和Hesse矩阵 H H <script type="math/tex" id="MathJax-Element-26">H</script>,若 | g t | < ϵ <script type="math/tex" id="MathJax-Element-27">|g_t|<\epsilon</script>则停止迭代
(3)计算点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-28">x^{(t)}</script> 处的牛顿方向作为搜索方向:

d(t)=H1tgt d ( t ) = − H t − 1 g t
<script type="math/tex; mode=display" id="MathJax-Element-29">d^{(t)}=-H_t^{-1}g_t</script>
(4)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-30">x^{(t)}</script> 出发,沿着牛顿方向 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-31">d^{(t)}</script> 做一维搜索,获得最优步长:
λt=argminλf(x(t)+λd(t)) λ t = arg ⁡ min λ f ( x ( t ) + λ ⋅ d ( t ) )
<script type="math/tex; mode=display" id="MathJax-Element-32">\lambda_t = \arg \min_{\lambda} f(x^{(t)}+\lambda\cdot d^{(t)})</script>
(5)更新参数
x(t+1)=x(t)+λtd(t) x ( t + 1 ) = x ( t ) + λ t ⋅ d ( t )
<script type="math/tex; mode=display" id="MathJax-Element-33">x^{(t+1)}=x^{(t)}+\lambda_t\cdot d^{(t)}</script>


二、拟牛顿法

2.1 提出的初衷

牛顿法中的Hesse矩阵 H H <script type="math/tex" id="MathJax-Element-34">H</script>在稠密时求逆计算量大,也有可能没有逆(Hesse矩阵非正定)。拟牛顿法提出,用不含二阶导数的矩阵 U t <script type="math/tex" id="MathJax-Element-35">U_t</script> 替代牛顿法中的 H1t H t − 1 <script type="math/tex" id="MathJax-Element-36">H_t^{-1}</script>,然后沿搜索方向 Utgt − U t g t <script type="math/tex" id="MathJax-Element-37">-U_tg_t</script> 做一维搜索。根据不同的 Ut U t <script type="math/tex" id="MathJax-Element-38">U_t</script> 构造方法有不同的拟牛顿法。
注意拟牛顿法的 关键词

  • 不用算二阶导数
  • 不用求逆

2.2 拟牛顿条件

牛顿法的搜索方向是

d(t)=H1tgt d ( t ) = − H t − 1 g t
<script type="math/tex; mode=display" id="MathJax-Element-39">d^{(t)}=-H_t^{-1}g_t</script>
为了不算二阶导及其逆矩阵,设法构造一个矩阵 U U <script type="math/tex" id="MathJax-Element-40">U</script>,用它来逼近 H 1 <script type="math/tex" id="MathJax-Element-41">H^{-1}</script>
现在为了方便推导,假设 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-42">f(x)</script> 是二次函数,于是 Hesse 矩阵 H H <script type="math/tex" id="MathJax-Element-43">H</script> 是常数阵,任意两点 x ( t ) <script type="math/tex" id="MathJax-Element-44">x^{(t)}</script> 和 x(t+1) x ( t + 1 ) <script type="math/tex" id="MathJax-Element-45">x^{(t+1)}</script> 处的梯度之差是:
f(x(t+1))f(x(t))=H(x(t+1)x(t)) ▽ f ( x ( t + 1 ) ) − ▽ f ( x ( t ) ) = H ⋅ ( x ( t + 1 ) − x ( t ) )
<script type="math/tex; mode=display" id="MathJax-Element-46">\bigtriangledown f(x^{(t+1)}) - \bigtriangledown f(x^{(t)}) = H\cdot (x^{(t+1)}-x^{(t)}) </script>
等价于
x(t+1)x(t)=H1[f(x(t+1))f(x(t))] x ( t + 1 ) − x ( t ) = H − 1 ⋅ [ ▽ f ( x ( t + 1 ) ) − ▽ f ( x ( t ) ) ]
<script type="math/tex; mode=display" id="MathJax-Element-47">x^{(t+1)}-x^{(t)} = H^{-1}\cdot [\bigtriangledown f(x^{(t+1)}) - \bigtriangledown f(x^{(t)})] </script>
那么对非二次型的情况,也仿照这种形式,要求近似矩阵 U U <script type="math/tex" id="MathJax-Element-48">U</script> 满足类似的关系:
x ( t + 1 ) x ( t ) = U t + 1 [ f ( x ( t + 1 ) ) f ( x ( t ) ) ]
<script type="math/tex; mode=display" id="MathJax-Element-49">x^{(t+1)}-x^{(t)}=U_{t+1}\cdot [\bigtriangledown f(x^{(t+1)})-\bigtriangledown f(x^{(t)})]</script>
或者写成
Δxt=Ut+1Δgt Δ x t = U t + 1 ⋅ Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-50">\Delta x_t=U_{t+1}\cdot \Delta g_t</script>
以上就是 拟牛顿条件,不同的拟牛顿法,区别就在于如何确定 U U <script type="math/tex" id="MathJax-Element-51">U</script>。

2.3 DFP法

为了方便区分,下面把 U <script type="math/tex" id="MathJax-Element-52">U</script>称作 D D <script type="math/tex" id="MathJax-Element-53">D</script>(表示DFP)。

DFP推导

现在已知拟牛顿条件

Δ x t = D t + 1 Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-54">\Delta x_t=D_{t+1}\cdot \Delta g_t</script>
假设已知 Dt D t <script type="math/tex" id="MathJax-Element-55">D_t</script>,希望用叠加的方式求 Dt+1 D t + 1 <script type="math/tex" id="MathJax-Element-56">D_{t+1}</script>,即 Dt+1=Dt+ΔDt D t + 1 = D t + Δ D t <script type="math/tex" id="MathJax-Element-57">D_{t+1}=D_{t}+\Delta D_t</script>,代入得到
ΔDtΔgt=ΔxtDtΔgt Δ D t Δ g t = Δ x t − D t Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-58">\Delta D_t \Delta g_t=\Delta x_t - D_t \Delta g_t</script>
假设满足这个等式的 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-59">\Delta D_t</script> 是这样的形式:
ΔDt=ΔxtqTtDtΔgtwTt Δ D t = Δ x t ⋅ q t T − D t Δ g t ⋅ w t T
<script type="math/tex; mode=display" id="MathJax-Element-60">\Delta D_t=\Delta x_t \cdot q_t^T-D_t\Delta g_t\cdot w_t^T</script>
首先,对照一下就能发现:
qTtΔgt=wTtΔgt=In q t T ⋅ Δ g t = w t T ⋅ Δ g t = I n
<script type="math/tex; mode=display" id="MathJax-Element-61">q_t^T\cdot \Delta g_t=w_t^T \cdot \Delta g_t = I_n</script>
其次,要保证 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-62">\Delta D_t</script> 是对称的,参照 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-63">\Delta D_t</script> 的表达式,最简单就是令
qt=αtΔxtwt=βtDtΔgt q t = α t Δ x t w t = β t D t Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-64">q_t=\alpha_t \Delta x_t\\ w_t=\beta_t D_t\Delta g_t</script>
第二个条件代入第一个得到:
αt=1ΔgTtΔxtβt=1ΔgTtDtΔgt α t = 1 Δ g t T Δ x t β t = 1 Δ g t T D t Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-65">\alpha_t=\frac{1}{\Delta g_t^T\Delta x_t} \\\beta_t=\frac{1}{\Delta g_t^TD_t\Delta g_t}</script>
然后代入回 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-66">\Delta D_t</script> 的表达式:
ΔDt=ΔxtΔxTtΔgTtΔxtDtΔgtΔgTtDtΔgTtDtΔgt Δ D t = Δ x t Δ x t T Δ g t T Δ x t − D t Δ g t Δ g t T D t Δ g t T D t Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-67">\Delta D_t = \frac{\Delta x_t\Delta x_t^T}{\Delta g_t^T\Delta x_t}-\frac{D_t\Delta g_t\Delta g_t^TD_t}{\Delta g_t^TD_t\Delta g_t}</script>
观察一下两项分式,第一项仅涉及向量乘法,时间复杂度是 O(n) O ( n ) <script type="math/tex" id="MathJax-Element-68">O(n)</script>,第二项涉及矩阵乘法,时间复杂度是 O(n2) O ( n 2 ) <script type="math/tex" id="MathJax-Element-69">O(n^2)</script>,综合起来是 O(n2) O ( n 2 ) <script type="math/tex" id="MathJax-Element-70">O(n^2)</script>。

DFP算法步骤

(1)给定初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-71">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-72">\epsilon</script>,令 D0=In D 0 = I n <script type="math/tex" id="MathJax-Element-73">D_0=I_n</script>( n n <script type="math/tex" id="MathJax-Element-74">n</script>是 x <script type="math/tex" id="MathJax-Element-75">x</script>的维数), t=0 t = 0 <script type="math/tex" id="MathJax-Element-76">t=0</script>
(2)计算搜索方向 d(t)=D1tgt d ( t ) = − D t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-77">d^{(t)}=-D_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-78">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-79">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:

λt=argminλf(x(t)+λd(t))x(t+1)=x(t)+λtd(t) λ t = arg ⁡ min λ f ( x ( t ) + λ ⋅ d ( t ) ) x ( t + 1 ) = x ( t ) + λ t ⋅ d ( t )
<script type="math/tex; mode=display" id="MathJax-Element-80">\lambda_t = \arg \min_{\lambda} f(x^{(t)}+\lambda\cdot d^{(t)})\\x^{(t+1)}=x^{(t)}+\lambda_t\cdot d^{(t)}</script>
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-81">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)计算 Δg=gt+1gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-82">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-83">\Delta x=x^{(t+1)}-x^{(t)}</script>,更新 H H <script type="math/tex" id="MathJax-Element-84">H</script>
D t + 1 = D t + Δ x Δ x T Δ g T Δ x D t Δ g Δ g T D t Δ g T D t Δ g
<script type="math/tex; mode=display" id="MathJax-Element-85">D_{t+1}=D_{t}+\frac{\Delta x\Delta x^T}{\Delta g^T\Delta x}-\frac{D_t\Delta g\Delta g^TD_t}{\Delta g^TD_t\Delta g}</script>
(6) t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-86">t=t+1</script>,转(2)

2.4 BFGS法

为了方便区分,下面把 U U <script type="math/tex" id="MathJax-Element-87">U</script>称作 B 1 <script type="math/tex" id="MathJax-Element-88">B^{-1}</script>(表示BFGS)。

BFGS推导

拟牛顿条件

Δxt=B1t+1ΔgtΔgt=Bt+1Δxt Δ x t = B t + 1 − 1 ⋅ Δ g t Δ g t = B t + 1 ⋅ Δ x t
<script type="math/tex; mode=display" id="MathJax-Element-89">\Delta x_t=B_{t+1}^{-1}\cdot \Delta g_t\\ \Delta g_t=B_{t+1} \cdot \Delta x_t</script>
推导与DFP相似,但是,可以看到BFGS这种拟牛顿条件的形式与BFP的是对偶的,所以迭代公式只要把 Δxt Δ x t <script type="math/tex" id="MathJax-Element-90">\Delta x_t</script> 和 Δgt Δ g t <script type="math/tex" id="MathJax-Element-91">\Delta g_t</script> 调换一下就好。
ΔBt=ΔgtΔgTtΔxTtΔgtBtΔxtΔxTtBtΔxTtBtΔxt Δ B t = Δ g t Δ g t T Δ x t T Δ g t − B t Δ x t Δ x t T B t Δ x t T B t Δ x t
<script type="math/tex; mode=display" id="MathJax-Element-92">\Delta B_t = \frac{\Delta g_t\Delta g_t^T}{\Delta x_t^T\Delta g_t}-\frac{B_t\Delta x_t\Delta x_t^TB_t}{\Delta x_t^TB_t\Delta x_t}</script>
只不过有个问题,按照下面这个迭代公式,不也一样要求逆吗?这就要引入谢尔曼莫里森公式了。
Δxt=B1t+1Δgt Δ x t = B t + 1 − 1 ⋅ Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-93">\Delta x_t=B_{t+1}^{-1}\cdot \Delta g_t</script>

Sherman-Morrison 公式

对于任意非奇异方阵 A A <script type="math/tex" id="MathJax-Element-94">A</script>, u , v R n <script type="math/tex" id="MathJax-Element-95">u,v\in R^n</script>是 n n <script type="math/tex" id="MathJax-Element-96">n</script>维向量,若 1 + v T A 1 u 0 <script type="math/tex" id="MathJax-Element-97">1+v^TA^{-1}u\neq 0</script>,则

(A+uvT)1=A1(A1u)(vTA1)1+vTA1u ( A + u v T ) − 1 = A − 1 − ( A − 1 u ) ( v T A − 1 ) 1 + v T A − 1 u
<script type="math/tex; mode=display" id="MathJax-Element-98">(A+uv^T)^{-1} = A^{-1}-\frac{(A^{-1}u)(v^TA^{-1})}{1+v^TA^{-1}u}</script>
该公式描述了在矩阵 A A <script type="math/tex" id="MathJax-Element-99">A</script>发生某种变化时,如何利用之前求好的逆,求新的逆。
对迭代公式引入两次 Sherman-Morrison 公式就能得到
B t + 1 1 = ( I n Δ x t Δ g t T Δ x t T Δ g t ) B t 1 ( I n Δ g t Δ x t T Δ x t T Δ g t ) + Δ x t Δ x t T Δ x t T Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-100">B^{-1}_{t+1}=\left (I_n-\frac{\Delta x_t \Delta g_t^T}{\Delta x_t^T \Delta g_t}\right )B_{t}^{-1}\left (I_n-\frac{\Delta g_t \Delta x_t^T}{\Delta x_t^T \Delta g_t}\right )+\frac{\Delta x_t \Delta x_t^T}{\Delta x_t^T \Delta g_t}</script>
就得到了逆矩阵之间的推导。可能有人会问,第一个矩阵不也要求逆吗?其实这是一个迭代算法,初始矩阵设为单位矩阵(对角阵也可以)就不用求逆了。
这个公式的详细推导可以参考 这里或者 这里

BFGS算法步骤

虽然下面的矩阵写成 B1 B − 1 <script type="math/tex" id="MathJax-Element-101">B^{-1}</script>,但要明确,BFGS从头到尾都不需要算逆,把下面的 B1 B − 1 <script type="math/tex" id="MathJax-Element-102">B^{-1}</script>换成 H H <script type="math/tex" id="MathJax-Element-103">H</script>这个符号,也是一样的。
(1)给定初始点 x ( 0 ) <script type="math/tex" id="MathJax-Element-104">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-105">\epsilon</script>,设置 B10 B 0 − 1 <script type="math/tex" id="MathJax-Element-106">B_0^{-1}</script>, t=0 t = 0 <script type="math/tex" id="MathJax-Element-107">t=0</script>
(2)计算搜索 d(t)=B1tgt d ( t ) = − B t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-108">d^{(t)}=-B_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-109">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-110">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:

λt=argminλf(x(t)+λd(t))x(t+1)=x(t)+λtd(t) λ t = arg ⁡ min λ f ( x ( t ) + λ ⋅ d ( t ) ) x ( t + 1 ) = x ( t ) + λ t ⋅ d ( t )
<script type="math/tex; mode=display" id="MathJax-Element-111">\lambda_t = \arg \min_{\lambda} f(x^{(t)}+\lambda\cdot d^{(t)})\\x^{(t+1)}=x^{(t)}+\lambda_t\cdot d^{(t)}</script>
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-112">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)计算 Δg=gt+1gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-113">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-114">\Delta x=x^{(t+1)}-x^{(t)}</script>,更新 B1 B − 1 <script type="math/tex" id="MathJax-Element-115">B^{-1}</script>,然后
B1t+1=(InΔxtΔgTtΔxTtΔgt)B1t(InΔgtΔxTtΔxTtΔgt)+ΔxtΔxTtΔxTtΔgt B t + 1 − 1 = ( I n − Δ x t Δ g t T Δ x t T Δ g t ) B t − 1 ( I n − Δ g t Δ x t T Δ x t T Δ g t ) + Δ x t Δ x t T Δ x t T Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-116">B^{-1}_{t+1}=\left (I_n-\frac{\Delta x_t \Delta g_t^T}{\Delta x_t^T \Delta g_t}\right )B_{t}^{-1}\left (I_n-\frac{\Delta g_t \Delta x_t^T}{\Delta x_t^T \Delta g_t}\right )+\frac{\Delta x_t \Delta x_t^T}{\Delta x_t^T \Delta g_t}</script>
(6) t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-117">t=t+1</script>,转(2)

2.5 L-BFGS法(Limited-memory BFGS)

对于 d d <script type="math/tex" id="MathJax-Element-118">d</script>维参数,BFGS算法需要保存一个 O ( d 2 ) <script type="math/tex" id="MathJax-Element-119">O(d^2)</script>大小的 B1 B − 1 <script type="math/tex" id="MathJax-Element-120">B^{-1}</script>矩阵,实际上只需要每一轮的 Δx Δ x <script type="math/tex" id="MathJax-Element-121">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-122">\Delta g</script>,也可以递归计算出当前迭代的 B1 B − 1 <script type="math/tex" id="MathJax-Element-123">B^{-1}</script>矩阵,L-BFGS就是基于这种思想,实现了节省内存的BFGS。

L-BFGS推导

BFGS的递推公式:

B1t+1=(InΔxtΔgTtΔxTtΔgt)B1t(InΔgtΔxTtΔxTtΔgt)+ΔxtΔxTtΔxTtΔgt B t + 1 − 1 = ( I n − Δ x t Δ g t T Δ x t T Δ g t ) B t − 1 ( I n − Δ g t Δ x t T Δ x t T Δ g t ) + Δ x t Δ x t T Δ x t T Δ g t
<script type="math/tex; mode=display" id="MathJax-Element-124">B^{-1}_{t+1}=\left (I_n-\frac{\Delta x_t \Delta g_t^T}{\Delta x_t^T \Delta g_t}\right )B_{t}^{-1}\left (I_n-\frac{\Delta g_t \Delta x_t^T}{\Delta x_t^T \Delta g_t}\right )+\frac{\Delta x_t\Delta x_t^T}{\Delta x_t^T \Delta g_t}</script>
现在假设 ρt=1ΔxTtΔgt ρ t = 1 Δ x t T Δ g t <script type="math/tex" id="MathJax-Element-125">\rho_t = \frac{1}{\Delta x_t^T \Delta g_t}</script>, Vt=InρtΔgtΔxTt V t = I n − ρ t Δ g t Δ x t T <script type="math/tex" id="MathJax-Element-126">V_t = I_n-\rho_t \Delta g_t \Delta x_t^T</script>,则递推公式可以写成
B1t+1=VTtB1tVt+ρtΔxtΔxTt B t + 1 − 1 = V t T B t − 1 V t + ρ t Δ x t Δ x t T
<script type="math/tex; mode=display" id="MathJax-Element-127">B^{-1}_{t+1}=V_t^TB^{-1}_{t}V_t+\rho_t \Delta x_t \Delta x_t^T</script>
给定的初始矩阵 B10 B 0 − 1 <script type="math/tex" id="MathJax-Element-128">B^{-1}_{0}</script>后,之后的每一轮都可以递推计算
B11=VT0B10V0+ρ0Δx0ΔxT0B12=VT1B10V1+ρ1Δx1ΔxT1=(VT1VT0)B10(V0V1)+VT1ρ0Δx0ΔxT0V1+ρ1Δx1ΔxT1 B 1 − 1 = V 0 T B 0 − 1 V 0 + ρ 0 Δ x 0 Δ x 0 T B 2 − 1 = V 1 T B 0 − 1 V 1 + ρ 1 Δ x 1 Δ x 1 T = ( V 1 T V 0 T ) B 0 − 1 ( V 0 V 1 ) + V 1 T ρ 0 Δ x 0 Δ x 0 T V 1 + ρ 1 Δ x 1 Δ x 1 T
<script type="math/tex; mode=display" id="MathJax-Element-129">B^{-1}_{1}=V_0^TB^{-1}_{0}V_0+\rho_0 \Delta x_0 \Delta x_0^T\\B^{-1}_{2}=V_1^TB^{-1}_{0}V_1+\rho_1 \Delta x_1 \Delta x_1^T\\ =(V_1^TV_0^T)B^{-1}_{0}(V_0V_1)+V_1^T\rho_0\Delta x_0 \Delta x_0^TV_1+\rho_1 \Delta x_1 \Delta x_1^T</script>
一直到最后 B1k+1 B k + 1 − 1 <script type="math/tex" id="MathJax-Element-130">B^{-1}_{k+1}</script>可以由 t=0 t = 0 <script type="math/tex" id="MathJax-Element-131">t=0</script>到 t=k t = k <script type="math/tex" id="MathJax-Element-132">t=k</script>的 Δxt Δ x t <script type="math/tex" id="MathJax-Element-133">\Delta x_t</script>和 Δgt Δ g t <script type="math/tex" id="MathJax-Element-134">\Delta g_t</script>表示:
B1t+1=++++(VTtVTt1VT1VT0)B10(V0Vt1Vt)(VTtVTt1VT2VT1)(ρ0Δx0ΔxT0)(V1Vt1Vt)VTt(ρt1Δxt1ΔxTt1)VtρtΔxtΔxTt B t + 1 − 1 = ( V t T V t − 1 T ⋯ V 1 T V 0 T ) B 0 − 1 ( V 0 ⋯ V t − 1 V t ) + ( V t T V t − 1 T ⋯ V 2 T V 1 T ) ( ρ 0 Δ x 0 Δ x 0 T ) ( V 1 ⋯ V t − 1 V t ) + ⋯ + V t T ( ρ t − 1 Δ x t − 1 Δ x t − 1 T ) V t + ρ t Δ x t Δ x t T
<script type="math/tex; mode=display" id="MathJax-Element-135">\begin{matrix} B^{-1}_{t+1} & = & & (V_t^TV_{t-1}^T\cdots V_1^TV_0^T)B^{-1}_0 (V_0\cdots V_{t-1}V_t) \\ & & + & (V_t^TV_{t-1}^T\cdots V_2^TV_1^T)(\rho_0\Delta x_0\Delta x_0^T)(V_1\cdots V_{t-1}V_t)\\ & & + & \cdots\\ & & + & V_t^T(\rho_{t-1} \Delta x_{t-1} \Delta x_{t-1}^T)V_t\\ & & + & \rho_t \Delta x_t \Delta x_t^T \end{matrix} </script>
看起来很长,其实可以写成一个求和项
B1t+1=(i=t0VTi)B10(i=0tVi)+j=0t(i=tj+1VTi)(ρjΔxjΔxTj)(i=j+1tVi) B t + 1 − 1 = ( ∏ i = t 0 V i T ) B 0 − 1 ( ∏ i = 0 t V i ) + ∑ j = 0 t ( ∏ i = t j + 1 V i T ) ( ρ j Δ x j Δ x j T ) ( ∏ i = j + 1 t V i )
<script type="math/tex; mode=display" id="MathJax-Element-136">B^{-1}_{t+1} = \left (\prod_{i=t}^0 V_i^T \right )B_0^{-1} \left (\prod_{i=0}^t V_i\right )+\sum_{j=0}^{t} \left (\prod_{i=t}^{j+1} V_i^T \right ) \left ( \rho_j\Delta x_j \Delta x_j^T\right ) \left (\prod_{i=j+1}^t V_i \right )</script>
这个求和项包含了从 0 0 <script type="math/tex" id="MathJax-Element-137">0</script>到 t <script type="math/tex" id="MathJax-Element-138">t</script>的所有 Δx Δ x <script type="math/tex" id="MathJax-Element-139">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-140">\Delta g</script>,而根据实际需要,可以只取最近的 m m <script type="math/tex" id="MathJax-Element-141">m</script>个,也就是:
B t 1 = ( i = t 1 t m V i T ) B 0 1 ( i = t m t 1 V i ) + j = t 1 t m ( i = t j + 1 V i T ) ( ρ j Δ x j Δ x j T ) ( i = j + 1 t V i )
<script type="math/tex; mode=display" id="MathJax-Element-142">B^{-1}_{t} = \left (\prod_{i=t-1}^{t-m} V_i^T \right )B_0^{-1} \left (\prod_{i=t-m}^{t-1} V_i\right )+\sum_{j=t-1}^{t-m} \left (\prod_{i=t}^{j+1} V_i^T \right ) \left ( \rho_j\Delta x_j \Delta x_j^T\right ) \left (\prod_{i=j+1}^t V_i \right )</script>

工程上的L-BFGS

我们关心的其实不是 B1t B t − 1 <script type="math/tex" id="MathJax-Element-143">B^{-1}_t</script>本身如何,算 B1t B t − 1 <script type="math/tex" id="MathJax-Element-144">B^{-1}_t</script>的根本目的是要算本轮搜索方向 B1tgt B t − 1 g t <script type="math/tex" id="MathJax-Element-145">B^{-1}_tg_t</script>
以下算法摘自《Numerical Optimization》,它可以高效地计算出拟牛顿法每一轮的搜索方向。仔细观察一下,你会发现它实际上就是复现上面推导的那一堆很长的递推公式,你所需要的是最近 m m <script type="math/tex" id="MathJax-Element-146">m</script>轮的 Δ x <script type="math/tex" id="MathJax-Element-147">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-148">\Delta g</script>,后向和前向算完得到最终的 r r <script type="math/tex" id="MathJax-Element-149">r</script> 就是搜索方向 B t 1 g t <script type="math/tex" id="MathJax-Element-150">B^{-1}_tg_t</script>,之后要做一维搜索或者什么的都可以。
解释一下算法的符号和本文符号之间的对应关系, si=Δxi s i = Δ x i <script type="math/tex" id="MathJax-Element-151">s_i=\Delta x_i</script>, yi=Δgi y i = Δ g i <script type="math/tex" id="MathJax-Element-152">y_i=\Delta g_i</script>, Hk=B1k H k = B k − 1 <script type="math/tex" id="MathJax-Element-153">H_k=B_k^{-1}</script>
代码实现可以参考这里

图片名称
L-BFGS算法步骤

(1)给定初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-250">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-251">\epsilon</script>,预定保留最近 m m <script type="math/tex" id="MathJax-Element-252">m</script>个向量,设置 B 0 1 <script type="math/tex" id="MathJax-Element-253">B_0^{-1}</script>, t=0 t = 0 <script type="math/tex" id="MathJax-Element-254">t=0</script>
(2)用Algorithm 9.1计算搜索方向 d(t)=B1tgt d ( t ) = − B t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-255">d^{(t)}=-B_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-256">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-257">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:

λt=argminλf(x(t)+λd(t))x(t+1)=x(t)+λtd(t) λ t = arg ⁡ min λ f ( x ( t ) + λ ⋅ d ( t ) ) x ( t + 1 ) = x ( t ) + λ t ⋅ d ( t )
<script type="math/tex; mode=display" id="MathJax-Element-258">\lambda_t = \arg \min_{\lambda} f(x^{(t)}+\lambda\cdot d^{(t)})\\x^{(t+1)}=x^{(t)}+\lambda_t\cdot d^{(t)}</script>
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-259">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)判断 t>m t > m <script type="math/tex" id="MathJax-Element-260">t>m</script>,删掉存储的 Δxtm Δ x t − m <script type="math/tex" id="MathJax-Element-261">\Delta x_{t-m}</script> 和 Δgtm Δ g t − m <script type="math/tex" id="MathJax-Element-262">\Delta g_{t-m}</script>
(5)计算 Δg=gt+1gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-263">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-264">\Delta x=x^{(t+1)}-x^{(t)}</script>,令 t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-265">t=t+1</script>,转(2)


最后,有时候你看不懂BFGS到底意味着什么,并不是你英文差,而是因为这个简称真的没有意义。。。。。



参考资料

  1. 【博客】LBFGS方法推导-慢慢的回味
  2. 【博客】数值优化:理解L-BFGS算法
  3. 【博客】无约束优化算法——牛顿法与拟牛顿法(DFP,BFGS,LBFGS)
  4. 【博客】无约束最优化方法——牛顿法、拟牛顿法、BFGS、LBFGS
  5. 【博客】Numerical Optimization: Understanding L-BFGS
  6. 【论文】A Stochastic Quasi-Newton Method for Online Convex Optimization
  7. 【书籍】Numeric Optimization
Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐