TL;DR We investigate a fundamental question in recurrent neural network training: why is backpropagation through time always ran backwards? We show, by deriving an exact gradient-based algorithm that propagates error forward in time (in multiple phases), that this does not necessarily need to be the case! However, while the math holds up, it suffers from critical numerical stability issues as the network forgets information faster. This post details the derivation, the successful experiments, an analysis of why this promising idea suffers numerically, and the reasons why we did not investigate it further.
Do we necessarily need to calculate error signals in backward time (as in backpropagation through time) when training recurrent neural networks? Surprisingly, no.
In this post, we derive a method to propagate errors forward in time. By using a “warm-up” phase to determine initial conditions, we can reconstruct the exact error gradients required for learning without ever moving backward through the sequence. This sounds like a perfect solution for neuromorphic or analog hardware and a potential theory for how our brains learn, as we no longer suffer from the huge memory requirements of backpropagation through time and its to reverse the time arrow.
However, there is a catch.
While we successfully trained deep RNNs using this method on non-trivial tasks, it suffers from a critical flaw that prevents it from being widely applicable. The algorithm suffers from severe numerical instability whenever the neural network is in a “forgetting” regime. Essentially, the math works, but the floating-point arithmetic does not.
Despite these negative results, we decided to report our findings through this blog post, as we learned a lot on the fundamentals of backpropagation through time, an algorithm that is at the core of deep learning. While not directly applicable, we hope that the ideas we discuss here may sparkle some new ones and that it may be useful for the search of alternative physical compute paradigms!
Backpropagation of errors through time
Before diving into how we can propagate error signals in time, let us review backpropagation through time (BPTT) and why we would like to improve upon it in the first place.
Notation and derivation
We consider a recurrent network with hidden state $h_t$ that satisfies $h_0 = 0$ and evolves according to the following dynamics:
... continue reading