[ PROMPT_NODE_22342 ]
long-context
[ SKILL_DOCUMENTATION ]
# 长上下文:扩展 Transformer 上下文窗口
## 何时使用此技能
当你需要以下操作时,请使用长上下文技术:
- **处理长文档** (32k, 64k, 128k+ tokens) 使用 Transformer 模型
- **扩展上下文窗口** 预训练模型 (LLaMA, Mistral 等)
- **实现高效位置编码** (RoPE, ALiBi)
- **训练模型** 具备长度外推能力
- **部署模型** 高效处理变长输入
- **微调** 现有模型以适应更长上下文,且计算开销最小
**关键技术**: RoPE (旋转位置嵌入), YaRN, ALiBi (带线性偏置的注意力机制), 位置插值 (Position Interpolation)
**论文**: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)
## 安装
bash
# HuggingFace Transformers (包含 RoPE, YaRN 支持)
pip install transformers torch
# 用于自定义实现
pip install einops # 张量操作
pip install rotary-embedding-torch # 独立 RoPE 实现
# 可选: 用于效率优化的 FlashAttention
pip install flash-attn --no-build-isolation
## 快速开始
### RoPE (旋转位置嵌入)
python
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""旋转位置嵌入 (RoPE)。"""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
# 计算逆频率
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
# 位置索引
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# 计算频率
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# 计算 sin 和 cos
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
return emb.cos(), emb.sin()
def rotate_half(x):
"""旋转隐藏维度的一半。"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""将旋转嵌入应用于查询和键。"""
# q, k 形状: (batch, heads, seq_len, dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# 用法
rope = RotaryEmbedding(dim=64, max_seq_len=8192)
cos, sin = rope(seq_len=2048, device='cuda')
# 在注意力层中使用