[ PROMPT_NODE_22716 ]
training-guide
[ SKILL_DOCUMENTATION ]
# Mamba 训练指南
## 从零开始训练
### 环境配置
bash
# 安装依赖
pip install torch>=1.12.0 --extra-index-url https://download.pytorch.org/whl/cu116
pip install packaging ninja
pip install causal-conv1d>=1.1.0
pip install mamba-ssm
# 验证 CUDA
python -c "import torch; print(torch.cuda.is_available())"
### 基础训练循环
python
import torch
from mamba_ssm import Mamba
from torch.utils.data import DataLoader
# 模型设置
model = Mamba(
d_model=512,
d_state=16,
d_conv=4,
expand=2
).cuda()
# 优化器 (与 GPT 相同)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=6e-4,
betas=(0.9, 0.95),
weight_decay=0.1
)
# 训练循环
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.cuda(), targets.cuda()
# 前向传播
logits = model(inputs)
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
## 分布式训练
### 单节点多 GPU (DDP)
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化进程组
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# 包装模型
model = Mamba(...).cuda()
model = DDP(model, device_ids=[local_rank])
# 训练
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
for batch in dataloader:
loss = compute_loss(model, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
**启动命令**:
bash
torchrun --nproc_per_node=8 train.py
### 多节点训练
bash
# 节点 0 (主节点)
torchrun --nproc_per_node=8
--nnodes=4 --node_rank=0
--master_addr=$MASTER_ADDR --master_port=29500
train.py
# 节点 1-3 (工作节点)
torchrun --nproc_per_node=8
--nnodes=4 --node_rank=$NODE_RANK
--master_addr=$MASTER_ADDR --master_port=29500
train.py
## 混合精度训练
### BF16 (推荐)
python
from torch.cuda.amp import autocast, GradScaler
# BF16 (在 A100/H100 上无需 scaler)
for batch in dataloader:
with autocast(dtype=torch.bfloat16):
logits = model(inputs)
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
### FP16 (带梯度缩放)