[ PROMPT_NODE_22348 ]
rope
[ SKILL_DOCUMENTATION ]
# RoPE: 旋转位置编码 (Rotary Position Embeddings)
基于 RoFormer 论文 (arXiv 2104.09864) 和 HuggingFace transformers 实现的完整技术指南。
## 目录
- 数学公式
- 实现细节
- 缩放技术
- 生产使用
## 数学公式
**来源**: RoFormer: Enhanced Transformer with Rotary Position Embedding (arXiv 2104.09864)
### 核心思想
RoPE 使用旋转矩阵编码绝对位置,同时在注意力机制中自然地结合了相对位置依赖性。
### 公式
给定位置索引 `m` 和嵌入维度 `d`:
旋转矩阵 R_θ(m):
[cos(mθ₁) -sin(mθ₁) 0 0 ]
[sin(mθ₁) cos(mθ₁) 0 0 ]
[0 0 cos(mθ₂) -sin(mθ₂) ]
[0 0 sin(mθ₂) cos(mθ₂) ]
...
其中 θⱼ = base^(-2j/d) 对于 j ∈ [0, 1, 2, ..., d/2)
**关键属性**: 位置 m 和 n 之间的注意力仅取决于相对距离 (m - n)。
### 推导
**步骤 1: 通过旋转进行位置编码**
q_m = W_q x_m 旋转 mθ
k_n = W_k x_n 旋转 nθ
**步骤 2: 注意力分数**
score(q_m, k_n) = q_m^T k_n
= (旋转后的查询) · (旋转后的键)
= f(q, k, m-n)
分数取决于相对位置 `m - n`,而非绝对位置。
## 实现细节
**来源**: HuggingFace transformers/modeling_rope_utils.py
### 基础 RoPE 实现
python
import torch
import math
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""预计算旋转频率 (cos + i*sin)。"""
# 计算逆频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 位置索引
t = torch.arange(end, device=freqs.device)
# 外积: (end, dim/2)
freqs = torch.outer(t, freqs).float()
# 转换为复指数 (欧拉公式)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # e^(i*θ) = cos(θ) + i*sin(θ)
return freqs_cis
def reshape_for_broadcast(freqs_cis, x):
"""重塑频率张量以匹配 x 的维度。"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq, xk, freqs_cis):
"""将旋转嵌入应用于查询和键。"""
# 转换为复数
xq_ = torch.view_as_complex(