拟牛顿法(DFP、BFGS、L-BFGS)
拟牛顿法一、牛顿法1.1 基本介绍牛顿法属于利用一阶和二阶导数的无约束目标最优化方法。基本思想是,在每一次迭代中,以牛顿方向为搜索方向进行更新。牛顿法对目标的可导性更严格,要求二阶可导,有Hesse矩阵求逆的计算复杂的缺点。XGBoost本质上就是利用牛顿法进行优化的。1.2 基本原理现在推导牛顿法。假设无约束最优化问题是minxf(x)minxf(x)\m...
拟牛顿法
一、牛顿法
1.1 基本介绍
牛顿法属于利用一阶和二阶导数的无约束目标最优化方法。基本思想是,在每一次迭代中,以牛顿方向为搜索方向进行更新。牛顿法对目标的可导性更严格,要求二阶可导,有Hesse矩阵求逆的计算复杂的缺点。XGBoost本质上就是利用牛顿法进行优化的。
1.2 基本原理
现在推导牛顿法。
假设无约束最优化问题是
对于一维 x x <script type="math/tex" id="MathJax-Element-2">x</script> 的情况,可以将 <script type="math/tex" id="MathJax-Element-3">f(x^{(t+1)})</script> 在 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-4">x^{(t)}</script> 附近用二阶泰勒展开近似:
然后用泰勒展开的极值点近似 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-6">f(x)</script> 的极值点:
因此
于是得到迭代公式, g g <script type="math/tex" id="MathJax-Element-9">g</script>和 <script type="math/tex" id="MathJax-Element-10">h</script>分别是目标在当前 x x <script type="math/tex" id="MathJax-Element-11">x</script>上的一阶和二阶导
推广到 x x <script type="math/tex" id="MathJax-Element-13">x</script>是多维向量的情况, <script type="math/tex" id="MathJax-Element-14">g_t</script> 仍然是向量,而 Ht H t <script type="math/tex" id="MathJax-Element-15">H_t</script> 是Hesse矩阵
以二维 x=(x1,x2) x = ( x 1 , x 2 ) <script type="math/tex" id="MathJax-Element-17">x=(x_1,x_2)</script>为例:
参数更新方程推广为:
可见,每一次迭代的更新方向都是当前点的牛顿方向,步长固定为1。每一次都需要计算一阶导数 g g <script type="math/tex" id="MathJax-Element-20">g</script>以及Hesse矩阵的逆矩阵,对于高维特征而言,求逆矩阵的计算量巨大且耗时。
1.3 阻尼牛顿法
从上面的推导中看出,牛顿方向 <script type="math/tex" id="MathJax-Element-21">-H^{-1}g</script> 能使得更新后函数处于极值点,但是它不一定是极小点,也就是说牛顿方向可能是下降方向,也可能是上升方向,以至于当初始点远离极小点时,牛顿法有可能不收敛。因此提出 阻尼牛顿法,在牛顿法的基础上,每次迭代除了计算更新方向(牛顿方向),还要对最优步长做一维搜索。
算法步骤
(1)给定给初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-22">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-23">\epsilon</script>
(2)计算点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-24">x^{(t)}</script> 处梯度 gt g t <script type="math/tex" id="MathJax-Element-25">g_t</script>和Hesse矩阵 H H <script type="math/tex" id="MathJax-Element-26">H</script>,若
<script type="math/tex" id="MathJax-Element-27">|g_t|<\epsilon</script>则停止迭代
(3)计算点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-28">x^{(t)}</script> 处的牛顿方向作为搜索方向:
(4)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-30">x^{(t)}</script> 出发,沿着牛顿方向 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-31">d^{(t)}</script> 做一维搜索,获得最优步长:
(5)更新参数
二、拟牛顿法
2.1 提出的初衷
牛顿法中的Hesse矩阵 H H <script type="math/tex" id="MathJax-Element-34">H</script>在稠密时求逆计算量大,也有可能没有逆(Hesse矩阵非正定)。拟牛顿法提出,用不含二阶导数的矩阵
<script type="math/tex" id="MathJax-Element-35">U_t</script> 替代牛顿法中的 H−1t H t − 1 <script type="math/tex" id="MathJax-Element-36">H_t^{-1}</script>,然后沿搜索方向 −Utgt − U t g t <script type="math/tex" id="MathJax-Element-37">-U_tg_t</script> 做一维搜索。根据不同的 Ut U t <script type="math/tex" id="MathJax-Element-38">U_t</script> 构造方法有不同的拟牛顿法。
注意拟牛顿法的 关键词:
- 不用算二阶导数
- 不用求逆
2.2 拟牛顿条件
牛顿法的搜索方向是
为了不算二阶导及其逆矩阵,设法构造一个矩阵 U U <script type="math/tex" id="MathJax-Element-40">U</script>,用它来逼近 <script type="math/tex" id="MathJax-Element-41">H^{-1}</script>
现在为了方便推导,假设 f(x) f ( x ) <script type="math/tex" id="MathJax-Element-42">f(x)</script> 是二次函数,于是 Hesse 矩阵 H H <script type="math/tex" id="MathJax-Element-43">H</script> 是常数阵,任意两点 <script type="math/tex" id="MathJax-Element-44">x^{(t)}</script> 和 x(t+1) x ( t + 1 ) <script type="math/tex" id="MathJax-Element-45">x^{(t+1)}</script> 处的梯度之差是:
等价于
那么对非二次型的情况,也仿照这种形式,要求近似矩阵 U U <script type="math/tex" id="MathJax-Element-48">U</script> 满足类似的关系:
或者写成
以上就是 拟牛顿条件,不同的拟牛顿法,区别就在于如何确定 U U <script type="math/tex" id="MathJax-Element-51">U</script>。
2.3 DFP法
为了方便区分,下面把 <script type="math/tex" id="MathJax-Element-52">U</script>称作 D D <script type="math/tex" id="MathJax-Element-53">D</script>(表示DFP)。
DFP推导
现在已知拟牛顿条件
假设已知 Dt D t <script type="math/tex" id="MathJax-Element-55">D_t</script>,希望用叠加的方式求 Dt+1 D t + 1 <script type="math/tex" id="MathJax-Element-56">D_{t+1}</script>,即 Dt+1=Dt+ΔDt D t + 1 = D t + Δ D t <script type="math/tex" id="MathJax-Element-57">D_{t+1}=D_{t}+\Delta D_t</script>,代入得到
假设满足这个等式的 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-59">\Delta D_t</script> 是这样的形式:
首先,对照一下就能发现:
其次,要保证 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-62">\Delta D_t</script> 是对称的,参照 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-63">\Delta D_t</script> 的表达式,最简单就是令
第二个条件代入第一个得到:
然后代入回 ΔDt Δ D t <script type="math/tex" id="MathJax-Element-66">\Delta D_t</script> 的表达式:
观察一下两项分式,第一项仅涉及向量乘法,时间复杂度是 O(n) O ( n ) <script type="math/tex" id="MathJax-Element-68">O(n)</script>,第二项涉及矩阵乘法,时间复杂度是 O(n2) O ( n 2 ) <script type="math/tex" id="MathJax-Element-69">O(n^2)</script>,综合起来是 O(n2) O ( n 2 ) <script type="math/tex" id="MathJax-Element-70">O(n^2)</script>。
DFP算法步骤
(1)给定初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-71">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-72">\epsilon</script>,令 D0=In D 0 = I n <script type="math/tex" id="MathJax-Element-73">D_0=I_n</script>( n n <script type="math/tex" id="MathJax-Element-74">n</script>是
<script type="math/tex" id="MathJax-Element-75">x</script>的维数), t=0 t = 0 <script type="math/tex" id="MathJax-Element-76">t=0</script>
(2)计算搜索方向 d(t)=−D−1t⋅gt d ( t ) = − D t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-77">d^{(t)}=-D_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-78">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-79">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-81">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)计算 Δg=gt+1−gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-82">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)−x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-83">\Delta x=x^{(t+1)}-x^{(t)}</script>,更新 H H <script type="math/tex" id="MathJax-Element-84">H</script>
(6) t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-86">t=t+1</script>,转(2)
2.4 BFGS法
为了方便区分,下面把 U U <script type="math/tex" id="MathJax-Element-87">U</script>称作 <script type="math/tex" id="MathJax-Element-88">B^{-1}</script>(表示BFGS)。
BFGS推导
拟牛顿条件
推导与DFP相似,但是,可以看到BFGS这种拟牛顿条件的形式与BFP的是对偶的,所以迭代公式只要把 Δxt Δ x t <script type="math/tex" id="MathJax-Element-90">\Delta x_t</script> 和 Δgt Δ g t <script type="math/tex" id="MathJax-Element-91">\Delta g_t</script> 调换一下就好。
只不过有个问题,按照下面这个迭代公式,不也一样要求逆吗?这就要引入谢尔曼莫里森公式了。
Sherman-Morrison 公式
对于任意非奇异方阵 A A <script type="math/tex" id="MathJax-Element-94">A</script>,
<script type="math/tex" id="MathJax-Element-95">u,v\in R^n</script>是 n n <script type="math/tex" id="MathJax-Element-96">n</script>维向量,若
<script type="math/tex" id="MathJax-Element-97">1+v^TA^{-1}u\neq 0</script>,则
该公式描述了在矩阵 A A <script type="math/tex" id="MathJax-Element-99">A</script>发生某种变化时,如何利用之前求好的逆,求新的逆。
对迭代公式引入两次 Sherman-Morrison 公式就能得到
就得到了逆矩阵之间的推导。可能有人会问,第一个矩阵不也要求逆吗?其实这是一个迭代算法,初始矩阵设为单位矩阵(对角阵也可以)就不用求逆了。
这个公式的详细推导可以参考 这里或者 这里。
BFGS算法步骤
虽然下面的矩阵写成 B−1 B − 1 <script type="math/tex" id="MathJax-Element-101">B^{-1}</script>,但要明确,BFGS从头到尾都不需要算逆,把下面的 B−1 B − 1 <script type="math/tex" id="MathJax-Element-102">B^{-1}</script>换成 H H <script type="math/tex" id="MathJax-Element-103">H</script>这个符号,也是一样的。
(1)给定初始点
<script type="math/tex" id="MathJax-Element-104">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-105">\epsilon</script>,设置 B−10 B 0 − 1 <script type="math/tex" id="MathJax-Element-106">B_0^{-1}</script>, t=0 t = 0 <script type="math/tex" id="MathJax-Element-107">t=0</script>
(2)计算搜索 d(t)=−B−1t⋅gt d ( t ) = − B t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-108">d^{(t)}=-B_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-109">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-110">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-112">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)计算 Δg=gt+1−gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-113">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)−x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-114">\Delta x=x^{(t+1)}-x^{(t)}</script>,更新 B−1 B − 1 <script type="math/tex" id="MathJax-Element-115">B^{-1}</script>,然后
(6) t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-117">t=t+1</script>,转(2)
2.5 L-BFGS法(Limited-memory BFGS)
对于 d d <script type="math/tex" id="MathJax-Element-118">d</script>维参数,BFGS算法需要保存一个 <script type="math/tex" id="MathJax-Element-119">O(d^2)</script>大小的 B−1 B − 1 <script type="math/tex" id="MathJax-Element-120">B^{-1}</script>矩阵,实际上只需要每一轮的 Δx Δ x <script type="math/tex" id="MathJax-Element-121">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-122">\Delta g</script>,也可以递归计算出当前迭代的 B−1 B − 1 <script type="math/tex" id="MathJax-Element-123">B^{-1}</script>矩阵,L-BFGS就是基于这种思想,实现了节省内存的BFGS。
L-BFGS推导
BFGS的递推公式:
现在假设 ρt=1ΔxTtΔgt ρ t = 1 Δ x t T Δ g t <script type="math/tex" id="MathJax-Element-125">\rho_t = \frac{1}{\Delta x_t^T \Delta g_t}</script>, Vt=In−ρtΔgtΔxTt V t = I n − ρ t Δ g t Δ x t T <script type="math/tex" id="MathJax-Element-126">V_t = I_n-\rho_t \Delta g_t \Delta x_t^T</script>,则递推公式可以写成
给定的初始矩阵 B−10 B 0 − 1 <script type="math/tex" id="MathJax-Element-128">B^{-1}_{0}</script>后,之后的每一轮都可以递推计算
一直到最后 B−1k+1 B k + 1 − 1 <script type="math/tex" id="MathJax-Element-130">B^{-1}_{k+1}</script>可以由 t=0 t = 0 <script type="math/tex" id="MathJax-Element-131">t=0</script>到 t=k t = k <script type="math/tex" id="MathJax-Element-132">t=k</script>的 Δxt Δ x t <script type="math/tex" id="MathJax-Element-133">\Delta x_t</script>和 Δgt Δ g t <script type="math/tex" id="MathJax-Element-134">\Delta g_t</script>表示:
看起来很长,其实可以写成一个求和项
这个求和项包含了从 0 0 <script type="math/tex" id="MathJax-Element-137">0</script>到 <script type="math/tex" id="MathJax-Element-138">t</script>的所有 Δx Δ x <script type="math/tex" id="MathJax-Element-139">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-140">\Delta g</script>,而根据实际需要,可以只取最近的 m m <script type="math/tex" id="MathJax-Element-141">m</script>个,也就是:
工程上的L-BFGS
我们关心的其实不是 B−1t B t − 1 <script type="math/tex" id="MathJax-Element-143">B^{-1}_t</script>本身如何,算 B−1t B t − 1 <script type="math/tex" id="MathJax-Element-144">B^{-1}_t</script>的根本目的是要算本轮搜索方向 B−1tgt B t − 1 g t <script type="math/tex" id="MathJax-Element-145">B^{-1}_tg_t</script>
以下算法摘自《Numerical Optimization》,它可以高效地计算出拟牛顿法每一轮的搜索方向。仔细观察一下,你会发现它实际上就是复现上面推导的那一堆很长的递推公式,你所需要的是最近 m m <script type="math/tex" id="MathJax-Element-146">m</script>轮的
<script type="math/tex" id="MathJax-Element-147">\Delta x</script>和 Δg Δ g <script type="math/tex" id="MathJax-Element-148">\Delta g</script>,后向和前向算完得到最终的 r r <script type="math/tex" id="MathJax-Element-149">r</script> 就是搜索方向
<script type="math/tex" id="MathJax-Element-150">B^{-1}_tg_t</script>,之后要做一维搜索或者什么的都可以。
解释一下算法的符号和本文符号之间的对应关系, si=Δxi s i = Δ x i <script type="math/tex" id="MathJax-Element-151">s_i=\Delta x_i</script>, yi=Δgi y i = Δ g i <script type="math/tex" id="MathJax-Element-152">y_i=\Delta g_i</script>, Hk=B−1k H k = B k − 1 <script type="math/tex" id="MathJax-Element-153">H_k=B_k^{-1}</script>
代码实现可以参考这里。

L-BFGS算法步骤
(1)给定初始点 x(0) x ( 0 ) <script type="math/tex" id="MathJax-Element-250">x^{(0)}</script>,允许误差 ϵ ϵ <script type="math/tex" id="MathJax-Element-251">\epsilon</script>,预定保留最近 m m <script type="math/tex" id="MathJax-Element-252">m</script>个向量,设置
<script type="math/tex" id="MathJax-Element-253">B_0^{-1}</script>, t=0 t = 0 <script type="math/tex" id="MathJax-Element-254">t=0</script>
(2)用Algorithm 9.1计算搜索方向 d(t)=−B−1t⋅gt d ( t ) = − B t − 1 ⋅ g t <script type="math/tex" id="MathJax-Element-255">d^{(t)}=-B_t^{-1}\cdot g_t</script>
(3)从点 x(t) x ( t ) <script type="math/tex" id="MathJax-Element-256">x^{(t)}</script> 出发,沿着 d(t) d ( t ) <script type="math/tex" id="MathJax-Element-257">d^{(t)}</script> 做一维搜索,获得最优步长并更新参数:
(4)判断精度,若 |gt+1|<ϵ | g t + 1 | < ϵ <script type="math/tex" id="MathJax-Element-259">|g_{t+1}|<\epsilon</script>则停止迭代,否则转(5)
(5)判断 t>m t > m <script type="math/tex" id="MathJax-Element-260">t>m</script>,删掉存储的 Δxt−m Δ x t − m <script type="math/tex" id="MathJax-Element-261">\Delta x_{t-m}</script> 和 Δgt−m Δ g t − m <script type="math/tex" id="MathJax-Element-262">\Delta g_{t-m}</script>
(5)计算 Δg=gt+1−gt Δ g = g t + 1 − g t <script type="math/tex" id="MathJax-Element-263">\Delta g=g_{t+1}-g_t</script>, Δx=x(t+1)−x(t) Δ x = x ( t + 1 ) − x ( t ) <script type="math/tex" id="MathJax-Element-264">\Delta x=x^{(t+1)}-x^{(t)}</script>,令 t=t+1 t = t + 1 <script type="math/tex" id="MathJax-Element-265">t=t+1</script>,转(2)
最后,有时候你看不懂BFGS到底意味着什么,并不是你英文差,而是因为这个简称真的没有意义。。。。。

参考资料
更多推荐
所有评论(0)