[ PROMPT_NODE_22738 ]
custom-models
[ SKILL_DOCUMENTATION ]
# 向 TorchTitan 添加自定义模型
本指南解释了如何按照既定模式向 TorchTitan 添加新模型。
## 目录结构
torchtitan/models/your_model/
├── model/
│ ├── __init__.py
│ ├── args.py # 模型参数
│ ├── model.py # 模型定义
│ └── state_dict_adapter.py # HF 转换 (可选)
├── infra/
│ ├── __init__.py
│ ├── parallelize.py # TP, FSDP, 编译应用
│ └── pipeline.py # PP 应用 (可选)
├── train_configs/
│ ├── debug_model.toml
│ └── your_model_XB.toml
├── __init__.py # TrainSpec 注册
└── README.md
## 第 1 步:定义模型参数
继承自 `BaseModelArgs`:
python
# model/args.py
from torchtitan.protocols.model import BaseModelArgs
from dataclasses import dataclass
@dataclass
class YourModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
vocab_size: int = 128256
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
"""返回 (参数量, 每个 token 的 FLOPs) 用于吞吐量计算。"""
nparams = self.vocab_size * self.dim + ... # 计算参数量
flops = 6 * nparams # 近似值: 6 * 参数量 (前向+反向)
return nparams, flops
def update_from_config(self, job_config) -> "YourModelArgs":
"""从训练配置更新参数。"""
# 如有需要,从 job_config 覆盖特定参数
return self
## 第 2 步:定义模型
继承自 `ModelProtocol`:
python
# model/model.py
import torch.nn as nn
from torchtitan.protocols.model import ModelProtocol
from .args import YourModelArgs
class YourModel(ModelProtocol):
def __init__(self, args: YourModelArgs):
super().__init__()
self.args = args
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleDict({
str(i): TransformerBlock(args) for i in range(args.n_layers)
})
self.norm = RMSNorm(args.dim)
self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
h = self.tok_embeddings(tokens)
for layer in self.layers.values():
h = layer(h)
h = self.norm(h)
return self.output(h)
def init_weights(self):
"""递归初始化权重。"""
for module in self.modules():
if hasattr(module, 'init_weights') a