Linformer: Making Transformers Linear, Efficient, and Scalable

Ruhaan838
5 min read3 days ago

--

💪 Motivation

We all know the Transformer revolutionized the AI world after the Paper came Attention is all you need, The Model Transformer can perform very long sequence tasks and make fewer errors than other models like RNN, CNN, etc. The Major problem with the Transformer is that it requires a lot of computation power.

👋 Introduction

So, our traditional transformer Proposed in Attention is all you need. The paper solves the biggest problem of RNN (Recurrent Neural Network), gradient vanishing, but the Transformer requires too much time and computation to train and evaluate. To solve this problem the Linformer is come out, The model proposed by “Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, and Hao Ma” In the paper,

Linformer: Self-Attention with Linear Complexity
Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma Facebook AI, Seattle, WA

So the Linformer allows you to compute the Attention with sequence length.

🪫 Power Hunger: Transformer

The Transformer has too much Power hunger. It uses the NxN matrix multiplication where the N is too large, like 512, 1024, etc. so the NxN matrix multiplication happens too many times which is too expensive to the device, and takes a lot of memory and time. To solve this so many papers came out one of these is Linformer.

💡 Linformer Self-Attention vs Traditional Self-Attention

Let’s look at the high-level comparison between the Linformer vs Traditional Transformer.

Fig.1: Multihead Attention and Linformer Multihead Attention (Source)

Transformer High-Level Flow

On the top of Fig.1, we show the traditional equation of the Transformer's multihead attention.

Fig.2: (left) Scaled Dot-Product Attention. (right) Multi-head attention consists of several attention layers running in parallel. (Source)

The input first passes through the linear projection and next calculates the scaled Dot-Product Attention (Self-Attention).

We perform this operation for D_model / head_size. Then concat this all Self-Attention block and again pass through the linear block.

Linformer Idea

At the bottom of Fig.1, there is the equation for performing the linear Self-Attention.

Fig.3: Visual Representation of Fig.1 bottom equation (Source)

So in simple words, the informer has two special shared parameters which are called “two linear projection matrices” These parameters are learnable the dimensions of these parameters are “k x n”, Throughout the blog these parameters are called the “E and F are the distribute over the Key and Value.

The first project is the original “n x d” Key and value to the “k x d” and then applying the Attention like the Traditional Transformer as shown in Fig. 3.

❓How does the Linformer calculate the Attention?

To calculate the Linformer Attention first project the original “n x d” — dimensional key and value layer Wk(Ki) and Wv(Vi) into “k x d” dimensional projected key and value layers.

Note the Fig.1 bottom equation requires the O(nk) time and space complexity. If we select the projected dimension k such that k << n then we can reduce the amount of memory and time.

🌐 Parameter Sharing

One thing that Linformer's paper authors describe is that share parameters for the linear projection matrics Ei, Fi across the layers and heads, there are 3 ways to do that.

  • Headwise sharing: For each layer define the E and F same such as Ei = E and Fi = F for all heads i.
  • Key-value sharing: For each layer define the E as a primary parameter Ei = Fi = E for each key-value projection for all heads i.
  • Layerwise sharing: For each layer use the single matrix E across all layers.

🧑🏻‍💻 Some Coding Guidelines

See 🐈‍⬛ GitHub for the full code”

So in this, I try to make the Attention mechanism more simple reliable, and much more readable so you can relate this to the equations.

Linformer Self Attention

class LinearSelfAttention(nn.Module):
def __init__(self, d_model:int, seq_len:int, k:int, d_k:int, headwise_sharing:bool=True, key_val_sharing:bool=False, dropout:float=0.2):
"""LinearSelfAttention args

Args:
d_model (int): eg. 512, 1024 ...
seq_len (int): eg. 6, 1000, 20000 ...
k (int): Reduction Term as per the paper
d_k (int): d_model // head
dropout (float, optional): Dropout Rate. Defaults to 0.2.
"""
super().__init__()

self.d_model = d_model
self.seq_len = seq_len
self.k = k
self.d_k = d_k
self.dropout = nn.Dropout(dropout)
self.headwise_sharing = headwise_sharing
self.key_val_sharing = key_val_sharing

self.E = nn.Linear(seq_len, k)
self.F = nn.Linear(seq_len, k) if headwise_sharing else self.E if key_val_sharing else None

assert(self.F is not None), "At least one sharing is require to perform the Linear Attention."

def forward(self, q, k, v, mask=None):

K = k.transpose(-1, -2)
K_ = self.E(K)

attention = (q @ K_) / math.sqrt(self.d_k)

if mask:
mask = torch.triu(torch.ones((self.seq_len, self.k))).bool().to(q.device)
attention.masked_fill_(mask == 0, -1e9)

attention = attention.softmax(dim=-1)

attention = self.dropout(attention)

V = v.transpose(-1, -2)
V_ = self.F(V)
V_ = V_.transpose(-1, -2)
return attention @ V_

Multihead Attention

class MultiHeadAttention(nn.Module):
def __init__(self, d_model:int, seq_len:int, k:int, head:int,
headwise_sharing:bool=True, key_val_sharing:bool=False ,dropout:float=0.2):
"""MultiHeadAttention

Args:
d_model (int)
seq_len (int)
k (int)
head (int)
dropout (float, optional). Defaults to 0.2.
"""
super().__init__()

assert d_model % head == 0, "D_model is must divided by the head"

self.d_model = d_model
self.seq_len = seq_len
self.k = k
self.head = head
self.d_k = d_model // head

self.Wk = nn.Linear(d_model, d_model, bias=False)
self.Wq = nn.Linear(d_model, d_model, bias=False)
self.Wv = nn.Linear(d_model, d_model, bias=False)

self.Wo = nn.Linear(d_model, d_model, bias=False)

self.self_attention = LinearSelfAttention(self.d_k, seq_len, k, self.d_k, headwise_sharing, key_val_sharing, dropout)

def forward(self, q, k, v, mask=None):

q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)

query = q.view(q.shape[0], q.shape[1], self.head, self.d_k).transpose(1, 2)
key = k.view(k.shape[0], v.shape[1], self.head, self.d_k).transpose(1, 2)
value = v.view(v.shape[0], v.shape[1], self.head, self.d_k).transpose(1, 2)

attention_score = self.self_attention(query, key, value, mask)

x = attention_score.transpose(1, 2).contiguous().view(attention_score.shape[0], -1, self.head*self.d_k)

return self.Wo(x)

if __name__ == "__main__":
m = MultiHeadAttention(512, 6, 256, 8, True, False)
mask = torch.triu(torch.ones((1, 1, 256)), diagonal=1)
mask = mask == 0
a = torch.randn(1, 6, 512)
print(m(a, a, a, mask).size())

🧾 Result (from paper)

inference time vs. sequence length for various Linformer models (Source)
Inference-time efficiency improvements of the Linformer over the Transformer, across various projected dimensions k and sequence lengths n. The left table shows time saved. The right table shows the memory saved. (Source)

Here the results clearly show that the D_model is not affecting the time complexity of the model. The model is run at a constant time over changes in the D_model.

🔖 Reference and links

--

--

Ruhaan838
Ruhaan838

No responses yet