[ PROMPT_NODE_22344 ]
extension_methods
[ SKILL_DOCUMENTATION ]
# 上下文扩展方法
基于已发表研究对 YaRN、ALiBi 和位置插值(Position Interpolation)的综合对比。
## 目录
- YaRN (Yet another RoPE extensioN)
- ALiBi (Attention with Linear Biases)
- 位置插值 (Position Interpolation)
- 方法对比
## YaRN: Yet another RoPE extensioN
**论文**: arXiv 2309.00071 (2023)
**作者**: Bowen Peng, Jeffrey Quesnelle, Honglu Fan, Enrico Shippole
### 概述
YaRN 将基于 RoPE 的模型扩展至 128k+ 上下文,且训练数据量比以往方法少 10 倍。
### 关键创新
1. **NTK 感知插值 (NTK-aware interpolation)**:对不同频率分量进行差异化缩放
2. **注意力温度缩放 (Attention temperature scaling)**:调整注意力锐度
3. **分段 NTK (NTK-by-parts)**:插值/外推混合策略
### 技术细节
**问题**:朴素的位置插值会均匀压缩所有频率,导致高频信息丢失。
**解决方案**:对不同频率进行差异化处理。
python
# 频率分解
# 低频 ( 1/β_fast): 外推 (保持原样)
# 中频: 两者之间的平滑过渡
def yarn_get_mscale(scale=1.0):
"""注意力温度缩放。"""
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
"""查找分段 NTK 的维度截断点。"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
"""查找插值的频率范围。"""
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def yarn_linear_ramp_mask(min_val, max_val, dim):
"""创建插值与外推之间的平滑过渡。"""
if min_val == max_val:
max_val += 0.001 # 避免除以零
linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
### 完整的 YaRN 实现
python
class YaRNScaledRoPE(nn.Module):
"""完整的 YaRN 实现。"""
def __init__(
self,
dim,
max_position_embeddings=2048,