[ PROMPT_NODE_22360 ]
wanda
[ SKILL_DOCUMENTATION ]
# Wanda:基于权重和激活值的剪枝
基于 ICLR 2024 论文 (arXiv 2306.11695) - A Simple and Effective Pruning Approach for Large Language Models
## 概述
**来源**: https://arxiv.org/abs/2306.11695
**会议**: ICLR 2024
**GitHub**: https://github.com/locuslab/wanda
Wanda 通过权重幅度 × 输入激活值对大语言模型进行剪枝,实现 50% 的稀疏度且精度损失 <1%,无需重新训练。
## 核心创新
### 剪枝准则
**关键见解**: 权重重要性 = 幅度 × 使用频率
python
importance(w_ij) = |w_ij| × ||X_i||
其中:
- w_ij: 连接输入 i 到输出 j 的权重
- X_i: 维度 i 的输入激活范数
- ||·||: L2 范数
**直觉**:
- 大权重幅度 → 重要参数
- 高激活值 → 频繁使用的维度
- 乘积同时捕捉了这两个因素
### 与幅度剪枝的对比
**幅度剪枝** (基准):
python
importance = |weight| # 仅考虑权重大小
**Wanda**:
python
importance = |weight| × activation # 同时考虑了使用频率
**示例**:
权重 A: 幅度=0.5, 激活值=0.1 → 重要性=0.05
权重 B: 幅度=0.3, 激活值=0.8 → 重要性=0.24
幅度剪枝: 保留 A (权重更大)
Wanda: 保留 B (整体更重要) ✓
## 算法
### 一次性剪枝
python
import torch
from transformers import AutoModelForCausalLM
def wanda_prune(model, calib_data, sparsity=0.5):
"""
Wanda 剪枝算法。
步骤:
1. 在校准数据上收集激活统计信息
2. 计算重要性 = |权重| × 激活值
3. 剪除重要性最低的权重
4. 返回剪枝后的模型(无需重新训练!)
"""
# 步骤 1: 收集激活值
activations = {}
def activation_hook(name):
def hook(module, input, output):
# 存储输入激活范数
X = input[0].detach()
# 每个输入维度的范数
act_norm = X.abs().mean(dim=0) # 在批次/序列上取平均
if name in activations:
activations[name] += act_norm
else:
activations[name] = act_norm
return hook
# 注册钩子
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hook = module.register_forward_hook(activation_hook(name))
hooks.append(hook)
# 运行校准
model.eval()
with torch.no_grad():
for batch in calib_data: