Backpropagation
In a previous post, we defined the loss function as a measure of how wrong a model's predictions are, and minimized it with gradient descent. But we never addressed how to efficiently compute the gradient of the loss with respect to every parameter in a large neural network. That's what backpropagation does.
The "gradient" in gradient descent is a vector of partial derivatives of the loss with respect to each parameter.
A partial derivative
What if we want
What if the dependency isn't a single chain?
If
A neural network is a long chain of differentiable operations, where the loss depends on each parameter through many intermediate layers. Backpropagation is just the systematic application of the chain rule, working backwards from the loss through each operation to compute every partial derivative. The most fundamental operation in a neural network is matrix multiplication, so let's start there.
The Gradient of a Matrix Multiplication
Suppose
Each element of
From these equations we can read off partial derivatives directly. For example,
Now suppose
Consider
Substituting the partial derivatives we computed above:
Writing out
we see that
The same pattern holds for every entry:
Similarly,
which is the
The dimensions confirm this:
To compute these gradients, the only information we needed from upstream was
A Small Neural Network
Now let's apply this to a neural network with
The loss for a target
The forward pass computes each layer in sequence:
Now we work backwards from the loss, applying the chain rule through each operation:
Since
Through the ReLU, whose derivative is
Finally,
The entire second row is zero because neuron
The parameters of this network are
Each gradient tells us how the loss changes when we adjust that parameter. To verify them all at once, let's increase every parameter by
Rerunning the forward pass with
The actual change in
In practice, we use these gradients to improve the model through gradient descent: each parameter is updated by subtracting
The predicted change in loss sums each parameter's contribution:
Rerunning the forward pass with the updated parameters:
The actual change is roughly
That's all backpropagation is: a forward pass to compute the loss, then a backward pass applying the chain rule through each operation, tracking how much the loss changes with respect to each variable. Each operation only needs its local inputs and the gradient flowing back from downstream, so it's a computationally efficient way to determine how to nudge the parameters of a model to decrease the loss.