Linformer: Making Transformers Linear, Efficient, and Scalable
💪 Motivation
We all know the Transformer revolutionized the AI world after the Paper came Attention is all you need, The Model Transformer can perform very long sequence tasks and make fewer errors than other models like RNN, CNN, etc. The Major problem with the Transformer is that it requires a lot of computation power.
👋 Introduction
So, our traditional transformer Proposed in Attention is all you need. The paper solves the biggest problem of RNN (Recurrent Neural Network), gradient vanishing, but the Transformer requires too much time and computation to train and evaluate. To solve this problem the Linformer is come out, The model proposed by “Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, and Hao Ma” In the paper,
So the Linformer allows you to compute the Attention with sequence length.
🪫 Power Hunger: Transformer
The Transformer has too much Power hunger. It uses the NxN matrix multiplication where the N is too large, like 512, 1024, etc. so the NxN matrix multiplication happens too many times which is too expensive to the device, and takes a lot of memory and time. To solve this so many papers came out one of these is Linformer.
💡 Linformer Self-Attention vs Traditional Self-Attention
Let’s look at the high-level comparison between the Linformer vs Traditional Transformer.
Transformer High-Level Flow
On the top of Fig.1, we show the traditional equation of the Transformer's multihead attention.
The input first passes through the linear projection and next calculates the scaled Dot-Product Attention (Self-Attention).
We perform this operation for D_model / head_size. Then concat this all Self-Attention block and again pass through the linear block.
Linformer Idea
At the bottom of Fig.1, there is the equation for performing the linear Self-Attention.
So in simple words, the informer has two special shared parameters which are called “two linear projection matrices” These parameters are learnable the dimensions of these parameters are “k x n”, Throughout the blog these parameters are called the “E and F” are the distribute over the Key and Value.
The first project is the original “n x d” Key and value to the “k x d” and then applying the Attention like the Traditional Transformer as shown in Fig. 3.
❓How does the Linformer calculate the Attention?
To calculate the Linformer Attention first project the original “n x d” — dimensional key and value layer Wk(Ki) and Wv(Vi) into “k x d” dimensional projected key and value layers.
Note the Fig.1 bottom equation requires the O(nk) time and space complexity. If we select the projected dimension k such that k << n then we can reduce the amount of memory and time.
🌐 Parameter Sharing
One thing that Linformer's paper authors describe is that share parameters for the linear projection matrics Ei, Fi across the layers and heads, there are 3 ways to do that.
- Headwise sharing: For each layer define the E and F same such as Ei = E and Fi = F for all heads i.
- Key-value sharing: For each layer define the E as a primary parameter Ei = Fi = E for each key-value projection for all heads i.
- Layerwise sharing: For each layer use the single matrix E across all layers.
🧑🏻💻 Some Coding Guidelines
“See 🐈⬛ GitHub for the full code”
So in this, I try to make the Attention mechanism more simple reliable, and much more readable so you can relate this to the equations.
Linformer Self Attention
class LinearSelfAttention(nn.Module):
def __init__(self, d_model:int, seq_len:int, k:int, d_k:int, headwise_sharing:bool=True, key_val_sharing:bool=False, dropout:float=0.2):
"""LinearSelfAttention args
Args:
d_model (int): eg. 512, 1024 ...
seq_len (int): eg. 6, 1000, 20000 ...
k (int): Reduction Term as per the paper
d_k (int): d_model // head
dropout (float, optional): Dropout Rate. Defaults to 0.2.
"""
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.k = k
self.d_k = d_k
self.dropout = nn.Dropout(dropout)
self.headwise_sharing = headwise_sharing
self.key_val_sharing = key_val_sharing
self.E = nn.Linear(seq_len, k)
self.F = nn.Linear(seq_len, k) if headwise_sharing else self.E if key_val_sharing else None
assert(self.F is not None), "At least one sharing is require to perform the Linear Attention."
def forward(self, q, k, v, mask=None):
K = k.transpose(-1, -2)
K_ = self.E(K)
attention = (q @ K_) / math.sqrt(self.d_k)
if mask:
mask = torch.triu(torch.ones((self.seq_len, self.k))).bool().to(q.device)
attention.masked_fill_(mask == 0, -1e9)
attention = attention.softmax(dim=-1)
attention = self.dropout(attention)
V = v.transpose(-1, -2)
V_ = self.F(V)
V_ = V_.transpose(-1, -2)
return attention @ V_
Multihead Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model:int, seq_len:int, k:int, head:int,
headwise_sharing:bool=True, key_val_sharing:bool=False ,dropout:float=0.2):
"""MultiHeadAttention
Args:
d_model (int)
seq_len (int)
k (int)
head (int)
dropout (float, optional). Defaults to 0.2.
"""
super().__init__()
assert d_model % head == 0, "D_model is must divided by the head"
self.d_model = d_model
self.seq_len = seq_len
self.k = k
self.head = head
self.d_k = d_model // head
self.Wk = nn.Linear(d_model, d_model, bias=False)
self.Wq = nn.Linear(d_model, d_model, bias=False)
self.Wv = nn.Linear(d_model, d_model, bias=False)
self.Wo = nn.Linear(d_model, d_model, bias=False)
self.self_attention = LinearSelfAttention(self.d_k, seq_len, k, self.d_k, headwise_sharing, key_val_sharing, dropout)
def forward(self, q, k, v, mask=None):
q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)
query = q.view(q.shape[0], q.shape[1], self.head, self.d_k).transpose(1, 2)
key = k.view(k.shape[0], v.shape[1], self.head, self.d_k).transpose(1, 2)
value = v.view(v.shape[0], v.shape[1], self.head, self.d_k).transpose(1, 2)
attention_score = self.self_attention(query, key, value, mask)
x = attention_score.transpose(1, 2).contiguous().view(attention_score.shape[0], -1, self.head*self.d_k)
return self.Wo(x)
if __name__ == "__main__":
m = MultiHeadAttention(512, 6, 256, 8, True, False)
mask = torch.triu(torch.ones((1, 1, 256)), diagonal=1)
mask = mask == 0
a = torch.randn(1, 6, 512)
print(m(a, a, a, mask).size())
🧾 Result (from paper)
Here the results clearly show that the D_model is not affecting the time complexity of the model. The model is run at a constant time over changes in the D_model.
🔖 Reference and links
- Linformer: Self-Attention with Linear Complexity
- Attention Is All You Need
- Special Thanks to hkproj 🙏🏻