[ PROMPT_NODE_22810 ]
Optimization Flash Attention 基准测试
[ SKILL_DOCUMENTATION ]
# 性能基准测试
## 内容
- 各 GPU 速度对比
- 内存使用分析
- 随序列长度的扩展性
- 训练与推理性能
- Flash Attention 版本对比
## 各 GPU 速度对比
### A100 80GB (Ampere)
**前向传播时间**(毫秒,batch=8, heads=32, dim=64):
| 序列长度 | 标准 | Flash Attn 2 | Flash Attn 3 | 加速比 (FA2) |
|------------|----------|--------------|--------------|---------------|
| 512 | 1.2 | 0.9 | N/A | 1.3x |
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
### H100 80GB (Hopper)
**前向传播时间**(毫秒,相同配置):
| 序列长度 | 标准 | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | 最佳加速比 |
|------------|----------|--------------|---------------------|--------------------|--------------|
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
**关键洞察**:H100 上使用 FP8 的 Flash Attention 3 可达到约 1.2 PFLOPS(理论峰值的 75%)。
### A10G 24GB (Ampere)
**前向传播时间**(毫秒,batch=4):
| 序列长度 | 标准 | Flash Attn 2 | 加速比 |
|------------|----------|--------------|---------|
| 512 | 2.1 | 1.6 | 1.3x |
| 1024 | 6.8 | 2.8 | 2.4x |
| 2048 | 25.9 | 9.4 | 2.8x |
| 4096 | 102.1 | 35.2 | 2.9x |
## 内存使用分析
### GPU 内存消耗 (batch=8, heads=32, dim=64)
**标准注意力内存**:
| 序列长度 | 注意力矩阵 | KV 缓存 | 总计 | 备注 |
|------------|------------------|----------|-------|-------|
| 512 | 8 MB | 32 MB | 40 MB | 可控 |
| 2048 | 128 MB | 128 MB | 256 MB | 增长中 |
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | 较大 |
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | 24GB GPU 上 OOM |
**Flash Attention 2 内存**:
| 序列长度 | 注意力 (片上) | KV 缓存 | 总计 | 减少量 |
|------------|---------------------|----------|-------|-----------|
| 512 | 0 MB (重计算) | 32 MB | 32 MB | 20% |
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
**关键洞察**:Flash Attention 不会实例化注意力矩阵,节省了 O(N²) 的内存。
### 内存扩展性对比
**Llama 2 7B 模型内存** (float16, batch=1):
| 上下文长度 | 标准 A