[ PROMPT_NODE_22888 ]
sft-training
[ SKILL_DOCUMENTATION ]
# SFT 训练指南
使用 TRL 进行监督微调 (SFT) 的完整指南,适用于指令微调和特定任务微调。
## 概述
SFT 在输入-输出对上训练模型以最小化交叉熵损失。用于:
- 指令遵循
- 特定任务微调
- 聊天机器人训练
- 领域自适应
## 数据集格式
### 格式 1: 提示词-补全 (Prompt-Completion)
[
{
"prompt": "法国的首都是哪里?",
"completion": "法国的首都是巴黎。"
}
]
### 格式 2: 对话式 (ChatML)
[
{
"messages": [
{"role": "user", "content": "什么是 Python?"},
{"role": "assistant", "content": "Python 是一种编程语言。"}
]
}
]
### 格式 3: 纯文本
[
{"text": "用户: 你好n助手: 你好!有什么我可以帮你的吗?"}
]
## 基础训练
python
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
# 加载模型
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
# 加载数据集
dataset = load_dataset("trl-lib/Capybara", split="train")
# 配置
config = SFTConfig(
output_dir="Qwen2.5-SFT",
per_device_train_batch_size=4,
num_train_epochs=1,
learning_rate=2e-5,
save_strategy="epoch"
)
# 训练
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
tokenizer=tokenizer
)
trainer.train()
## 聊天模板
自动应用聊天模板:
python
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset, # 消息格式
tokenizer=tokenizer
# 自动应用聊天模板
)
或手动应用:
python
def format_chat(example):
messages = example["messages"]
text = tokenizer.apply_chat_template(messages, tokenize=False)
return {"text": text}
dataset = dataset.map(format_chat)
## 打包 (Packing) 以提高效率
将多个序列打包在一起以最大化 GPU 利用率:
python
config = SFTConfig(
packing=True, # 启用打包
max_seq_length=2048,
dataset_text_field="text"
)
**优点**: 训练速度提升 2-3 倍
**权衡**: 批处理稍微复杂一些
## 多 GPU 训练
bash
accelerate launch --num_processes 4 train_sft.py
或使用配置:
python
config = SFTConfig(
output_dir="model-sft",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
n