Overview#
The Transformer architecture, introduced in “Attention Is All You Need” (Vaswani et al., 2017), revolutionized natural language processing by replacing recurrent and convolutional layers with self-attention mechanisms.
Key Components#
1. Self-Attention Mechanism#
The core innovation of transformers is the self-attention mechanism, which allows the model to focus on different parts of the input sequence when processing each element.
Mathematical Foundation:
Attention(Q, K, V) = softmax(QK^T / √d_k)V
Where:
- Q (Query): What we’re looking for
- K (Key): What we’re looking at
- V (Value): The actual content we extract
- d_k: Dimension of the key vectors (for scaling)
2. Multi-Head Attention#
Instead of using a single attention function, transformers use multiple attention “heads” to capture different types of relationships.
# Simplified multi-head attention
def multi_head_attention(Q, K, V, num_heads):
# Split into multiple heads
Q_heads = split_heads(Q, num_heads)
K_heads = split_heads(K, num_heads)
V_heads = split_heads(V, num_heads)
# Apply attention to each head
attention_outputs = []
for i in range(num_heads):
attention_output = scaled_dot_product_attention(
Q_heads[i], K_heads[i], V_heads[i]
)
attention_outputs.append(attention_output)
# Concatenate and project
concatenated = concatenate(attention_outputs)
return linear_projection(concatenated)
3. Positional Encoding#
Since transformers don’t have inherent sequence order awareness, positional encoding is added to input embeddings.
Sinusoidal Encoding:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Architecture Components#
Encoder Stack#
- Layer Normalization: Applied before each sub-layer
- Residual Connections: Help with gradient flow
- Feed-Forward Networks: Position-wise fully connected layers
Decoder Stack#
- Masked Self-Attention: Prevents attending to future positions
- Encoder-Decoder Attention: Allows decoder to focus on encoder outputs
- Output Projection: Maps to vocabulary size for next token prediction
Key Insights#
- Parallelization: Unlike RNNs, all positions can be processed simultaneously
- Long-Range Dependencies: Self-attention can directly connect distant positions
- Interpretability: Attention weights provide insights into model focus
- Scalability: Architecture scales well with increased model size
Practical Considerations#
Training Tips#
- Learning Rate Scheduling: Use warmup followed by decay
- Gradient Clipping: Prevent exploding gradients
- Dropout: Apply to attention weights and feed-forward layers
Common Challenges#
- Quadratic Complexity: Attention complexity grows with sequence length
- Memory Requirements: Large models require substantial GPU memory
- Training Instability: Requires careful hyperparameter tuning
Applications & Variants#
Pre-trained Models#
- BERT: Bidirectional encoder representations
- GPT: Generative pre-trained transformer (decoder-only)
- T5: Text-to-text transfer transformer
Efficiency Improvements#
- Linformer: Linear attention complexity
- Performer: Fast attention via random features
- Reformer: Reversible residual layers and locality-sensitive hashing
Code Example: Simple Attention#
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.query(x) # (batch_size, seq_len, d_model)
K = self.key(x)
V = self.value(x)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_model ** 0.5)
attention_weights = F.softmax(scores, dim=-1)
# Apply attention to values
output = torch.matmul(attention_weights, V)
return output, attention_weights
Further Reading#
Last Updated: January 2024
Next Topics: BERT fine-tuning, GPT architecture comparison