[ PROMPT_NODE_27430 ]
Stable Baselines3 回调
[ SKILL_DOCUMENTATION ]
# Stable Baselines3 回调系统
本文档提供了关于 Stable Baselines3 中回调系统的全面信息,用于监控和控制训练过程。
## 概述
回调是在训练期间特定点调用的函数,用于:
- 监控训练指标
- 保存检查点
- 实现提前停止
- 记录自定义指标
- 动态调整超参数
- 触发评估
## 内置回调
### EvalCallback
定期评估智能体并保存最佳模型。
python
from stable_baselines3.common.callbacks import EvalCallback
eval_callback = EvalCallback(
eval_env, # 独立的评估环境
best_model_save_path="./logs/best_model/", # 保存最佳模型的路径
log_path="./logs/eval/", # 保存评估日志的路径
eval_freq=10000, # 每 N 步评估一次
n_eval_episodes=5, # 每次评估的片段数
deterministic=True, # 使用确定性动作
render=False, # 评估期间是否渲染
verbose=1,
warn=True,
)
model.learn(total_timesteps=100000, callback=eval_callback)
**关键特性:**
- 根据平均奖励自动保存最佳模型
- 将评估指标记录到 TensorBoard
- 如果达到奖励阈值,可以停止训练
**重要提示:** 使用向量化训练环境时,请调整 `eval_freq`:
python
# 如果有 4 个并行环境,将 eval_freq 除以 n_envs
eval_freq = 10000 // 4 # 每 10000 个总环境步数评估一次
### CheckpointCallback
定期保存模型检查点。
python
from stable_baselines3.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(
save_freq=10000, # 每 N 步保存一次
save_path="./logs/checkpoints/", # 检查点目录
name_prefix="rl_model", # 检查点文件前缀
save_replay_buffer=True, # 保存重放缓冲区(仅限离策略算法)
save_vecnormalize=True, # 保存 VecNormalize 统计信息
verbose=2,
)
model.learn(total_timesteps=100000, callback=checkpoint_callback)
**输出文件:**
- `rl_model_10000_steps.zip` - 10k 步时的模型
- `rl_model_20000_steps.zip` - 20k 步时的模型
- 等等。
**重要提示:** 针对向量化环境调整 `save_freq`(除以 n_envs)。
### StopTrainingOn