[ PROMPT_NODE_27104 ]
training_evaluation
[ SKILL_DOCUMENTATION ]
# PyHealth 训练、评估与可解释性
## 概述
PyHealth 提供了全面的工具,用于临床应用中的模型训练、预测评估、确保模型可靠性以及解释结果。
## Trainer 类
### 核心功能
`Trainer` 类通过 PyTorch 集成管理完整的模型训练和评估工作流。
**初始化:**
python
from pyhealth.trainer import Trainer
trainer = Trainer(
model=model, # PyHealth 或 PyTorch 模型
device="cuda", # 或 "cpu"
)
### 训练
**train() 方法**
通过全面的监控和检查点保存功能来训练模型。
**参数:**
- `train_dataloader`: 训练数据加载器
- `val_dataloader`: 验证数据加载器 (可选)
- `test_dataloader`: 测试数据加载器 (可选)
- `epochs`: 训练轮数
- `optimizer`: 优化器实例或类
- `learning_rate`: 学习率 (默认: 1e-3)
- `weight_decay`: L2 正则化 (默认: 0)
- `max_grad_norm`: 梯度裁剪阈值
- `monitor`: 监控指标 (例如: "pr_auc_score")
- `monitor_criterion`: "max" 或 "min"
- `save_path`: 检查点保存目录
**用法:**
python
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
epochs=50,
optimizer=torch.optim.Adam,
learning_rate=1e-3,
weight_decay=1e-5,
max_grad_norm=5.0,
monitor="pr_auc_score",
monitor_criterion="max",
save_path="./checkpoints"
)
**训练特性:**
1. **自动检查点 (Automatic Checkpointing)**: 根据监控指标保存最佳模型
2. **早停 (Early Stopping)**: 如果性能不再提升则停止训练
3. **梯度裁剪 (Gradient Clipping)**: 防止梯度爆炸
4. **进度跟踪 (Progress Tracking)**: 显示训练进度和指标
5. **多 GPU 支持 (Multi-GPU Support)**: 自动设备分配
### 推理
**inference() 方法**
在数据集上执行预测。
**参数:**
- `dataloader`: 推理数据加载器
- `additional_outputs`: 需要返回的其他输出列表
- `return_patient_ids`: 返回患者标识符
**用法:**
python
predictions = trainer.inference(
dataloader=test_loader,
additional_outputs=["attention_weights", "embeddings"],
return_patient_ids=True
)
**返回:**
- `y_pred`: 模型预测结果
- `y_true`: 真实标签
- `patient_ids`: 患者标识符 (如果请求)
- 其他输出 (如果指定)
### 评估
**evaluate() 方法**
计算全面的 e