为什么需要RNN

当我们遇到一些数据是序列的,长度不定的,数据的先后,顺序,是存在相互影响的语义的。对于这类问题,因此就出现了RNN这种结构,能够能够的提取序列数据的特征。

RNN结构

image-20200521141220698

最简单的RNN的结构如上所示,左边是一个RNN单元,右边是将这个单元展示后得到的网络。最早的激活函数使用tanh,该网络的特殊之处在于,下一个阶段网络的输入由上一阶段的输出以及x共同组成,用公式表示如下:
$$
\begin{array}{l}O_{t}=g\left(V \cdot S_{t}\right) \ S_{t}=f\left(U \cdot X_{t}+W \cdot S_{t-1}\right)\end{array}
$$

RNN的优点

  1. RNN可以记录时间序列上的信息,对于序列数据,前后语义有着相互联系的场景比较适用。
  2. RNN可以处理文本,语音这些数据,数据的输出长度可以是不定的。

RNN的缺点

  1. 梯度消失和梯度爆炸问题,当对RNN进行梯度求导的时候,得到的表达式是参数的一个连乘形式,任意时刻对$W_s$求偏导如下:
    $$
    \frac{\partial L_{t}}{\partial W_{x}}=\sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}} \frac{\partial O_{t}}{\partial S_{t}}\left(\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}\right) \frac{\partial S_{k}}{\partial W_{x}}
    $$
    随着网络加深,连乘项越来越多,将S用tanh激活函数带入,下面表达式可变为:
    $$
    \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}= \prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}
    $$
    即一个参数累乘的形式,当网络足够深的时候,如果参数小于一,则会出现梯度消失的问题,如果参数大于1,多次连乘的结果将导致梯度爆炸。

  2. RNN网络难以训练,并且如果使用的是tanh或者relu激活函数,它无法处理非常长的序列。

通过上面可以发现,只要解决了掉偏导公式中参数连乘的哪一项就可以解决梯度问题,LSTM就是按照这个思路,将这一项变成0或者1。

LSTM

LSTM即long short Term memory,LSTM的结构比普通的RNN要复杂一些,由三个门结构组成,分别是遗忘门,输入门,输出门:

image-20200521155345276

首先是遗忘门,对输入的数据做一些选择性的遗忘,控制是否遗忘由sigmoid决定。其次是输入门,利用sigmoid对输入数据进行取舍,tanh对输入数据赋予权重。输出门:利用sigmoid对输入进行取舍,然后用tanh对数据进行加权,得到下一个输入。

(通过sigmoid后的特征,最后通过一个乘法加入到网络中)

为什么LSTM能够解决梯度消失问题

接在RNN的后面分析,LSTM梯度求导过程每一项中也存在一个累乘项,但是LSTM这个累乘项在LSTM中为0或者为1,因此有效避免了累乘导致的梯度消失问题。

传统RNN梯度计算如下:
$$
\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}}
$$
LSTM中有表达式:
$$
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) \approx 0 | 1
$$
因此LSTM:
$$
\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{1}}{\partial W_{s}}
$$
梯度中不存在累乘项,因此可以克服梯度消失和梯度爆炸的问题。

LSTM具有记忆功能

由于LSTM每次计算都有参考到上一时刻的LSTM状态,每一步决策均使用到了上一次的中间结果,因此具有记忆功能。

LSTM具记忆时间长

由于LSTM将连乘项转化为1或者0,因此有效解决了梯度爆炸和梯度消失的问题,可以保存距离当前位置比较远的位置的信息,因此LSTM具有记忆时间长的功能。

LSTM存在的问题

无法并行运算,LSTM计算效率太低。