[ PROMPT_NODE_22740 ]
float8
[ SKILL_DOCUMENTATION ]
# TorchTitan 中的 Float8 训练
对于 GEMM(通用矩阵乘法)足够大以至于 FP8 TensorCore 加速超过动态量化开销的模型,Float8 训练可提供显著的加速。
## 硬件要求
- NVIDIA H100 或更新的 GPU (FP8 Tensor Cores)
- 用于 MXFP8 训练的 Blackwell GPU
## 安装
bash
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
## 用法:张量级缩放 (Tensorwise Scaling)
使用张量级动态缩放的标准 Float8:
bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
--model.converters="quantize.linear.float8"
--quantize.linear.float8.enable_fsdp_float8_all_gather
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp
--compile.enable
### 关键参数
| 参数 | 描述 |
|----------|-------------|
| `--model.converters="quantize.linear.float8"` | 将 `nn.Linear` 替换为 `Float8Linear` |
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | 以 float8 通信以节省带宽 |
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | 为所有 AMAX/scales 执行单次 all-reduce |
| `--compile.enable` | 必需 - 融合 float8 缩放/转换内核 |
## 用法:行级缩放 (Rowwise Scaling)
比张量级缩放具有更高的精度:
bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
--model.converters="quantize.linear.float8"
--quantize.linear.float8.recipe_name rowwise
--compile.enable
## 层过滤
并非所有层都受益于 Float8。过滤掉小层:
bash
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
### 自动过滤
自动跳过太小而无法受益的层:
bash
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
阈值基于 H100 微基准测试,即加速比 > 开销。
## TOML 配置
toml
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "auto_filter_small_kn"]
[compile]
enable = true
components = ["model", "loss"]
## Float8 如何与分布式训练协同工作
### 单设备
在调用 `torch._scaled_mm` 之前,在 forward 内部将输入和权重转换为 float8:
python
# Float8 矩阵乘法需要缩放因子
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
### FSDP + Float8
1. 在 forward 之前将分片的高精度...