[ PROMPT_NODE_22728 ]
architecture-details
[ SKILL_DOCUMENTATION ]
# RWKV 架构细节
## Time-Mixing 和 Channel-Mixing 模块
RWKV 在 **Time-Mixing**(序列处理)和 **Channel-Mixing**(特征处理)模块之间交替进行。
### Time-Mixing 模块 (WKV 操作)
核心创新是 **WKV (Weighted Key-Value)** 机制:
python
# 传统注意力机制 (O(n²))
scores = Q @ K.T / sqrt(d) # n×n 矩阵
attention = softmax(scores)
output = attention @ V
# RWKV Time-Mixing (O(n))
# 使用递归以线性时间计算 WKV
for t in range(T):
wkv[t] = (exp(w) * k[t] @ v[t] + a[t] * aa[t]) / (exp(w) * k[t] + a[t] * ab[t])
aa[t+1] = exp(w) * k[t] @ v[t] + exp(-u) * aa[t]
ab[t+1] = exp(w) * k[t] + exp(-u) * ab[t]
**完整的 Time-Mixing 实现**:
python
class RWKV_TimeMix(nn.Module):
def __init__(self, d_model, n_layer):
super().__init__()
self.d_model = d_model
# 线性投影
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.receptance = nn.Linear(d_model, d_model, bias=False)
self.output = nn.Linear(d_model, d_model, bias=False)
# Time-mixing 参数
self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model))
# 时间衰减和奖励
self.time_decay = nn.Parameter(torch.ones(d_model)) # w
self.time_first = nn.Parameter(torch.ones(d_model)) # u
def forward(self, x, state=None):
B, T, C = x.shape
# 时间偏移混合 (与前一个 token 插值)
if state is None:
state = torch.zeros(B, C, 3, device=x.device) # [aa, ab, x_prev]
x_prev = state[:, :, 2].unsqueeze(1) # 前一个 x
xk = x * self.time_mix_k + x_prev * (1 - self.time_mix_k)
xv = x * self.time_mix_v + x_prev * (1 - self.time_mix_v)
xr = x * self.time_mix_r + x_prev * (1 - self.time_mix_r)
# 计算 k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
# WKV 计算 (可并行或顺序)
wkv = self.wkv(k, v, state[:, :, :2])
# 应用 receptance 门控和输出投影
out = self.output(torch.sigmoid(r) * wkv)
# 更新状态
new_state = torch.stack([state_aa, state_ab, x[:, -1]], dim=2)
return out, new_state
d