详解神经网络的前向传播和反向传播

本篇博客是对Michael Nielsen所著的《Neural Network and Deep Learning》第2章内容的解读,有兴趣的朋友可以直接阅读原文Neural Network and Deep Learning

  对神经网络有些了解的人可能都知道,神经网络其实就是一个输入 X X <script type="math/tex" id="MathJax-Element-1">X</script>到输出 Y <script type="math/tex" id="MathJax-Element-2">Y</script>的映射函数: f(X)=Y f ( X ) = Y <script type="math/tex" id="MathJax-Element-3">f(X)=Y</script>,函数的系数就是我们所要训练的网络参数 W W <script type="math/tex" id="MathJax-Element-4">W</script>,只要函数系数确定下来,对于任何输入 x i <script type="math/tex" id="MathJax-Element-5">x_i</script>我们就能得到一个与之对应的输出 yi y i <script type="math/tex" id="MathJax-Element-6">y_i</script>,至于 yi y i <script type="math/tex" id="MathJax-Element-7">y_i</script>是否符合我们预期,这就属于如何提高模型性能方面的问题了,本文不做讨论。

  那么问题来了,现在我们手中只有训练集的输入 X X <script type="math/tex" id="MathJax-Element-8">X</script>和输出 Y <script type="math/tex" id="MathJax-Element-9">Y</script>,我们应该如何调整网络参数 W W <script type="math/tex" id="MathJax-Element-10">W</script>使网络实际的输出 f ( X ) = Y ^ <script type="math/tex" id="MathJax-Element-11">f(X)=\hat{Y}</script>与训练集的 Y Y <script type="math/tex" id="MathJax-Element-12">Y</script>尽可能接近?

  在开始正式讲解之前,让我们先对反向传播过程有一个直观上的印象。反向传播算法的核心是代价函数 C <script type="math/tex" id="MathJax-Element-13">C</script>对网络中参数(各层的权重 w w <script type="math/tex" id="MathJax-Element-14">w</script>和偏置 b <script type="math/tex" id="MathJax-Element-15">b</script>)的偏导表达式 Cw ∂ C ∂ w <script type="math/tex" id="MathJax-Element-16">\frac{\partial{C}}{\partial{w}}</script>和 Cb ∂ C ∂ b <script type="math/tex" id="MathJax-Element-17">\frac{\partial{C}}{\partial{b}}</script>。这些表达式描述了代价函数值 C C <script type="math/tex" id="MathJax-Element-18">C</script>随权重 w <script type="math/tex" id="MathJax-Element-19">w</script>或偏置 b b <script type="math/tex" id="MathJax-Element-20">b</script>变化而变化的程度。到这里,BP算法的思路就很容易理解了:如果当前代价函数值距离预期值较远,那么我们通过调整 w <script type="math/tex" id="MathJax-Element-21">w</script>和 b b <script type="math/tex" id="MathJax-Element-22">b</script>的值使新的代价函数值更接近预期值(和预期值相差越大,则 w <script type="math/tex" id="MathJax-Element-23">w</script>和 b b <script type="math/tex" id="MathJax-Element-24">b</script>调整的幅度就越大)。一直重复该过程,直到最终的代价函数值在误差范围内,则算法停止。

  BP算法可以告诉我们神经网络在每次迭代中,网络的参数是如何变化的,理解这个过程对于我们分析网络性能或优化过程是非常有帮助的,所以还是尽可能搞透这个点。我也是之前大致看过,然后发现看一些进阶知识还是需要BP的推导过程作为支撑,所以才重新整理出这么一篇博客。

前向传播过程

  在开始反向传播之前,先提一下前向传播过程,即网络如何根据输入 X <script type="math/tex" id="MathJax-Element-25">X</script>得到输出 Y Y <script type="math/tex" id="MathJax-Element-26">Y</script>的。这个很容易理解,粗略看一下即可,这里主要是为了统一后面的符号表达。

w j k l <script type="math/tex" id="MathJax-Element-27">w_{jk}^{l}</script>为第 l1 l − 1 <script type="math/tex" id="MathJax-Element-28">l-1</script>层第 k k <script type="math/tex" id="MathJax-Element-29">k</script>个神经元到第 l <script type="math/tex" id="MathJax-Element-30">l</script>层第 j j <script type="math/tex" id="MathJax-Element-31">j</script>个神经元的权重, b j l <script type="math/tex" id="MathJax-Element-32">b_j^l</script>为第 l l <script type="math/tex" id="MathJax-Element-33">l</script>层第 j <script type="math/tex" id="MathJax-Element-34">j</script>个神经元的偏置, alj a j l <script type="math/tex" id="MathJax-Element-35">a_j^l</script>为第 l l <script type="math/tex" id="MathJax-Element-36">l</script>层第 j <script type="math/tex" id="MathJax-Element-37">j</script>个神经元的激活值(激活函数的输出)。不难看出, alj a j l <script type="math/tex" id="MathJax-Element-38">a_j^l</script>的值取决于上一层神经元的激活:

alj=σ(kwljkal1k+blj)(1) (1) a j l = σ ( ∑ k w j k l a k l − 1 + b j l )
<script type="math/tex; mode=display" id="MathJax-Element-39">a_j^l=\sigma{(\sum_k{w_{jk}^l a_k^{l-1}}+b_j^l)} \tag{1}</script> 将上式重写为矩阵形式:
al=σ(wlal1+bl)(2) (2) a l = σ ( w l a l − 1 + b l )
<script type="math/tex; mode=display" id="MathJax-Element-40">a^l=\sigma{(w^l a^{l-1} +b^l)} \tag{2}</script>为了方便表示,记 zl=wlal1+bl z l = w l a l − 1 + b l <script type="math/tex" id="MathJax-Element-41">z^l=w^l a^{l-1} +b^l</script>为每一层的权重输入, (2) ( 2 ) <script type="math/tex" id="MathJax-Element-42">(2)</script>式则变为 al=σ(zl) a l = σ ( z l ) <script type="math/tex" id="MathJax-Element-43">a^l=\sigma{(z^l)}</script>。
  利用 (2) ( 2 ) <script type="math/tex" id="MathJax-Element-44">(2)</script>式一层层计算网络的激活值,最终能够根据输入 X X <script type="math/tex" id="MathJax-Element-45">X</script>得到相应的输出 Y ^ <script type="math/tex" id="MathJax-Element-46">\hat Y</script>。

反向传播过程

  反向传播过程中要计算 Cw ∂ C ∂ w <script type="math/tex" id="MathJax-Element-47">\frac{\partial{C}}{\partial w}</script>和 Cb ∂ C ∂ b <script type="math/tex" id="MathJax-Element-48">\frac{\partial{C}}{\partial b}</script>,我们先对代价函数做两个假设,以二次损失函数为例:

C=12nxy(x)aL(x)2(3) (3) C = 1 2 n ∑ x ‖ y ( x ) − a L ( x ) ‖ 2
<script type="math/tex; mode=display" id="MathJax-Element-49">C=\frac{1}{2n} \sum_x{\| y(x) - a^L(x)\| ^ 2} \tag{3}</script>其中 n n <script type="math/tex" id="MathJax-Element-50">n</script>为训练样本 x <script type="math/tex" id="MathJax-Element-51">x</script>的总数, y=y(x) y = y ( x ) <script type="math/tex" id="MathJax-Element-52">y=y(x)</script>为期望的输出,即ground truth, L L <script type="math/tex" id="MathJax-Element-53">L</script>为网络的层数, a L ( x ) <script type="math/tex" id="MathJax-Element-54">a^L(x)</script>为网络的输出向量。
假设1:总的代价函数可以表示为单个样本的代价函数之和的平均:
C=1nxCx  Cx=12yaL2(4) (4) C = 1 n ∑ x C x     C x = 1 2 ‖ y − a L ‖ 2
<script type="math/tex; mode=display" id="MathJax-Element-55">C=\frac{1}{n} \sum_x{C_x}   C_x=\frac{1}{2}\|y-a^L\|^2 \tag{4}</script>
  这个假设的意义在于,因为反向传播过程中我们只能计算单个训练样本的 Cxw ∂ C x ∂ w <script type="math/tex" id="MathJax-Element-56">\frac{\partial{C_x}}{\partial w}</script>和 Cxb ∂ C x ∂ b <script type="math/tex" id="MathJax-Element-57">\frac{\partial{C_x}}{\partial b}</script>,在这个假设下,我们可以通过计算所有样本的平均来得到总体的 Cw ∂ C ∂ w <script type="math/tex" id="MathJax-Element-58">\frac{\partial{C}}{\partial w}</script>和 Cb ∂ C ∂ b <script type="math/tex" id="MathJax-Element-59">\frac{\partial{C}}{\partial b}</script>
假设2:代价函数可以表达为网络输出的函数 costC=C(aL) c o s t C = C ( a L ) <script type="math/tex" id="MathJax-Element-60">costC=C(a^L)</script>,比如单个样本 x x <script type="math/tex" id="MathJax-Element-61">x</script>的二次代价函数可以写为:
(5) C x = 1 2 y a L 2 = 1 2 j ( y j a j L ) 2
<script type="math/tex; mode=display" id="MathJax-Element-62">C_x=\frac{1}{2}\|y-a^L\|^2=\frac{1}{2} \sum_j{(y_j - a_j^L)^2} \tag{5}</script>

反向传播的四个基本方程

  权重 w w <script type="math/tex" id="MathJax-Element-3202">w</script>和偏置 b <script type="math/tex" id="MathJax-Element-3203">b</script>的改变如何影响代价函数 C C <script type="math/tex" id="MathJax-Element-3204">C</script>是理解反向传播的关键。最终,这意味着我们需要计算出每个 C w j k l <script type="math/tex" id="MathJax-Element-3205">\frac{\partial{C}}{\partial w_{jk}^l}</script>和 Cblj ∂ C ∂ b j l <script type="math/tex" id="MathJax-Element-3206">\frac{\partial{C}}{\partial b_j^l}</script>,在讨论基本方程之前,我们引入误差 δ δ <script type="math/tex" id="MathJax-Element-3207">\delta</script>的概念, δlj δ j l <script type="math/tex" id="MathJax-Element-3208">\delta_j^l</script>表示第 l l <script type="math/tex" id="MathJax-Element-3209">l</script>层第 j <script type="math/tex" id="MathJax-Element-3210">j</script>个单元的误差。关于误差的理解,《Neural Network and Deep Learning》书中给了一个比较形象的例子。

  如上图所示,假设有个小恶魔在第 l l <script type="math/tex" id="MathJax-Element-3211">l</script>层第 j <script type="math/tex" id="MathJax-Element-3212">j</script>个单元捣蛋,他让这个神经元的权重输出变化了 Δzlj Δ z j l <script type="math/tex" id="MathJax-Element-3213">\Delta z_j^l</script>,那么这个神经元的激活输出为 σ(zlj+Δzlj) σ ( z j l + Δ z j l ) <script type="math/tex" id="MathJax-Element-3214">\sigma(z_j^l+\Delta z_j^l)</script>,然后这个误差向后逐层传播下去,导致最终的代价函数变化了 CzljΔzlj ∂ C ∂ z j l Δ z j l <script type="math/tex" id="MathJax-Element-3215">\frac{\partial{C}}{\partial z_j^l}\Delta z_j^l</script>。现在这个小恶魔改过自新,它想帮助我们尽可能减小代价函数的值(使网络输出更符合预期)。假设 Czlj ∂ C ∂ z j l <script type="math/tex" id="MathJax-Element-3216">\frac{\partial{C}}{\partial z_j^l}</script>一开始是个很大的正值或者负值,小恶魔通过选择一个和 Czlj ∂ C ∂ z j l <script type="math/tex" id="MathJax-Element-3217">\frac{\partial{C}}{\partial z_j^l}</script>方向相反的 Δzlj Δ z j l <script type="math/tex" id="MathJax-Element-3218">\Delta z_j^l</script>使代价函数更小(这就是我们熟知的梯度下降法)。随着迭代的进行, Czlj ∂ C ∂ z j l <script type="math/tex" id="MathJax-Element-3219">\frac{\partial{C}}{\partial z_j^l}</script>会逐渐趋向于0,那么 Δzlj Δ z j l <script type="math/tex" id="MathJax-Element-3220">\Delta z_j^l</script>对于代价函数的改进效果就微乎其微了,这时小恶魔就一脸骄傲的告诉你:“俺已经找到了最优解了(局部最优)”。这启发我们可以用 Czlj ∂ C ∂ z j l <script type="math/tex" id="MathJax-Element-3221">\frac{\partial{C}}{\partial z_j^l}</script>来衡量神经元的误差:

δlj=Czlj δ j l = ∂ C ∂ z j l
<script type="math/tex; mode=display" id="MathJax-Element-3222">\delta_j^l=\frac{\partial{C}}{\partial z_j^l}</script>下面就来看看四个基本方程是怎么来的。
  
1. 输出层的误差方程
δLj=CzLj=CaLjaLjzLj=CaLjσ(zLj)(BP1) (BP1) δ j L = ∂ C ∂ z j L = ∂ C ∂ a j L ∂ a j L ∂ z j L = ∂ C ∂ a j L σ ′ ( z j L )
<script type="math/tex; mode=display" id="MathJax-Element-3223">\delta_j^L=\frac{\partial C}{\partial z_j^L}=\frac{\partial C}{\partial a_j^L}\frac{\partial a_j^L}{\partial z_j^L}=\frac{\partial C}{\partial a_j^L}\sigma'(z_j^L) \tag{BP1}</script>如果上面的东西你看明白了,这个方程应该不难理解,等式右边第一项 CaLj ∂ C ∂ a j L <script type="math/tex" id="MathJax-Element-3224">\frac{\partial C}{\partial a_j^L}</script>衡量了代价函数随网络最终输出的变化快慢,而第二项 σ(zLj) σ ′ ( z j L ) <script type="math/tex" id="MathJax-Element-3225">\sigma'(z_j^L)</script>则衡量了激活函数输出随 zLj z j L <script type="math/tex" id="MathJax-Element-3226">z_j^L</script>的变化快慢。当激活函数饱和,即 σ(zLj)0 σ ′ ( z j L ) ≈ 0 <script type="math/tex" id="MathJax-Element-3227">\sigma'(z_j^L)\approx0</script>时,无论 CaLj ∂ C ∂ a j L <script type="math/tex" id="MathJax-Element-3228">\frac{\partial C}{\partial a_j^L}</script>多大,最终 δLj0 δ j L ≈ 0 <script type="math/tex" id="MathJax-Element-3229">\delta_j^L\approx0</script>,输出神经元进入饱和区,停止学习。
  (BP1)方程中两项都很容易计算,如果代价函数为二次代价函数 C=12j(yjaLj)2 C = 1 2 ∑ j ( y j − a j L ) 2 <script type="math/tex" id="MathJax-Element-3230">C=\frac{1}{2} \sum_j{(y_j - a_j^L)^2}</script>,则 CaLj=aLjyj ∂ C ∂ a j L = a j L − y j <script type="math/tex" id="MathJax-Element-3231">\frac{\partial C}{\partial a_j^L}=a_j^L-y_j</script>,同理,对激活函数 σ(z) σ ( z ) <script type="math/tex" id="MathJax-Element-3232">\sigma(z)</script>求 zLj z j L <script type="math/tex" id="MathJax-Element-3233">z_j^L</script>的偏导即可求得 σ(zLj) σ ′ ( z j L ) <script type="math/tex" id="MathJax-Element-3234">\sigma'(z_j^L)</script>。将(BP1)重写为矩阵形式:
δL=aCσ(zL)(BP1a) (BP1a) δ L = ∇ a C ⊙ σ ′ ( z L )
<script type="math/tex; mode=display" id="MathJax-Element-3235">\delta^L=\nabla_aC \odot \sigma'(z^L) \tag{BP1a}</script> <script type="math/tex" id="MathJax-Element-3236">\odot</script>为Hadamard积,即矩阵的点积。
2. 误差传递方程
δl=((wl+1)Tδl+1)σ(zl)(BP2) (BP2) δ l = ( ( w l + 1 ) T δ l + 1 ) ⊙ σ ′ ( z l )
<script type="math/tex; mode=display" id="MathJax-Element-3237">\delta^l=((w^{l+1})^T\delta^{l+1})\odot \sigma'(z^l) \tag{BP2}</script>这个方程说明我们可以通过第 l+1 l + 1 <script type="math/tex" id="MathJax-Element-3238">l+1</script>层的误差 δl+1 δ l + 1 <script type="math/tex" id="MathJax-Element-3239">\delta^{l+1}</script>计算第 l l <script type="math/tex" id="MathJax-Element-3240">l</script>层的误差 δ l <script type="math/tex" id="MathJax-Element-3241">\delta^{l}</script>,结合(BP1)和(BP2)两个方程,我们现在可以计算网络中任意一层的误差了,先计算 δL δ L <script type="math/tex" id="MathJax-Element-3242">\delta^L</script>,然后计算 δL1 δ L − 1 <script type="math/tex" id="MathJax-Element-3243">\delta^{L-1}</script>, δL2 δ L − 2 <script type="math/tex" id="MathJax-Element-3244">\delta^{L-2}</script>,…,直到输入层。
证明过程如下:
δlj=Czlj=kCzl+1kzl+1kzlj=kδl+1kzl+1kzlj δ j l = ∂ C ∂ z j l = ∑ k ∂ C ∂ z k l + 1 ∂ z k l + 1 ∂ z j l = ∑ k δ k l + 1 ∂ z k l + 1 ∂ z j l
<script type="math/tex; mode=display" id="MathJax-Element-3245">\delta_j^l=\frac{\partial C}{\partial z_j^l} = \sum_k \frac{\partial C}{\partial z_k^{l+1}} \frac{\partial z_k^{l+1}}{\partial z_j^l}= \sum_k \delta_k^{l+1} \frac{\partial z_k^{l+1}}{\partial z_j^l} </script>因为 zl+1k=jwl+1kjalj+bl+1k=jwl+1kjσ(zlj)+bl+1k z k l + 1 = ∑ j w k j l + 1 a j l + b k l + 1 = ∑ j w k j l + 1 σ ( z j l ) + b k l + 1 <script type="math/tex" id="MathJax-Element-3246">z_k^{l+1}=\sum_j{w_{kj}^{l+1}a_j^l+b_k^{l+1}}=\sum_j{w_{kj}^{l+1}\sigma{(z_j^l)}+b_k^{l+1}}</script>,所以 zl+1kzlj=wl+1kjσ(zlj) ∂ z k l + 1 ∂ z j l = w k j l + 1 σ ′ ( z j l ) <script type="math/tex" id="MathJax-Element-3247">\frac{\partial z_k^{l+1}}{\partial z_j^l}=w_{kj}^{l+1}\sigma'(z_j^l)</script>,因此可以得到(BP2),
δlj=kwl+1kjδl+1kσ(zlj) δ j l = ∑ k w k j l + 1 δ k l + 1 σ ′ ( z j l )
<script type="math/tex; mode=display" id="MathJax-Element-3248">\delta_j^l=\sum_k w_{kj}^{l+1} \delta_k^{l+1} \sigma'(z_j^l)</script>
3. 代价函数对偏置的改变率
Cblj=Czljzljblj=Czlj=δlj(BP3) (BP3) ∂ C ∂ b j l = ∂ C ∂ z j l ∂ z j l ∂ b j l = ∂ C ∂ z j l = δ j l
<script type="math/tex; mode=display" id="MathJax-Element-3249">\frac{\partial C}{\partial b_j^l}=\frac{\partial C}{\partial z_j^l}\frac{\partial z_j^l}{\partial b_j^l}=\frac{\partial C}{\partial z_j^l}=\delta_j^l \tag{BP3}</script>这里因为 zlj=kwljkal1k+blj z j l = ∑ k w j k l a k l − 1 + b j l <script type="math/tex" id="MathJax-Element-3250">z_j^l=\sum_k{w_{jk}^l a_k^{l-1}}+b_j^l</script>所以 zLjbLj=1 ∂ z j L ∂ b j L = 1 <script type="math/tex" id="MathJax-Element-3251">\frac{\partial z_j^L}{\partial b_j^L}=1</script>
4. 代价函数对权重的改变率
Cwljk=CzljzLjwljk=Czljal1k=al1kδlj(BP4) (BP4) ∂ C ∂ w j k l = ∂ C ∂ z j l ∂ z j L ∂ w j k l = ∂ C ∂ z j l a k l − 1 = a k l − 1 δ j l
<script type="math/tex; mode=display" id="MathJax-Element-3252">\frac{\partial C}{\partial w_{jk}^l}=\frac{\partial C}{\partial z_j^l}\frac{\partial z_j^L}{\partial w_{jk}^l}=\frac{\partial C}{\partial z_j^l}a_k^{l-1}=a_k^{l-1}\delta_j^l \tag{BP4}</script>可以简写为
Cw=ainδout(6) (6) ∂ C ∂ w = a i n δ o u t
<script type="math/tex; mode=display" id="MathJax-Element-3253">\frac{\partial C}{\partial w}=a_{in}\delta_{out} \tag{6}</script>,不难发现,当上一层激活输出接近0的时候,无论返回的误差有多大, Cw ∂ C ∂ w <script type="math/tex" id="MathJax-Element-3254">\frac{\partial C}{\partial w}</script>的改变都很小,这也就解释了为什么神经元饱和不利于训练。

  从上面的推导我们不难发现,当输入神经元没有被激活,或者输出神经元处于饱和状态,权重和偏置会学习的非常慢,这不是我们想要的效果。这也说明了为什么我们平时总是说激活函数的选择非常重要。

  当我计算得到 Cwljk ∂ C ∂ w j k l <script type="math/tex" id="MathJax-Element-3255">\frac{\partial C}{\partial w_{jk}^l}</script>和 Cblj ∂ C ∂ b j l <script type="math/tex" id="MathJax-Element-3256">\frac{\partial C}{\partial b_j^l}</script>后,就能愉悦地使用梯度下降法对参数进行一轮轮更新了,直到最后模型收敛。

反向传播为什么快

  回答这个问题前,我们先看一下普通方法怎么求梯度。以计算权重为例,我们将代价函数看成是权重的函数 C=C(w) C = C ( w ) <script type="math/tex" id="MathJax-Element-1830">C=C(w)</script>,假设现在网络中有100万个参数,我们可以利用微分的定义式来计算代价函数对其中某个权重 wj w j <script type="math/tex" id="MathJax-Element-1831">w_j</script>的偏导:

CwjC(w+εej)C(w)ε(7) (7) ∂ C ∂ w j ≈ C ( w + ε e j → ) − C ( w ) ε
<script type="math/tex; mode=display" id="MathJax-Element-1832">\frac{\partial C}{\partial w_j}\approx\frac{C(w+\varepsilon \vec{e_j}) - C(w)}{\varepsilon} \tag{7}</script>然后我们算一下,为了计算 Cwj ∂ C ∂ w j <script type="math/tex" id="MathJax-Element-1833">\frac{\partial C}{\partial w_j}</script>,我们需要从头到尾完整进行一次前向传播才能得到最终 C(w+εej) C ( w + ε e j → ) <script type="math/tex" id="MathJax-Element-1834">C(w+\varepsilon \vec{e_j})</script>的值,要计算100万个参数的偏导就需要前向传播100万次,而且这还只是一次迭代,想想是不是特别可怕?
  再反观反向传播算法,如方程(BP4)所示,我们只要知道 al1k a k l − 1 <script type="math/tex" id="MathJax-Element-1835">a_k^{l-1}</script>和 δlj δ j l <script type="math/tex" id="MathJax-Element-1836">\delta_j^l </script>就能计算出偏导 Cwljk ∂ C ∂ w j k l <script type="math/tex" id="MathJax-Element-1837">\frac{\partial C}{\partial w_{jk}^l}</script>。激活函数值 al1k a k l − 1 <script type="math/tex" id="MathJax-Element-1838">a_k^{l-1}</script>在一次前向传播后就能全部得到,然后利用(BP1)和(PB2)可以计算出 δlj δ j l <script type="math/tex" id="MathJax-Element-1839">\delta_j^l </script>,反向传播和前向传播计算量相当,所以总共只需2次前向传播的计算量就能计算出所有的 Cwljk ∂ C ∂ w j k l <script type="math/tex" id="MathJax-Element-1840">\frac{\partial C}{\partial w_{jk}^l}</script>。这比使用微分定义式求偏导的计算量少了不止一点半点,简直是质的飞跃。

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐