[ SKILL_DOCUMENTATION ]
# Flash Attention - 高效内存优化的注意力机制
## 快速开始
Flash Attention 通过 IO 感知平铺(tiling)和重计算,为 Transformer 注意力机制提供了 2-4 倍的加速和 10-20 倍的内存减少。
**PyTorch 原生(最简单,PyTorch 2.2+)**:
python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# 如果可用,自动使用 Flash Attention
out = F.scaled_dot_product_attention(q, k, v)
**flash-attn 库(更多功能)**:
bash
pip install flash-attn --no-build-isolation
python
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
## 常见工作流
### 工作流 1:在现有 PyTorch 模型中启用
复制此检查清单:
Flash Attention 集成:
- [ ] 第 1 步:检查 PyTorch 版本 (≥2.2)
- [ ] 第 2 步:启用 Flash Attention 后端
- [ ] 第 3 步:通过性能分析验证加速效果
- [ ] 第 4 步:测试准确度是否与基准一致
**第 1 步:检查 PyTorch 版本**
bash
python -c "import torch; print(torch.__version__)"
# 应 ≥2.2.0
如果 <2.2,请升级:
bash
pip install --upgrade torch
**第 2 步:启用 Flash Attention 后端**
替换标准注意力:
python
# 之前(标准注意力)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# 之后(Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
强制使用 Flash Attention 后端:
python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
**第 3 步:通过性能分析验证加速效果**
python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# 基准测试
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())