9.7. Backpropagation Through Time — Dive into Deep Learning 1.0.0-beta0 documentation
If you completed the exercises in Section 9.5, you would
have seen that gradient clipping is vital to prevent the occasional
massive gradients from destabilizing training. We hinted that the
exploding gradients stem from backpropagating across long sequences.
Before introducing a slew of modern RNN architectures, let’s take a
closer look at how backpropagation works in sequence models in
mathematical detail. Hopefully, this discussion will bring some
precision to the notion of vanishing and exploding gradients. If you
recall our discussion of forward and backward propagation through
computational graphs when we introduced MLPs in
Section 5.3, then forward propagation in RNNs should be
relatively straightforward. Applying backpropagation in RNNs is called
backpropagation through time (Werbos, 1990). This procedure
requires us to expand (or unroll) the computational graph of an RNN one
time step at a time. The unrolled RNN is essentially a feedforward
neural network with the special property that the same parameters are
repeated throughout the unrolled network, appearing at each time step.
Then, just as in any feedforward neural network, we can apply the chain
rule, backpropagating gradients through the unrolled net. The gradient
with respect to each parameter must be summed across all places that the
parameter occurs in the unrolled net. Handling such weight tying should
be familiar from our chapters on convolutional neural networks.
Complications arise because sequences can be rather long. It is not
unusual to work with text sequences consisting of over a thousand
tokens. Note that this poses problems both from a computational (too
much memory) and optimization (numerical instability) standpoint. Input
from the first step passes through over 1000 matrix products before
arriving at the output, and another 1000 matrix products are required to
compute the gradient. We now analyze what can go wrong and how to
address it in practice.
9.7.1.
Analysis of Gradients in RNNs¶
We start with a simplified model of how an RNN works. This model ignores
details about the specifics of the hidden state and how it is updated.
The mathematical notation here does not explicitly distinguish scalars,
vectors, and matrices. We are just trying to develop some intuition. In
this simplified model, we denote \(h_t\) as the hidden state,
\(x_t\) as input, and \(o_t\) as output at time step \(t\).
Recall our discussions in Section 9.4.2 that
the input and the hidden state can be concatenated before being
multiplied by one weight variable in the hidden layer. Thus, we use
\(w_h\) and \(w_o\) to indicate the weights of the hidden layer
and the output layer, respectively. As a result, the hidden states and
outputs at each time steps are
(9.7.1) ¶
\[\begin{split}\begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned}\end{split}\]
where \(f\) and \(g\) are transformations of the hidden layer
and the output layer, respectively. Hence, we have a chain of values
\(\{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\}\)
that depend on each other via recurrent computation. The forward
propagation is fairly straightforward. All we need is to loop through
the \((x_t, h_t, o_t)\) triples one time step at a time. The
discrepancy between output \(o_t\) and the desired target
\(y_t\) is then evaluated by an objective function across all the
\(T\) time steps as
(9.7.2) ¶
\[L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t).\]
For backpropagation, matters are a bit trickier, especially when we
compute the gradients with regard to the parameters \(w_h\) of the
objective function \(L\). To be specific, by the chain rule,
(9.7.3) ¶
\[\begin{split}\begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned}\end{split}\]
The first and the second factors of the product in
(9.7.3) are easy to compute. The third factor
\(\partial h_t/\partial w_h\) is where things get tricky, since we
need to recurrently compute the effect of the parameter \(w_h\) on
\(h_t\). According to the recurrent computation in
(9.7.1), \(h_t\) depends on both \(h_{t-1}\)
and \(w_h\), where computation of \(h_{t-1}\) also depends on
\(w_h\). Thus, evaluating the total derivate of \(h_t\) with
respect to \(w_h\) using the chain rule yields
(9.7.4) ¶
\[\frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}.\]
To derive the above gradient, assume that we have three sequences
\(\{a_{t}\},\{b_{t}\},\{c_{t}\}\) satisfying \(a_{0}=0\) and
\(a_{t}=b_{t}+c_{t}a_{t-1}\) for \(t=1, 2,\ldots\). Then for
\(t\geq 1\), it is easy to show
(9.7.5) ¶
\[a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.\]
By substituting \(a_t\), \(b_t\), and \(c_t\) according to
(9.7.6) ¶
\[\begin{split}\begin{aligned}a_t &= \frac{\partial h_t}{\partial w_h},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}},\end{aligned}\end{split}\]
the gradient computation in (9.7.4)
satisfies \(a_{t}=b_{t}+c_{t}a_{t-1}\). Thus, per
(9.7.5), we can remove the recurrent computation in
(9.7.4) with
(9.7.7) ¶
\[\frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}.\]
While we can use the chain rule to compute
\(\partial h_t/\partial w_h\) recursively, this chain can get very
long whenever \(t\) is large. Let’s discuss a number of strategies
for dealing with this problem.
9.7.1.1.
Full Computation¶
One idea might be to compute the full sum in
(9.7.7). However, this is very slow and
gradients can blow up, since subtle changes in the initial conditions
can potentially affect the outcome a lot. That is, we could see things
similar to the butterfly effect, where minimal changes in the initial
conditions lead to disproportionate changes in the outcome. This is
generally undesirable. After all, we are looking for robust estimators
that generalize well. Hence this strategy is almost never used in
practice.
9.7.1.2.
Truncating Time Steps¶
Alternatively, we can truncate the sum in
(9.7.7) after \(\tau\) steps. This is
what we have been discussing so far. This leads to an approximation of
the true gradient, simply by terminating the sum at
\(\partial h_{t-\tau}/\partial w_h\). In practice this works quite
well. It is what is commonly referred to as truncated backpropgation
through time (Jaeger, 2002). One of the consequences of this is
that the model focuses primarily on short-term influence rather than
long-term consequences. This is actually desirable, since it biases
the estimate towards simpler and more stable models.
9.7.1.3.
Randomized Truncation¶
Last, we can replace \(\partial h_t/\partial w_h\) by a random
variable which is correct in expectation but truncates the sequence.
This is achieved by using a sequence of \(\xi_t\) with predefined
\(0 \leq \pi_t \leq 1\), where \(P(\xi_t = 0) = 1-\pi_t\) and
\(P(\xi_t = \pi_t^{-1}) = \pi_t\), thus \(E[\xi_t] = 1\). We use
this to replace the gradient \(\partial h_t/\partial w_h\) in
(9.7.4) with
(9.7.8) ¶
\[z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}.\]
It follows from the definition of \(\xi_t\) that
\(E[z_t] = \partial h_t/\partial w_h\). Whenever \(\xi_t = 0\)
the recurrent computation terminates at that time step \(t\). This
leads to a weighted sum of sequences of varying lengths, where long
sequences are rare but appropriately overweighted. This idea was
proposed by Tallec and Ollivier (2017).
9.7.1.4.
Comparing Strategies¶
Fig. 9.7.1 illustrates the three strategies when
analyzing the first few characters of The Time Machine using
backpropagation through time for RNNs:
-
The first row is the randomized truncation that partitions the text
into segments of varying lengths. -
The second row is the regular truncation that breaks the text into
subsequences of the same length. This is what we have been doing in
RNN experiments. -
The third row is the full backpropagation through time that leads to
a computationally infeasible expression.
Unfortunately, while appealing in theory, randomized truncation does not
work much better than regular truncation, most likely due to a number of
factors. First, the effect of an observation after a number of
backpropagation steps into the past is quite sufficient to capture
dependencies in practice. Second, the increased variance counteracts the
fact that the gradient is more accurate with more steps. Third, we
actually want models that have only a short range of interactions.
Hence, regularly truncated backpropagation through time has a slight
regularizing effect that can be desirable.
9.7.2.
Backpropagation Through Time in Detail¶
After discussing the general principle, let’s discuss backpropagation
through time in detail. Different from the analysis in
Section 9.7.1, in the following we will show how to
compute the gradients of the objective function with respect to all the
decomposed model parameters. To keep things simple, we consider an RNN
without bias parameters, whose activation function in the hidden layer
uses the identity mapping (\(\phi(x)=x\)). For time step \(t\),
let the single example input and the target be
\(\mathbf{x}_t \in \mathbb{R}^d\) and \(y_t\), respectively. The
hidden state \(\mathbf{h}_t \in \mathbb{R}^h\) and the output
\(\mathbf{o}_t \in \mathbb{R}^q\) are computed as
(9.7.9) ¶
\[\begin{split}\begin{aligned}\mathbf{h}_t &= \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_{qh} \mathbf{h}_{t},\end{aligned}\end{split}\]
where \(\mathbf{W}_{hx} \in \mathbb{R}^{h \times d}\),
\(\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\), and
\(\mathbf{W}_{qh} \in \mathbb{R}^{q \times h}\) are the weight
parameters. Denote by \(l(\mathbf{o}_t, y_t)\) the loss at time step
\(t\). Our objective function, the loss over \(T\) time steps
from the beginning of the sequence is thus
(9.7.10) ¶
\[L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t).\]
In order to visualize the dependencies among model variables and
parameters during computation of the RNN, we can draw a computational
graph for the model, as shown in Fig. 9.7.2. For example,
the computation of the hidden states of time step 3,
\(\mathbf{h}_3\), depends on the model parameters
\(\mathbf{W}_{hx}\) and \(\mathbf{W}_{hh}\), the hidden state of
the last time step \(\mathbf{h}_2\), and the input of the current
time step \(\mathbf{x}_3\).
As just mentioned, the model parameters in Fig. 9.7.2 are
\(\mathbf{W}_{hx}\), \(\mathbf{W}_{hh}\), and
\(\mathbf{W}_{qh}\). Generally, training this model requires
gradient computation with respect to these parameters
\(\partial L/\partial \mathbf{W}_{hx}\),
\(\partial L/\partial \mathbf{W}_{hh}\), and
\(\partial L/\partial \mathbf{W}_{qh}\). According to the
dependencies in Fig. 9.7.2, we can traverse in the
opposite direction of the arrows to calculate and store the gradients in
turn. To flexibly express the multiplication of matrices, vectors, and
scalars of different shapes in the chain rule, we continue to use the
\(\text{prod}\) operator as described in Section 5.3.
First of all, differentiating the objective function with respect to the
model output at any time step \(t\) is fairly straightforward:
(9.7.11) ¶
\[\frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q.\]
Now, we can calculate the gradient of the objective with respect to the
parameter \(\mathbf{W}_{qh}\) in the output layer:
\(\partial L/\partial \mathbf{W}_{qh} \in \mathbb{R}^{q \times h}\).
Based on Fig. 9.7.2, the objective \(L\) depends on
\(\mathbf{W}_{qh}\) via \(\mathbf{o}_1, \ldots, \mathbf{o}_T\).
Using the chain rule yields
(9.7.12) ¶
\[\frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top,\]
where \(\partial L/\partial \mathbf{o}_t\) is given by
(9.7.11).
Next, as shown in Fig. 9.7.2, at the final time step
\(T\), the objective function \(L\) depends on the hidden state
\(\mathbf{h}_T\) only via \(\mathbf{o}_T\). Therefore, we can
easily find the gradient
\(\partial L/\partial \mathbf{h}_T \in \mathbb{R}^h\) using the
chain rule:
(9.7.13) ¶
\[\frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}.\]
It gets trickier for any time step \(t < T\), where the objective
function \(L\) depends on \(\mathbf{h}_t\) via
\(\mathbf{h}_{t+1}\) and \(\mathbf{o}_t\). According to the
chain rule, the gradient of the hidden state
\(\partial L/\partial \mathbf{h}_t \in \mathbb{R}^h\) at any time
step \(t < T\) can be recurrently computed as:
(9.7.14) ¶
\[\frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}.\]
For analysis, expanding the recurrent computation for any time step
\(1 \leq t \leq T\) gives
(9.7.15) ¶
\[\frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}.\]
We can see from (9.7.15) that this simple linear
example already exhibits some key problems of long sequence models: it
involves potentially very large powers of \(\mathbf{W}_{hh}^\top\).
In it, eigenvalues smaller than 1 vanish and eigenvalues larger than 1
diverge. This is numerically unstable, which manifests itself in the
form of vanishing and exploding gradients. One way to address this is to
truncate the time steps at a computationally convenient size as
discussed in Section 9.7.1. In practice, this
truncation can also be effected by detaching the gradient after a given
number of time steps. Later on, we will see how more sophisticated
sequence models such as long short-term memory can alleviate this
further.
Finally, Fig. 9.7.2 shows that the objective function
\(L\) depends on model parameters \(\mathbf{W}_{hx}\) and
\(\mathbf{W}_{hh}\) in the hidden layer via hidden states
\(\mathbf{h}_1, \ldots, \mathbf{h}_T\). To compute gradients with
respect to such parameters
\(\partial L / \partial \mathbf{W}_{hx} \in \mathbb{R}^{h \times d}\)
and
\(\partial L / \partial \mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\),
we apply the chain rule that gives
(9.7.16) ¶
\[\begin{split}\begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top,\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top, \end{aligned}\end{split}\]
where \(\partial L/\partial \mathbf{h}_t\) that is recurrently
computed by (9.7.13) and
(9.7.14) is the key quantity that affects
the numerical stability.
Since backpropagation through time is the application of backpropagation
in RNNs, as we have explained in Section 5.3, training RNNs
alternates forward propagation with backpropagation through time.
Besides, backpropagation through time computes and stores the above
gradients in turn. Specifically, stored intermediate values are reused
to avoid duplicate calculations, such as storing
\(\partial L/\partial \mathbf{h}_t\) to be used in computation of
both \(\partial L / \partial \mathbf{W}_{hx}\) and
\(\partial L / \partial \mathbf{W}_{hh}\).
9.7.3.
Summary¶
Backpropagation through time is merely an application of backpropagation
to sequence models with a hidden state. Truncation is needed for
computational convenience and numerical stability, such as regular
truncation and randomized truncation. High powers of matrices can lead
to divergent or vanishing eigenvalues. This manifests itself in the form
of exploding or vanishing gradients. For efficient computation,
intermediate values are cached during backpropagation through time.
9.7.4.
Exercises¶
-
Assume that we have a symmetric matrix
\(\mathbf{M} \in \mathbb{R}^{n \times n}\) with eigenvalues
\(\lambda_i\) whose corresponding eigenvectors are
\(\mathbf{v}_i\) (\(i = 1, \ldots, n\)). Without loss of
generality, assume that they are ordered in the order
\(|\lambda_i| \geq |\lambda_{i+1}|\). -
Show that \(\mathbf{M}^k\) has eigenvalues \(\lambda_i^k\).
-
Prove that for a random vector \(\mathbf{x} \in \mathbb{R}^n\),
with high probability \(\mathbf{M}^k \mathbf{x}\) will be very
much aligned with the eigenvector \(\mathbf{v}_1\) of
\(\mathbf{M}\). Formalize this statement. -
What does the above result mean for gradients in RNNs?
-
Besides gradient clipping, can you think of any other methods to cope
with gradient explosion in recurrent neural networks?
Discussions