[ PROMPT_NODE_22343 ]
Extension Methods
[ SKILL_DOCUMENTATION ]
# Context Extension Methods
Comprehensive comparison of YaRN, ALiBi, and Position Interpolation based on published research.
## Table of Contents
- YaRN (Yet another RoPE extensioN)
- ALiBi (Attention with Linear Biases)
- Position Interpolation
- Method Comparison
## YaRN: Yet another RoPE extensioN
**Paper**: arXiv 2309.00071 (2023)
**Authors**: Bowen Peng, Jeffrey Quesnelle, Honglu Fan, Enrico Shippole
### Overview
YaRN extends RoPE-based models to 128k+ context with 10× less training data than previous methods.
### Key Innovations
1. **NTK-aware interpolation**: Scales different frequency components differently
2. **Attention temperature scaling**: Adjusts attention sharpness
3. **NTK-by-parts**: Hybrid interpolation/extrapolation
### Technical Details
**Problem**: Naive position interpolation compresses all frequencies uniformly, losing high-frequency information.
**Solution**: Different treatment for different frequencies.
```python
# Frequency decomposition
# Low frequencies ( 1/β_fast): Extrapolate (extend as-is)
# Middle frequencies: Smooth ramp between the two
def yarn_get_mscale(scale=1.0):
"""Attention temperature scaling."""
if scale 2048 are out-of-distribution
# Interpolation (good): positions [0, 0.0625, 0.125, ..., 2048]
# All positions within [0, 2048] (in-distribution)
```
### Mathematical Formulation
**Original RoPE**:
```
position_ids = [0, 1, 2, 3, ..., L-1]
```
**Position Interpolation** (scale factor s):
```
position_ids = [0, 1/s, 2/s, 3/s, ..., (L-1)/s]
```
### Implementation
```python
class InterpolatedRoPE(nn.Module):
"""RoPE with position interpolation."""
def __init__(self, dim, max_seq_len=2048, base=10000, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
# Standard RoPE frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len, device):
# Position indices
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# Interpolate positions
t = t / self.scaling_factor # KEY LINE
# Standard RoPE
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
```
### Fine-tuning Requirements
**Minimal fine-tuning needed**:
```python
# Extension: 2k → 32k (16× scale)
scaling_factor = 16.0
# Training config
training_args = {
"max_steps": 1000, # Only 1000 steps!
"learning_rate": 2e-5, # Small LR
"batch_size": 1,
"gradient_accumulation_steps": 16,
}
# Results: Near-perfect perplexity retention
```
### Theoretical Analysis
**Interpolation bound** (from paper):
Upper bound of interpolation error is ~600× smaller than extrapolation error.
```
Extrapolation error: O(L^2) # Grows quadratically
Interpolation error: O(1/s) # Shrinks linearly with scale
```
### Results
**LLaMA models extended to 32k**:
| Model | Original Context | Extended Context | Fine-tune Steps | Perplexity |
|-------|-----------------|------------------|----------------|------------|
| LLaMA 7B | 2048 | 32768 | 1000 | 2.72 |
| LLaMA 13B | 2048 | 32768 | 1000 | 2.55 |
| LLaMA 33B | 2048 | 32768 | 1000 | 2.38 |
| LLaMA 65B | 2048 | 32768 | 1000 | 2.26 |
**Passkey retrieval**: 100% accuracy up to 32k tokens
### Advantages
1. **Minimal training**: 1000 steps sufficient
2. **Stable**: Interpolation more stable than extrapolation
3. **Simple**: One-line code change
4. **Effective**: Works across all LLaMA sizes
### Disadvantages
1. **Limited extrapolation**: Can't go beyond trained range without fine-tuning
2. **Information compression**: All positions compressed into trained range
## Method Comparison
### Training Requirements
| Method | Pre-training Needed | Fine-tuning Steps | Training Tokens |
|--------|---------------------|-------------------|-----------------|
| **ALiBi** | Yes (from scratch) | 0 | Full (100B+) |
| **Position Interpolation** | No | 1,000 | ~100M |
| **YaRN** | No | 400 | ~100M |
| **Linear RoPE Scaling** | No | 1,000-5,000 | ~1B |
### Extrapolation Performance
**Test**: Train on 2k, test on 8k, 16k, 32k
| Method | 8k PPL | 16k PPL | 32k PPL | Extrapolation Quality |
|--------|--------|---------|---------|----------------------|
| **ALiBi** | 12.1 | 12.3 | 12.5 | Excellent |
| **YaRN** | 11.8 | 12.0 | 12.2 | Excellent |
| **Position Interpolation** | 12.5 | 13.2 | 14.8 | Poor |
| **Linear Scaling** | 13.1 | 15.2 | 19.4 | Poor |
### Memory and Speed
| Method | Memory vs Baseline | Speed vs Baseline |
|--------|--------------------|--------------------|
| **ALiBi** | -11% | +11% |
| **Position Interpolation** | 0% | 0% |
| **YaRN** | 0% | -5% |
| **Linear Scaling** | 0% | 0% |
### Use Case Recommendations
```python
# New model from scratch → ALiBi
if training_from_scratch:
use_method = "ALiBi"
# Extending existing RoPE model with best quality → YaRN
elif need_sota_quality:
use_method = "YaRN"
# Quick extension with minimal compute → Position Interpolation
elif need_quick_solution:
use_method = "Position Interpolation"
# Moderate extension, simple implementation → Linear Scaling
else:
use_method = "Linear RoPE Scaling"
```
## Resources
- **YaRN Paper**: https://arxiv.org/abs/2309.00071
- **ALiBi Paper**: https://arxiv.org/abs/2108.12409
- **Position Interpolation Paper**: https://arxiv.org/abs/2306.15595
- **YaRN Implementation**: https://github.com/jquesnelle/yarn
- **ALiBi Implementation**: https://github.com/ofirpress/attention_with_linear_biases
- **Together AI Blog**: https://www.together.ai/blog/llama-2-7b-32k
Source: claude-code-templates (MIT). See About Us for full credits.