[ PROMPT_NODE_22742 ]
fsdp
[ SKILL_DOCUMENTATION ]
# TorchTitan 中的 FSDP2
## 为什么选择 FSDP2?
FSDP2 是 PyTorch 完全分片数据并行 (FSDP) API 的重写版本,移除了 `FlatParameter` 抽象,以实现更好的可组合性和更简单的实现。
### 相比 FSDP1 的关键改进
- **基于 DTensor 的分片**:分片参数是 dim-0 上的 `DTensor`,支持轻松操作和无需通信的分片状态字典 (state dict)
- **更好的内存管理**:通过避免 `recordStream`,实现了确定性且更低的 GPU 内存占用 (降低 7%)
- **简化 API**:参数更少,无需包装类
### 性能
在 8x H100s 的 Llama-7B 上,FSDP2 实现了比 FSDP1 更高的 MFU,峰值内存降低了 7%,且保持相同的损失曲线。
## API 参考
python
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
@contract(state_cls=FSDPState)
def fully_shard(
module: nn.Module,
*,
mesh: Optional[DeviceMesh] = None,
reshard_after_forward: Union[bool, int] = True,
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
offload_policy: OffloadPolicy = OffloadPolicy(),
) -> nn.Module:
## 分片策略 (ZeRO 等效项)
| FSDP2 配置 | FSDP1 等效项 | DeepSpeed |
|---------------------|------------------|-----------|
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
## Meta 设备初始化
FSDP2 支持在分片后将张量具体化到 GPU 上:
python
# 在 meta 设备上初始化 (无内存占用)
with torch.device("meta"):
model = Transformer()
# 应用 FSDP2 分片
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
# 参数仍处于 meta 设备
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# 在 GPU 上分配分片参数
model.to_empty(device="cuda")
# 初始化权重
model.init_weights()
## 状态字典 (State Dict) 的差异
| 操作 | FSDP1 | FSDP2 |
|-----------|-------|-------|
| `model.state_dict()` | 完整状态字典 | 分片状态字典 (无通信) |
| `optim.state_dict()` | 本地状态字典 | 分片状态字典 (无通信) |
| `summon_full_params()` | 支持 | 使用 `DTensor` API 如 `full_tensor()` |
| 梯度裁剪