384 lines
12 KiB
Python
384 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
STDEN训练器
|
||
负责模型的训练、验证和评估
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
import numpy as np
|
||
import logging
|
||
from pathlib import Path
|
||
import json
|
||
|
||
from model.stden_model import STDENModel
|
||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
|
||
|
||
|
||
class STDENTrainer:
|
||
"""STDEN训练器主类"""
|
||
|
||
def __init__(self, config, dataloader):
|
||
"""
|
||
初始化训练器
|
||
|
||
Args:
|
||
config: 配置字典
|
||
dataloader: 数据加载器实例
|
||
"""
|
||
self.config = config
|
||
self.dataloader = dataloader
|
||
self.logger = logging.getLogger('STDEN')
|
||
|
||
# 设置设备
|
||
self.device = torch.device(config['device'])
|
||
|
||
# 模型配置
|
||
self.model_config = config['model']
|
||
self.train_config = config['train']
|
||
|
||
# 创建模型
|
||
self.model = self._create_model()
|
||
|
||
# 创建优化器和学习率调度器
|
||
self.optimizer, self.scheduler = self._create_optimizer()
|
||
|
||
# 设置损失函数
|
||
self.criterion = masked_mae_loss
|
||
|
||
# 创建检查点目录
|
||
self.checkpoint_dir = self._create_checkpoint_dir()
|
||
|
||
# 创建TensorBoard写入器
|
||
self.writer = self._create_tensorboard_writer()
|
||
|
||
# 训练状态
|
||
self.current_epoch = 0
|
||
self.best_val_loss = float('inf')
|
||
self.patience_counter = 0
|
||
|
||
self.logger.info("训练器初始化完成")
|
||
|
||
def _create_model(self):
|
||
"""创建STDEN模型"""
|
||
# 获取邻接矩阵
|
||
adj_matrix = self.config['adj_matrix']
|
||
|
||
# 创建模型
|
||
model = STDENModel(
|
||
adj_matrix=adj_matrix,
|
||
logger=self.logger,
|
||
**self.model_config
|
||
)
|
||
|
||
# 移动到指定设备
|
||
model = model.to(self.device)
|
||
|
||
self.logger.info(f"模型创建完成,参数数量: {sum(p.numel() for p in model.parameters())}")
|
||
return model
|
||
|
||
def _create_optimizer(self):
|
||
"""创建优化器和学习率调度器"""
|
||
# 获取训练配置
|
||
base_lr = self.train_config['base_lr']
|
||
optimizer_name = self.train_config.get('optimizer', 'adam').lower()
|
||
|
||
# 创建优化器
|
||
if optimizer_name == 'adam':
|
||
optimizer = optim.Adam(self.model.parameters(), lr=base_lr)
|
||
elif optimizer_name == 'sgd':
|
||
optimizer = optim.SGD(self.model.parameters(), lr=base_lr)
|
||
else:
|
||
raise ValueError(f"不支持的优化器: {optimizer_name}")
|
||
|
||
# 创建学习率调度器
|
||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||
optimizer,
|
||
mode='min',
|
||
factor=self.train_config.get('lr_decay_ratio', 0.1),
|
||
patience=self.train_config.get('patience', 20),
|
||
min_lr=self.train_config.get('min_learning_rate', 1e-6),
|
||
verbose=True
|
||
)
|
||
|
||
return optimizer, scheduler
|
||
|
||
def _create_checkpoint_dir(self):
|
||
"""创建检查点目录"""
|
||
checkpoint_dir = Path("checkpoints") / f"experiment_{int(time.time())}"
|
||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 保存配置
|
||
config_path = checkpoint_dir / "config.json"
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.config, f, indent=2, ensure_ascii=False)
|
||
|
||
return checkpoint_dir
|
||
|
||
def _create_tensorboard_writer(self):
|
||
"""创建TensorBoard写入器"""
|
||
log_dir = Path("runs") / f"experiment_{int(time.time())}"
|
||
return SummaryWriter(str(log_dir))
|
||
|
||
def train(self):
|
||
"""训练模型"""
|
||
self.logger.info("开始训练模型")
|
||
|
||
# 获取训练配置
|
||
epochs = self.train_config['epochs']
|
||
patience = self.train_config['patience']
|
||
test_every_n_epochs = self.train_config.get('test_every_n_epochs', 5)
|
||
|
||
# 获取数据加载器
|
||
train_loader = self.dataloader.train_loader
|
||
val_loader = self.dataloader.val_loader
|
||
|
||
for epoch in range(epochs):
|
||
self.current_epoch = epoch
|
||
|
||
# 训练一个epoch
|
||
train_loss = self._train_epoch(train_loader)
|
||
|
||
# 验证
|
||
val_loss = self._validate_epoch(val_loader)
|
||
|
||
# 更新学习率
|
||
self.scheduler.step(val_loss)
|
||
|
||
# 记录到TensorBoard
|
||
self.writer.add_scalar('Loss/Train', train_loss, epoch)
|
||
self.writer.add_scalar('Loss/Validation', val_loss, epoch)
|
||
self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)
|
||
|
||
# 打印训练信息
|
||
self.logger.info(
|
||
f"Epoch {epoch+1}/{epochs} - "
|
||
f"Train Loss: {train_loss:.6f}, "
|
||
f"Val Loss: {val_loss:.6f}, "
|
||
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}"
|
||
)
|
||
|
||
# 保存最佳模型
|
||
if val_loss < self.best_val_loss:
|
||
self.best_val_loss = val_loss
|
||
self.patience_counter = 0
|
||
self._save_checkpoint(epoch, is_best=True)
|
||
self.logger.info(f"新的最佳验证损失: {val_loss:.6f}")
|
||
else:
|
||
self.patience_counter += 1
|
||
|
||
# 定期保存检查点
|
||
if (epoch + 1) % 10 == 0:
|
||
self._save_checkpoint(epoch, is_best=False)
|
||
|
||
# 定期测试
|
||
if (epoch + 1) % test_every_n_epochs == 0:
|
||
test_metrics = self._test_epoch()
|
||
self.logger.info(f"测试指标: {test_metrics}")
|
||
|
||
# 早停检查
|
||
if self.patience_counter >= patience:
|
||
self.logger.info(f"早停触发,在epoch {epoch+1}")
|
||
break
|
||
|
||
# 训练完成
|
||
self.logger.info("训练完成")
|
||
self.writer.close()
|
||
|
||
# 最终测试
|
||
final_test_metrics = self._test_epoch()
|
||
self.logger.info(f"最终测试指标: {final_test_metrics}")
|
||
|
||
def _train_epoch(self, train_loader):
|
||
"""训练一个epoch"""
|
||
self.model.train()
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
|
||
for batch_idx, (x, y) in enumerate(train_loader):
|
||
# 移动数据到设备
|
||
x = x.to(self.device)
|
||
y = y.to(self.device)
|
||
|
||
# 前向传播
|
||
self.optimizer.zero_grad()
|
||
output = self.model(x)
|
||
|
||
# 计算损失
|
||
loss = self.criterion(output, y)
|
||
|
||
# 反向传播
|
||
loss.backward()
|
||
|
||
# 梯度裁剪
|
||
max_grad_norm = self.train_config.get('max_grad_norm', 5.0)
|
||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
|
||
|
||
# 更新参数
|
||
self.optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
|
||
return total_loss / num_batches
|
||
|
||
def _validate_epoch(self, val_loader):
|
||
"""验证一个epoch"""
|
||
self.model.eval()
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
|
||
with torch.no_grad():
|
||
for x, y in val_loader:
|
||
x = x.to(self.device)
|
||
y = y.to(self.device)
|
||
|
||
output = self.model(x)
|
||
loss = self.criterion(output, y)
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
|
||
return total_loss / num_batches
|
||
|
||
def _test_epoch(self):
|
||
"""测试模型"""
|
||
self.model.eval()
|
||
test_loader = self.dataloader.test_loader
|
||
|
||
total_mae = 0.0
|
||
total_mape = 0.0
|
||
total_rmse = 0.0
|
||
num_batches = 0
|
||
|
||
with torch.no_grad():
|
||
for x, y in test_loader:
|
||
x = x.to(self.device)
|
||
y = y.to(self.device)
|
||
|
||
output = self.model(x)
|
||
|
||
# 计算各种指标
|
||
mae = masked_mae_loss(output, y)
|
||
mape = masked_mape_loss(output, y)
|
||
rmse = masked_rmse_loss(output, y)
|
||
|
||
total_mae += mae.item()
|
||
total_mape += mape.item()
|
||
total_rmse += rmse.item()
|
||
num_batches += 1
|
||
|
||
metrics = {
|
||
'MAE': total_mae / num_batches,
|
||
'MAPE': total_mape / num_batches,
|
||
'RMSE': total_rmse / num_batches
|
||
}
|
||
|
||
return metrics
|
||
|
||
def _save_checkpoint(self, epoch, is_best=False):
|
||
"""保存检查点"""
|
||
checkpoint = {
|
||
'epoch': epoch,
|
||
'model_state_dict': self.model.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||
'best_val_loss': self.best_val_loss,
|
||
'config': self.config
|
||
}
|
||
|
||
# 保存最新检查点
|
||
latest_path = self.checkpoint_dir / "latest.pth"
|
||
torch.save(checkpoint, latest_path)
|
||
|
||
# 保存最佳检查点
|
||
if is_best:
|
||
best_path = self.checkpoint_dir / "best.pth"
|
||
torch.save(checkpoint, best_path)
|
||
|
||
# 保存特定epoch的检查点
|
||
epoch_path = self.checkpoint_dir / f"epoch_{epoch+1}.pth"
|
||
torch.save(checkpoint, epoch_path)
|
||
|
||
self.logger.info(f"检查点已保存: {epoch_path}")
|
||
|
||
def load_checkpoint(self, checkpoint_path):
|
||
"""加载检查点"""
|
||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
self.current_epoch = checkpoint['epoch']
|
||
self.best_val_loss = checkpoint['best_val_loss']
|
||
|
||
self.logger.info(f"检查点已加载: {checkpoint_path}")
|
||
|
||
def evaluate(self, save_predictions=False):
|
||
"""评估模型"""
|
||
self.logger.info("开始评估模型")
|
||
|
||
# 加载最佳模型
|
||
best_checkpoint_path = self.checkpoint_dir / "best.pth"
|
||
if best_checkpoint_path.exists():
|
||
self.load_checkpoint(best_checkpoint_path)
|
||
else:
|
||
self.logger.warning("未找到最佳检查点,使用当前模型")
|
||
|
||
# 测试
|
||
test_metrics = self._test_epoch()
|
||
|
||
# 打印结果
|
||
self.logger.info("评估结果:")
|
||
for metric_name, metric_value in test_metrics.items():
|
||
self.logger.info(f" {metric_name}: {metric_value:.6f}")
|
||
|
||
# 保存预测结果
|
||
if save_predictions:
|
||
self._save_predictions()
|
||
|
||
return test_metrics
|
||
|
||
def _save_predictions(self):
|
||
"""保存预测结果"""
|
||
self.model.eval()
|
||
test_loader = self.dataloader.test_loader
|
||
|
||
predictions = []
|
||
targets = []
|
||
|
||
with torch.no_grad():
|
||
for x, y in test_loader:
|
||
x = x.to(self.device)
|
||
output = self.model(x)
|
||
|
||
# 反标准化
|
||
scaler = self.dataloader.get_scaler()
|
||
output_denorm = scaler.inverse_transform(output.cpu().numpy())
|
||
y_denorm = scaler.inverse_transform(y.numpy())
|
||
|
||
predictions.append(output_denorm)
|
||
targets.append(y_denorm)
|
||
|
||
# 合并所有批次
|
||
predictions = np.concatenate(predictions, axis=0)
|
||
targets = np.concatenate(targets, axis=0)
|
||
|
||
# 保存到文件
|
||
results_dir = self.checkpoint_dir / "results"
|
||
results_dir.mkdir(exist_ok=True)
|
||
|
||
np.save(results_dir / "predictions.npy", predictions)
|
||
np.save(results_dir / "targets.npy", targets)
|
||
|
||
self.logger.info(f"预测结果已保存到: {results_dir}")
|
||
|
||
def __del__(self):
|
||
"""析构函数,确保资源被正确释放"""
|
||
if hasattr(self, 'writer'):
|
||
self.writer.close()
|