“Intuition behind LSTM”
Vikram Voleti
IIIT Hyderabad
Neural Networks
Neural Networks
Perceptron: linear combination -> non-linearity
Non-linearity “squashifies” the output - sigmoid: 0 to 1, tanh: -1 to 1, relu: 0 to inf
Sigmoid was chosen because: 1) “squashifies” between 0 & 1 for convenient binary classification
2) Can act as probability, to put a thresholded on
3) derivative is easy to compute
Neural Networks
Why non-linearity?
Linear combination will only give linear boundaries between classes. Non-linearities make neural networks universal approximators.
Vanishing Gradient Problem
Key reason:
Fractional derivatives of non-linearities
(That’s why ReLU is preferred.)
(Gradient gets worse with number of layers)
Problem 1: Training neural networks via gradient descent using backpropagation incurs vanishing/exploding gradient problem.
Vanilla RNN
Problem 2: Fixed input size. (Sequence Learning?)
Solution: Recurrent Neural Networks
Source: colah.github.io, Nature
Vanilla RNN - Vanishing Gradient Problem
Back-Propagation Through Time (BPTT)
(Gradient gets worse with time)
Key reason: Haphazard updation of cell state
Hint: Related to eigenvalues of weight matrices
Problem 3: Training recurrent neural networks incurs vanishing/exploding gradient problem.
LSTM
Problem: Vanishing/Exploding gradients in RNNs
Solution: Long Short-Term Memory (Hochreiter and Schmidhuber, 1997) [link]
Source: colah.github.io
Introducing: Long-term memory (cell state), short-term memory (working memory/cell output)
LSTM
3 Gates: (sigmoid units in the diagram)
LSTM - forget gate
1. Forget Gate:
Eg.: Remember that a character had died, forget the colour of their shirt.
Remember the currently called function, forget a returned value.
C(t-1)
h(t-1)
forget_gate(t) = sigmoid( Wf ( x(t), h(t-1) ) )
remembered_cell_state(t) = forget_gate(t) .* C(t-1)
The forget_gate has a sigmoid activation so as to act as a fraction on the previous long-term memory/cell state - hence deciding what fraction to remember and what fraction to forget.
(Wf includes bias)
LSTM - input gate
2. Input Gate:
Eg.: The latest murder news, not an irrelevant character.
A new variable, not a comment.
C(t-1)
h(t-1)
input_gate(t) = sigmoid( Wi ( x(t), h(t-1) ) )
input_information(t) = tanh( Wa ( x(t), h(t-1) ) )
relevant_input_information(t) = input_gate(t) .* input_information(t)
The input_gate has a sigmoid activation so as to act as a fraction on the input information - hence deciding what fraction to consider and what fraction to let go.
The input_information has a tanh activation so as to squashify the information between -1 and 1.
LSTM - update long-term memory
Update long-term memory:
Eg.: Remember the latest news, don’t remember an irrelevant character.
Remember a new variable, don’t remember a comment.
C(t-1)
h(t-1)
C(t) = remembered_cell_state(t) + relevant_input_information(t)
C(t)
h(t)
LSTM - output gate
3. Output Gate:
Eg.: Retrieve the name of murderer, don’t retrieve the parents of victim.
Retrieve the updated variable, don’t retrieve the nesting structure.
C(t-1)
h(t-1)
output_gate(t) = sigmoid( Wo ( x(t), h(t-1) ) )
retrieved_memory(t) = tanh( C(t) )
h(t) = output_gate(t) .* retrieved_memory(t)
The output_gate has a sigmoid activation so as to act as a fraction on the retrieved information - hence deciding what fraction to keep and what fraction to ignore.
The retrieved_memory has a tanh activation so as to squashify the retrieved information between -1 and 1.
C(t)
h(t)
LSTM
SUMMARY:
Using ( x(t), h(t-1) ), i.e. current input and previous working memory,
C(t-1)
h(t-1)
C(t)
h(t)
Variant - Peephole
LSTM
NOTE:
C(t-1)
h(t-1)
C(t)
h(t)
References