How LSTMs work
I’ve always been compared to a goldfish. My short-term memory can be fuzzy.
Neural networks (Like CNNs and GANs)have no memory at all. They’re worse than goldfish.
But RNNs, are a type of neural network, that can recall past training data/information.
For example, if you had a neural net that predicted an output (y) based on (x), normally (y) would be outputted and never used again by the network. But RNNs (recurrent neural networks)continue using past information, to help increase the performance of its model.
So, if we wanted a neural net to understand each word of a sentence, we’d use a recurrent neural network (RNN).
E.g. “What time is it?” would be fed into the network as 4 separate inputs (x). The network would put the most weight or relevance on the latest inputs (i.e it).
But this causes a problem in the system called a “vanishing gradient” . The short version of the problem is that by the time the network gets all the inputs (4 words), it will basically have forgotten the first words because their weight/relevance becomes close to zero.
So, neural networks have non-existent memory, while RNNs are part of the me-and-goldfish club (kind of have a memory).
Fortunately, there’s a type of RNN called an LSTM (Long-short term memory), which helps solve the problem. 😁 It surpasses me and the goldfish.
How LSTMs work
LSTMs are a type of recurrent neural network, but instead of simply feeding its outcome into the next part of the network, an LSTM does a bunch of math operations so it can have a better memory.
An LSTM has four “gates”: forget, remember, learn and use(or output)
It also has three inputs: long-term memory, short-term memory, and E. (E is some training example/new data)
Step 1: When the 3 inputs enter the LSTM they go into either the forget gate, or learn gate.
The long-term info goes into the forget gate, where, shocker, some of it is forgotten (the irrelevant parts).
The short-term info and “E” go into the learn gate. This gate decides what info will be learned. Bet you didn’t see that one coming!!!!
Step 2: information that passes the forget gate (it is not forgotten, forgotten info stays at the gate) and info that passes learn gate (it is learned) will go to the remember gate (which makes up the new long term memory) and the use gate (which updates short term memory +is the outcome of the network).
Learn Gate
TL;DR learn gate combines STM + E (input) and chooses to ignore the unneeded info.
This gate combines existing Short-term memory (STM) and some input “E” , multiplies by a matrix (W) and adds b. Then squishes this all into a tanh function.
This combination gives us “N”.
Then it ignores some of the short-term memory, by multiplying the combined result by an “ignore factor” .
The ignore factor (I) is calculated by combining STM and E, with a new set of W(weights) and b(biases)
Once we have N and I, we multiply them together, and that’s the result of the learn gate.
We have “learned” our new information (E).
Forget Gate
Forget gate is the gate you use to dump out all the unnecessary long term information. Kind of like when you study for a big exam, and the next day you forget everything. That’s the power of the forget gate.
Basically, the long-term memory (LTM) gets multiplied by a forget factor (f). This factor will make some of the long-term information be “forgotten”
The forget factor is this:
It is computed by taking the short-term memory, and input (E), multiplying them by some weights and biases and squishing them into a sigmoid function.
This function (f) gets multiplied by LTM — and boom, we’re left with LTM that we need.
Remember Gate
This gate takes the information from the forget gate and adds it to the information from the learn gate, to compute the new long term memory.
Rember gate = Learn gate output + Forget gate output
Use Gate
Use gate takes the LTM from the forget gate, and STM + E from the learn gate and uses them to come up with a new short term memory or an output (same thing).
For example, if we were trying to classify images, the output would be the network classification.
It takes the output of the forget gate, and puts it into a tanh activation function, like so:
It takes the output of the learn gate, and applies a sigmoid function, so the equation looks like this:
Then, the gate multiplies V x U, to get the new short-term memory!
Sentiment Analysis
I used LSTM’s to understand some text, specifically movie reviews and to determine if the review was positive or negative.
Sentiment analysis can be used for machines to understand human emotions — which is pretttyyyy cool!
My RNN uses two LSTM layers to help with its memory when analyzing text. Here’s the model architecture:
import torch.nn as nnclass SentimentRNN(nn.Module):
"""
The RNN model that will be used to perform Sentiment analysis.
"""def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):
"""
Initialize the model by setting up the layers.
"""
super(SentimentRNN, self).__init__()self.output_size = output_size
self.n_layers = n_layers
self.hidden_dim = hidden_dim
# embedding and LSTM layers
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers,
dropout=drop_prob, batch_first=True)
# dropout layer
self.dropout = nn.Dropout(0.3)
# linear and sigmoid layers
self.fc = nn.Linear(hidden_dim, output_size)
self.sig = nn.Sigmoid()def forward(self, x, hidden):
"""
Perform a forward pass of our model on some input and hidden state.
"""
batch_size = x.size(0)# embeddings and lstm_out
x = x.long()
embeds = self.embedding(x)
lstm_out, hidden = self.lstm(embeds, hidden)
# stack up lstm outputs
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
# dropout and fully-connected layer
out = self.dropout(lstm_out)
out = self.fc(out)
# sigmoid function
sig_out = self.sig(out)
# reshape to be batch_size first
sig_out = sig_out.view(batch_size, -1)
sig_out = sig_out[:, -1] # get last batch of labels
# return last sigmoid output and hidden state
return sig_out, hidden
def init_hidden(self, batch_size):
''' Initializes hidden state '''
weight = next(self.parameters()).data
hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),
weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())
return hidden
I hope you learned a bit about how LSTMs work!
To connect with me, my email is igrandic03@gmail.com, twitter, LinkedIn, and you can sign up for my monthly newsletter.