1 of 17

“Intuition behind LSTM”

Vikram Voleti

IIIT Hyderabad

2 of 17

Neural Networks

3 of 17

Neural Networks

Perceptron: linear combination -> non-linearity

Non-linearity “squashifies” the output - sigmoid: 0 to 1, tanh: -1 to 1, relu: 0 to inf

Sigmoid was chosen because: 1) “squashifies” between 0 & 1 for convenient binary classification

2) Can act as probability, to put a thresholded on

3) derivative is easy to compute

4 of 17

Neural Networks

Why non-linearity?

Linear combination will only give linear boundaries between classes. Non-linearities make neural networks universal approximators.

5 of 17

Vanishing Gradient Problem

Key reason:

Fractional derivatives of non-linearities

(That’s why ReLU is preferred.)

(Gradient gets worse with number of layers)

Problem 1: Training neural networks via gradient descent using backpropagation incurs vanishing/exploding gradient problem.

6 of 17

Vanilla RNN

Problem 2: Fixed input size. (Sequence Learning?)

Solution: Recurrent Neural Networks

Source: colah.github.io, Nature

7 of 17

Vanilla RNN - Vanishing Gradient Problem

Back-Propagation Through Time (BPTT)

(Gradient gets worse with time)

Key reason: Haphazard updation of cell state

Hint: Related to eigenvalues of weight matrices

Problem 3: Training recurrent neural networks incurs vanishing/exploding gradient problem.

8 of 17

LSTM

Problem: Vanishing/Exploding gradients in RNNs

Solution: Long Short-Term Memory (Hochreiter and Schmidhuber, 1997) [link]

Introducing: Long-term memory (cell state), short-term memory (working memory/cell output)

9 of 17

LSTM

3 Gates: (sigmoid units in the diagram)

  1. Forget gate

  • Input gate

  • Output gate

10 of 17

LSTM - forget gate

1. Forget Gate:

  • Remember only some parts of the long-term memory and forget the rest.

  • Decide what to remember based on current input, and previous working memory.

Eg.: Remember that a character had died, forget the colour of their shirt.

Remember the currently called function, forget a returned value.

C(t-1)

h(t-1)

forget_gate(t) = sigmoid( Wf ( x(t), h(t-1) ) )

remembered_cell_state(t) = forget_gate(t) .* C(t-1)

The forget_gate has a sigmoid activation so as to act as a fraction on the previous long-term memory/cell state - hence deciding what fraction to remember and what fraction to forget.

(Wf includes bias)

11 of 17

LSTM - input gate

2. Input Gate:

  • Remember only some parts of the current input & previous working memory.

  • Decide what to remember based on current input & previous working memory.

Eg.: The latest murder news, not an irrelevant character.

A new variable, not a comment.

C(t-1)

h(t-1)

input_gate(t) = sigmoid( Wi ( x(t), h(t-1) ) )

input_information(t) = tanh( Wa ( x(t), h(t-1) ) )

relevant_input_information(t) = input_gate(t) .* input_information(t)

The input_gate has a sigmoid activation so as to act as a fraction on the input information - hence deciding what fraction to consider and what fraction to let go.

The input_information has a tanh activation so as to squashify the information between -1 and 1.

12 of 17

LSTM - update long-term memory

Update long-term memory:

  • Add the relevant input information to the long-term memory.

Eg.: Remember the latest news, don’t remember an irrelevant character.

Remember a new variable, don’t remember a comment.

C(t-1)

h(t-1)

C(t) = remembered_cell_state(t) + relevant_input_information(t)

C(t)

h(t)

13 of 17

LSTM - output gate

3. Output Gate:

  • Having saved relevant information into long-term memory, retrieve some working memory.

  • Decide what to retrieve based on current input & previous working memory.

Eg.: Retrieve the name of murderer, don’t retrieve the parents of victim.

Retrieve the updated variable, don’t retrieve the nesting structure.

C(t-1)

h(t-1)

output_gate(t) = sigmoid( Wo ( x(t), h(t-1) ) )

retrieved_memory(t) = tanh( C(t) )

h(t) = output_gate(t) .* retrieved_memory(t)

The output_gate has a sigmoid activation so as to act as a fraction on the retrieved information - hence deciding what fraction to keep and what fraction to ignore.

The retrieved_memory has a tanh activation so as to squashify the retrieved information between -1 and 1.

C(t)

h(t)

14 of 17

LSTM

SUMMARY:

Using ( x(t), h(t-1) ), i.e. current input and previous working memory,

  • forget unimportant long-term memory,
  • compute relevant input information, and add it to the long-term memory,
  • retrieve relevant working memory from long-term memory.

C(t-1)

h(t-1)

C(t)

h(t)

15 of 17

Variant - Peephole

  • Same as LSTM, except use long-term memory as well for all decisions:
    • ( x(t), h(t-1), C(t-1) ) for forget and input gates,
    • ( x(t), h(t-1), C(t) ) for output gate.

16 of 17

LSTM

NOTE:

  • De-coupling short-term and long-term memory avoids vanishing/exploding gradient (haphazard updation of cell state in vanilla RNN was primary culprit)
  • Methodical design of structure - no “mystery” as to why it works!

C(t-1)

h(t-1)

C(t)

h(t)

17 of 17

References

  1. Hacker’s guide to NN: http://karpathy.github.io/neuralnets/
  2. Interactive visualization: http://neuralnetworksanddeeplearning.com/chap4.html
  3. LSTMs: colah.github.io
  4. Quora answer: https://www.quora.com/What-is-an-intuitive-explanation-of-LSTMs-and-GRUs/answer/Edwin-Chen-1?share=b6d3b009&srid=Xfgu