Back Propagation in Convolutional Neural Networks — Intuition and Code
Back Propagation in Convolutional Neural Networks — Intuition and Code
Disclaimer: If you don’t have any idea of how back propagation operates on a computational graph, I recommend you have a look at this lecture from the famous cs231n course.
I have scratched my head for a long time wondering how the back propagation algorithm works for convolutions. I could not find a simple and intuitive explanation of the algorithm online. So, I decided to write one myself. Hope you enjoy!
Why Understand Back Propagation?
Andrej Karpathy wrote in his blog about the need of understanding back propagation coining it as a Leaky Abstraction
‘‘it is easy to fall into the trap of abstracting away the learning process — believing that you can simply stack arbitrary layers together and backprop will “magically make them work” on your data’’
The Chain Rule
The following figure summarises the use of chain rule for the backward pass in computational graphs.
The forward pass on the left calculates z as a function f(x,y) using the input variables x and y. The right side of the figures shows the backward pass. Receiving dL/dz, the gradient of the loss function with respect to z from above, the gradients of x and y on the loss function can be calculate by applying the chain rule, as shown in the figure (borrowed from this post)
Here is another illustration which talks about the local gradients.
Back propagation illustration from CS231n Lecture 4. The variables x and y are cached, which are later used to calculate the local gradients.
If you understand the chain rule, you are good to go.
Let’s Begin
We will try to understand how the backward pass for a single convolutional layer by taking a simple case where number of channels is one across all computations. We will also dive into the code later.
The following convolution operation takes an input X of size 3×3 using a single filter W of size 2×2 without any padding and stride = 1 generating an output H of size 2×2. Also note that, while performing the forward pass, we will cache the variables X and filter W. This will help us while performing the backward pass.
Convolution Operation (Forward Pass)
Note: Here we are performing the convolution operation without flipping the filter. This is also referred to as the cross-correlation operation in literature. The above animation is provided just for the sake of clarity.
Input Size : 3×3, Filter Size : 2×2, Output Size : 2×2Output Equations
Backward Pass
Before moving further, make note of the following notations.
Notations
Now, for implementing the back propagation step for the current layer, we can assume that we get 𝜕h as input (from the backward pass of the next layer) and our aim is to calculate 𝜕w and 𝜕x. It is important to understand that 𝜕x (or 𝜕h for previous layer) would be the input for the backward pass of the previous layer. This is the core principle behind the success of back propagation.
Each weight in the filter contributes to each pixel in the output map. Thus, any change in a weight in the filter will affect all the output pixels. Thus, all these changes add up to contribute to the final loss. Thus, we can easily calculate the derivatives as follows.
Derivative Computation (Backward pass) since pictures speak more than wordsFinal derivatives after performing back propagation
Similarly, we can derive 𝜕x. Moving further, let’s see some code.
Note: Much of the code is inspired from a programming assignment from the course Convolutional Neural Network by deeplearning.ai which is taught by Andrew Ng on Coursera.
Naive implementation of forward and backward pass for a convolution function
If you enjoyed this article, you might also want to check the following articles to delve deeper into mathematics:
I also found Back propagation in Convnets lecture by Dhruv Batra very useful for understanding the concept.
Since I might not be an expert on the topic, if you find any mistakes in the article, or have any suggestions for improvement, please mention in comments.