LSTM Primer With Real Life Application ( DeepMind Continuous Risk Model of Acute Kidney Injury Prediction from EHR)*
LSTM is the key algorithm that enabled major ML successes like Google speech recognition and Translate¹. It was invented in 1997 by Hochreiter and Schmidhuber as an improvement over RNN vanishing/exploding gradient problem. LSTM can be used to model many types of sequential data² — from time series data to continuous handwriting and speech recognition³,⁸. What is it that makes LSTMs so versatile¹⁶ and successful⁹ ?
LSTM cell has the ability to dynamically modify its state⁴ on each new input ( time step ). Past experience shapes how new input will be interpreted i.e. LSTM does not start learning from scratch with each new input/step, it uses previous knowledge ( expressed as state ) to decide on the output and hidden states values⁵.
One of the best explanations of how LSTM cells operate is Stanford Lecture 10 ( CS231n ).
LSTM has two hidden states — a standard RNN state hₜ and cell state cₜ ( an internal state that is not exposed to the outside world ). Input vector xₜ and previous hidden state hₜ₋₁ are combined ( stacked ) to compute i, f, o and g gates.
Here is the explanation of key terms from the above slide:
- i is the input gate — how much we want to input into our cell
- f is the forget gate — how much we want to forget the cell memory from the previous time step
- o is the output gate — how much we want to reveal cell state to the outside world
- g is how much we want to write into our input cell
- i, f, o use sigmoids as non-linearity, which means output values will be between 0 and 1
- g uses tanh, which means output will be between -1 and 1
- Element wise ( Hadamard ) product f ⊙ cₜ₋₁ : if we want to forget/remember a particular element of a cell state, forget gate will be set to 0/1
- i ⊙ g product: i is a vector of ones and zeroes specifying if write to a particular element of a cell state is enabled. g has values between -1 and 1 ( tanh ) and is a candidate value we consider writing at each time step
Now we can put it all together in hypothetical scenario where sigmoids are pushed to be either 0 or 1: cell state can be thought of as being represented by scalar values ( integer counters ) that are either incremented or decremented by one (tanh limit values are either -1 or 1).
The next step is to compute a hidden state that will be revealed to the outside world:
hₜ = o ⊙ tanh(cₜ)
Cell state is squashed via tanh and multiplied by an output gate o ( again, sigmoid values are either 0 or 1 ), telling if we want to reveal a particular cell state element or not ( when we are computing the external hidden state for this time step ).
Recap: we start with previous hidden state hₜ₋₁and input xₜ vectors, stack them and multiply with weight matrix W to produce four gates ( i, f,o,g ). Forget gate f is then multiplied with previous cell state cₜ₋₁. The input and gate gate⁵ are multiplied element-wise and added to the cell state to give us next cell state cₜ, which gets squashed through a tanh and is element-wise multiplied with output gate o to produce the next hidden state hₜ¹².
What happens to LSTM cell during the backward pass¹¹ ?
This is where LSTM improvement over standard RNN⁸ comes in. Backpropagate on addition operation ( + ) will copy gradient directly into two branches ( multiply by 1 on each branch )⁶. This upstream gradient is then multiplied element wise by the forget gate — thus making backprop simpler ( no matrix multiplication ). Forget gate will probably be different for every step ( as opposed to vanilla RNN where W is shared for all time steps and gradients are flowing through tanh at every time step during the backward pass⁷). Element wise multiply is guaranteed to be between zero and one (forget gate is coming out from a sigmoid), which leads to nicer numerical properties.
Gated Recurrent Unit is a new type of hidden unit that has been motivated by the LSTM unit but is much simpler to compute and implement.
GRU is widely used ( PyTorch, TensorFlow both come with GRU implementations ).
Serial Nature of RNN/LSTM
RNN/LSTM architecture is inherently non-parallel, thus making it hard for RNNs to take advantage of GPU and other parallelism enablers. There are quite a few attempts to improve RNN training performance ( QRNN, Simple Recurrent Units ). LSTM influenced architectures also continue to appear ( Highway Networks, Recurrent Highway Networks ), as well as variations on LSTM theme ( GRU ).
Highway Networks make it possible to train very deep ( hundreds of layers ) feedforward networks. We present it here because it is LSTM inspired¹⁵ and used in DeepMind kidney paper⁸. The gist of it is that muliple layers are grouped in blocks, with output of each block defined as:
y = H(x, Wₕ)· T (x, Wₜ) + x · C(x, W𝒸)
T is transform, C is carry gate — a learned numbers between 0 and 1 that determine how much of information is transformed input and how much is just carried over unchanged ( original input ) via highway ( a highway layer can smoothly vary its behavior between that of H and that of a layer which simply passes its inputs through ).
Recurrent Highway Networks
x[ᵗ] is directly transformed only by the first Highway layer ( l = 1) in the recurrent transition and for this layer s[ᵗ]ₗ ₋ ₁ is the RHN layer’s output of the previous time step. Subsequent Highway layers only process the outputs of the previous layers. Dotted vertical lines in Figure 3 separate multiple Highway layers in the recurrent transition.
H, C and T are feed forward layer, carry and transform gates explained above in Highway Networks section.
I am now moving on to above mentioned attempt to work around inherently serial nature of RNN — QRNN ( Quasi RNN ) — that significantly improves training time over LSTM baseline while maintaining or exceeding accuracy¹⁴. QRNN borrows concepts from both CNN and LSTM to implement almost fully parallel passes through time sequence data.
QRNN has two major subcomponents:
- convolution, where filter bank slides along input sequence to produce candidate vectors ( dot product of filter and input sequence )
- pooling, where LSTM-like gates ( i, fo, ifo — input, forget, output ) are used to pool values from candidate vectors ( this is the only recurrent i.e. the non-parallel part )
*The article is best viewed in desktop browser or Medium application on smartphone as Medium does not support subscripts.
¹ The latest breakthroughs in NLP and language modeling ( BERT ) do not use LSTM, but are narrow, task oriented networks. Google Duplex is also RNN ( probably LSTM ) based.
Even business outlets like Bloomberg are aware of the importance of LSTM: These powers make LSTM arguably the most commercial AI achievement, used for everything from predicting diseases to composing music.
DeepMind Kidney paper is important because it goes beyond usual RNN applications to short sequences ( voice, text ). It demonstrates RNN applicability to massive datasets where you need to aggregate, learn and remember information over time ( 6 billion entries — prescriptions, lab results, diagnosis for over 700k patients ), over long time horizons ( 12 years ), on sparse time series data ( each 24 hour period was broken into 4 six hour periods ). The model operates sequentially over the electronic health record, performing continuous risk prediction of Acute Kidney Injury in the next 48 hours based on blood creatinine levels. This approach can be applied for continuous risk prediction in financial environments ( bank risk models — probability of default, regime changes ).
The paper also compares effectiveness of major RNN types ( The following are some of the recurrent neural network (RNN) cells that can be trialed: long short-term memory (LSTM), update gate RNN, intersection RNN, simple recurrent unit( SRU), gated recurrent unit (GRU), neural Turing machine (NTM), memory-augmented neural network, differentiable neural computer (DNC) and relational memory core ) and ends up using a three-layer LSTM with highway connections ( Fig 2 above; initially it was SRU ). The paper confirms superiority of RNNs over feed-forward models ( MLP, Gradient Boosted Trees, Random Forest — Supplementary Information Section D ), which we think comes from RNNs ability to remember history.
RNN model is significantly better than XGBoost across both AUPRC and AURCO measures.
² RNN/LSTMs are a general paradigm for handling variable sized sequence data that allow us to pretty naturally capture .. different types of setups in our models ( see picture below ). Vanilla Neural Networks have fixed structure — they only receive an input ( fixed size vector ) which is fed through a set of layers and produce a single output ( image classification, for example ).
³ Video can also be analyzed via LSTM ( caption generation ). Images can be generated ( PixelRNN ).
⁴ A state is represented as a set of floating point numbers. It is constantly modified by the weights and the biases via backprop.
⁵ Gate gate is best name the course author came up for g.
⁶ More detail on how backprop is calculated on addition and multiplication in MIT RNN course.
⁷ Picture below illustrates exploding and vanishing gradient with vanilla RNN; error gradient has to flow through all the cells, with repeated tanh and W matrix multiplication
⁸ Recurrent neural networks (RNNs) run sequentially over the electronic health record entries and are able to implicitly model the historical context of a patient by modifying an internal representation (or state) through time. We use a stacked multiple-layer recurrent network with highway connections between each layer, which at each time step takes the embedding vector as an input. We use the simple recurrent unit network as the RNN architecture, with tanh activations. We chose this from a broad range of alternative RNN architectures: specifically, the long short-term memory, update gate RNN and intersection RNN, simple recurrent units, gated recurrent units, the neural Turing machine, memory-augmented neural network, the Differentiable Neural Computer¹⁰ and the relational memory core. These alternatives did not provide significant performance improvements over the simple recurrent unit ( modified GRU — RM ) architecture.
⁹ Alex Graves ( DeepMind ) was Schmidhuber’s PhD student who demonstrated how powerful LSTM is: There is always this kind of Occam’s razor principle that gets applied once things get out in the field .. people try them out and inevitably the simple thing ends up working; people said the same thing about LSTM ; for years and years you know all the time I was working on it people kept saying .. well why are you using all these gates .. this looks very intricate, why don’t you just use a normal RNN or why don’t use a hidden Markov model ; and in the end it ( LSTM ) really did prove its worth but it took a long time; LSTM was way behind for a long time .. I was applying it — it took a long time before we could actually get state of the art in things like speech recognition .. I think what’s probably happening now is a simpler model that scaled up bigger run faster and run on more data does better and that continues for a while and then the trend reverses
¹⁰ Technically speaking a recurrent neural network can .. in theory emulate a Turing machine ( it ) is already something like a neural computer ( since it can approximate any function and remember any information — RM). Long Short-Term Memory already has quite a substantial ability to remember things but their memory is tied up in the activations of the latent state of this network and there’s several issues with that; the simplest one .. is that you just you can’t increase the amount of memory without increasing the size of the network and therefore the computational costs; another issue another kind of side effect is that the the contents of the memory tend to be rather fragile because they’re held in this the space of activations that’s being constantly updated, it’s being constantly bombarded with new information; so LSTM was actually designed to protect, to safeguard that information in some senses with the gates; we thought maybe we could take this one step further by removing this information from the network activations and making it external ( Neural Turing Machine and Differentiable Network Computer are attempts in that direction — RM )
¹¹ The original LSTM paper didn’t use backpropagation through time ( BPTT ); it uses this kind of hybrid of RTRL ( Real Time Recurrent Learning ) for the connections within the LSTM cells and backpropagation truncated to one step outside the cells; the reasoning there was that all the interesting stuff goes on inside the cells and the rest is really just something you can you can truncate away
¹² A great illustration of LSTM architecture ( Jozefowicz ):
¹³ RNN stacking is non-trivial; it is much more complex than feed forward networks. Many papers are written on this topic.
¹⁴QRNN is one of the major influences on Simple Recurrent Unit architecture.
¹⁵ LSTM’s traditional additive activation-based approach[LSTM1–13] is mirrored in the LSTM-inspired Highway Network (May 2015),[HW1][HW1a][HW3] the first working really deep feedforward neural network with hundreds of layers. It is essentially a feedforward version of LSTM[LSTM1] with forget gates.
¹⁶ Cruise replaced commonly used Kalman filter based object tracking with LSTM to associate objects across sensors and through time to measure their kinematics such as velocity and turn rate