[ PROMPT_NODE_27188 ]
Pytorch Lightning 回调
[ SKILL_DOCUMENTATION ]
# 回调 (Callbacks) - 综合指南
## 概述
回调允许在不干扰 LightningModule 研究代码的情况下,向训练过程添加任意独立的程序。它们在训练生命周期的特定钩子处执行自定义逻辑。
## 架构
Lightning 将训练逻辑组织为三个组件:
- **Trainer** - 工程基础设施
- **LightningModule** - 研究代码
- **Callbacks** - 非核心功能(监控、检查点保存、自定义行为)
## 创建自定义回调
基本结构:
python
from lightning.pytorch.callbacks import Callback
class MyCustomCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("训练开始!")
def on_train_end(self, trainer, pl_module):
print("训练结束!")
# 与 Trainer 一起使用
trainer = L.Trainer(callbacks=[MyCustomCallback()])
## 内置回调
### ModelCheckpoint
根据监控的指标保存模型。
**关键参数:**
- `dirpath` - 保存检查点的目录
- `filename` - 检查点文件名模式
- `monitor` - 要监控的指标
- `mode` - 监控指标的 "min" 或 "max"
- `save_top_k` - 保留的最佳模型数量
- `save_last` - 保存最后一个 epoch 的检查点
- `every_n_epochs` - 每 N 个 epoch 保存一次
- `save_on_train_epoch_end` - 在训练 epoch 结束时保存,而非验证结束时
**示例:**
python
from lightning.pytorch.callbacks import ModelCheckpoint
# 根据验证损失保存前 3 个模型
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}-{val_loss:.2f}",
monitor="val_loss",
mode="min",
save_top_k=3,
save_last=True
)
# 每 10 个 epoch 保存一次
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="model-{epoch:02d}",
every_n_epochs=10,
save_top_k=-1 # 保存所有
)
# 根据准确率保存最佳模型
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="best-model",
monitor="val_acc",
mode="max",
save_top_k=1
)
trainer = L.Trainer(callbacks=[checkpoint_callback])
**访问已保存的检查点:**
python
# 获取最佳模型路径
best_model_path = checkpoint_callback.best_model_path
# 获取最后一个检查点路径
last_checkpoint = checkpoint_callback.last_model_path
# 获取所有检查点路径
all_checkpoints = checkpoint_callback.best_k_models
### EarlyStopping
当监控的指标停止改善时停止训练。
**关键参数:**
- `