Many of our readers may have heard about BERT, the acronym for a Bidirectional Encoder Representation Transformer. Some might have used the pretrained model from BERT for fine-tuning on their own tasks using their own datasets. However, according to a recent survey, even among users, very few understand the architecture of BERT. The original idea for the BERT came from some advanced work on Transformers.
A detail discussion of a what a Transformer is and how it works was published in 2017 paper called “Attention Is All You Need” jointly published by the Google Brain Team and University of Toronto. This interesting paper was inspired from the paper “Neural Machine Translation (NMT)” published by Google. The core idea behind the Transformer model is self-attention. Self-attention is the ability to attend to different positions of input sequences to compute a representation of that sequence. In this regard Transformers are similar to human vision. When a person’s visual cortex detects an object and its surrounding, it does not typically scan the entire object and its surrounding scene; rather it focuses on a specific feature or portion of the item depending on what the person is searching for.
Over time when the visual cortex notices that a feature or object appears in a particular part of a scene, it looks at the same portion of a similar scene to look for the object. There are articles that discuss what Transformers do and achieve, but there are few or no articles beyond scientific and academic papers that discuss the architecture of Transformers. Here, in this post we will mainly discuss the architecture and explore in some detail the technical detail of the architecture behind BERT.
The problem with any Recurrent Neural Network 9RNN) such as a Long short-term memory (LSTM) or Gated Recurrent Network (GRU) is that the initial sequence of a long sequence input will be forgotten at the end of encoder end token. However, a Transformer will capture the score for every token in the sequence, such that every token in the sequence has its own meaningful representation in a long sequence. That score for every token is a called the context vector. We will not discuss the context vector abstraction as our focus is on the architecture. Most RNNs have another problem of long-term dependencies. Take for instance, a language model trying to predict the next word based on a given prior sequence of words with each word a separate token. If we have the sentence, “The cat loves to drink milk”, we can easily ascertain that the last token has to be “milk” given the previous tokens. We don’t need any further context to understand this short sequence.
Figure: Common RNN
Since the sequence is short, it is easy to ascertain the next token. But there are more cases when sequences are much longer and when same language model is used it has trouble trying to predict the next token. Sometimes the predictions may go horribly wrong when using RNN cells. Take for instance, the incomplete sentence, “The cats are very active and fast, they have strong ____.” Language models find this type of prediction hard to figure out but humans will easily guess that the next token is some part of a cat’s body. And, given a context active and fast, the most likely token would be “legs”. As we stated, when using an RNN, there is the problem of previously seen context in long sequences being forgotten. Further, during the training of an RNN there is a problem of vanishing and exploding gradients (details under reference section below).
However, most of the problems of any RNN can be solved using the LSTM. The internal architecture of an LSTM is depict in the figure below, and we will not talk much about LSTM here but have some in the reference section below.
LSTM typically solves the problem of vanishing and exploding gradients. However, an LSTM fails to solve the same problem encountered by an RNN model – long term dependencies. An LSTM does not do well when sentences are long because as the token we are trying to predict gets farther from the first token associated with the context (vector) the information contained in the context vector decreases exponentially with distance. The longer the sentence is the less the context vector can tell us. Please refer to the above architecture and its related equations. As you my note, by the end of the sequence at decoding portion, if the sequence is long, the encoder needs to predict using the context vector which is not rich in information about the very first tokens,
To solve this issue, we can associate a context score with each token individually so that when decoding its score and its context vector area available to find out the predicted token. We say, the Transformer model uses the attention (self-attention) to make its predictions. The figure below shows what attention mechanism does, with a simple example. This is adopted from the Neural Machine Translation paper.
But astute readers may note that the attention mechanism still does not solve the problem of parallelism. Practically this means that during the phase of decoding and encoding, the attention mechanism spends a lot of expensive computing time with a large corpus.
Representative Architecture of Transformer
Figure: Complete Transformer (“Attention Is All You Need”)
The figure above is the complete architecture of a Transformer. The left part of the transformer model is the encoder and the right part is the decoder. A transformer model generally has a similar pattern to that of a sequence to sequence (seq-2-seq) model. In the figure above of the Transformer architecture, both the encoder and decoder consist of a “Multi-Head Attention” and “Feed Forward network” each associated with a “Normalization” layer. At the decoder end is a “Masked Multi-Head Attention”.
Roughly, during the encoding sequence, each input sentence is passed through the model and generates an output for each token in sequence. At the decoding sequence, the decoder tries to decode the encoder’s information along with its own shifted right output, to make a prediction.
Here we look at some of the parts in detail.
The Encoder consists of three parts: i) Input Embedding; ii) Positional Encoding; and, iii) N encoded layers.
The Encoded Layers, consist of two sub-layers: i) a Multi-Headed attention (with padding mask) layer; and a, ii) Point Wise feed forward network. Each of these sublayers have a residual connection followed by a normalization layer. The job of the residual connection is to avoid the vanishing and exploding gradient that we saw troubled any RNN.
The Multi-Headed attention is comprised of four parts; i) Linear Layers (split into heads); ii) scaled dot-product attention; iii) Concatenation of heads; and, iv) a final linear layer.
Figure: Multi-Headed attention
The input to the Multi-Headed attention layers are the Query (Q), Key (K) and Values (V). These are subsequently passed through the Dense layer to yield to multiple heads. The multiple heads are processed by the scaled dot-product attention layer and then concatenated by the broad-casted head before sending it to a dense layer (linear layer). We use the Multi-Head attention because we can parallelize the job, resulting in efficiency that the self-attention and LSTM cannot achieve. Q, K and V can be split into multiples heads to jointly attend to information at different positions from different representational space. This reduces the total computational cost and the final computational cost will be similar to that of using single head attention network.
The Scaled Dot-Product Attention, for an input of Query (Q), Key (K) and Value (V) we compute as the attention score of:
The figure below is the single scaled Dot-product Attention architecture. And the equation above illustrates the computational weight for the attention as that shown in the figure below.
Figure: Scaled Dot-Product Attention
The attention score dot-product of Query (Q) and Transpose of Key (Q), which is normalized by factor of square root of the depth. And, taking the softmax of previous quantity, which again takes the dot-product with Values (V).
Positional Encoding, is used to give the relative position of the words in the sentences. This positional encoding is added to the embedding vector. Embedding vector represents the token in d-dimensional space where tokens with similar meaning will be closer to each other (same concept behind word embedding). Since the embedding vector do not encode the position of the token in the sentence, adding positional encoding gives the similarity of the token along with their position in the sentence, in d-dimensional space.
Masking, ensure if the token at the position is just a padding or a token. This enables the model to identify the token and treat differently based on mask value. Value 0 is a padding while value 1 is a token.
Here we have summarized some of the complicated architecture behind the BERT architecture. The BERT is a new and powerful way to extend language models to become more accurate at a lower overall computing cost. In future blogs we will explore more of this and other models.
2) BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding [Devlin-Jacob, Kristina-Lee, Kenton-Ming-wei Chang]
3) Neural Machine Translation by Jointly Learning to Align and Translate [Bahdanau Dzmitry, Cho Kyunghyun, Bengio Yoshua]
4) Attention Is All You Need [Vaswani Ashish, Shazeer Noam, Parmar Niki, Uszkoreit Jakob, Jones Llion, Gomez Aidan N., Kaiser Lukasz, Polosukhin Illia]