# -*- coding: utf-8 -*- """ STDEN训练器 负责模型的训练、验证和评估 """ import os import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.tensorboard import SummaryWriter import numpy as np import logging from pathlib import Path import json from model.stden_model import STDENModel from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss class STDENTrainer: """STDEN训练器主类""" def __init__(self, config, dataloader): """ 初始化训练器 Args: config: 配置字典 dataloader: 数据加载器实例 """ self.config = config self.dataloader = dataloader self.logger = logging.getLogger('STDEN') # 设置设备 self.device = torch.device(config['device']) # 模型配置 self.model_config = config['model'] self.train_config = config['train'] # 创建模型 self.model = self._create_model() # 创建优化器和学习率调度器 self.optimizer, self.scheduler = self._create_optimizer() # 设置损失函数 self.criterion = masked_mae_loss # 创建检查点目录 self.checkpoint_dir = self._create_checkpoint_dir() # 创建TensorBoard写入器 self.writer = self._create_tensorboard_writer() # 训练状态 self.current_epoch = 0 self.best_val_loss = float('inf') self.patience_counter = 0 self.logger.info("训练器初始化完成") def _create_model(self): """创建STDEN模型""" # 获取邻接矩阵 adj_matrix = self.config['adj_matrix'] # 创建模型 model = STDENModel( adj_matrix=adj_matrix, logger=self.logger, **self.model_config ) # 移动到指定设备 model = model.to(self.device) self.logger.info(f"模型创建完成,参数数量: {sum(p.numel() for p in model.parameters())}") return model def _create_optimizer(self): """创建优化器和学习率调度器""" # 获取训练配置 base_lr = self.train_config['base_lr'] optimizer_name = self.train_config.get('optimizer', 'adam').lower() # 创建优化器 if optimizer_name == 'adam': optimizer = optim.Adam(self.model.parameters(), lr=base_lr) elif optimizer_name == 'sgd': optimizer = optim.SGD(self.model.parameters(), lr=base_lr) else: raise ValueError(f"不支持的优化器: {optimizer_name}") # 创建学习率调度器 scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=self.train_config.get('lr_decay_ratio', 0.1), patience=self.train_config.get('patience', 20), min_lr=self.train_config.get('min_learning_rate', 1e-6), verbose=True ) return optimizer, scheduler def _create_checkpoint_dir(self): """创建检查点目录""" checkpoint_dir = Path("checkpoints") / f"experiment_{int(time.time())}" checkpoint_dir.mkdir(parents=True, exist_ok=True) # 保存配置 config_path = checkpoint_dir / "config.json" with open(config_path, 'w', encoding='utf-8') as f: json.dump(self.config, f, indent=2, ensure_ascii=False) return checkpoint_dir def _create_tensorboard_writer(self): """创建TensorBoard写入器""" log_dir = Path("runs") / f"experiment_{int(time.time())}" return SummaryWriter(str(log_dir)) def train(self): """训练模型""" self.logger.info("开始训练模型") # 获取训练配置 epochs = self.train_config['epochs'] patience = self.train_config['patience'] test_every_n_epochs = self.train_config.get('test_every_n_epochs', 5) # 获取数据加载器 train_loader = self.dataloader.train_loader val_loader = self.dataloader.val_loader for epoch in range(epochs): self.current_epoch = epoch # 训练一个epoch train_loss = self._train_epoch(train_loader) # 验证 val_loss = self._validate_epoch(val_loader) # 更新学习率 self.scheduler.step(val_loss) # 记录到TensorBoard self.writer.add_scalar('Loss/Train', train_loss, epoch) self.writer.add_scalar('Loss/Validation', val_loss, epoch) self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch) # 打印训练信息 self.logger.info( f"Epoch {epoch+1}/{epochs} - " f"Train Loss: {train_loss:.6f}, " f"Val Loss: {val_loss:.6f}, " f"LR: {self.optimizer.param_groups[0]['lr']:.6f}" ) # 保存最佳模型 if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.patience_counter = 0 self._save_checkpoint(epoch, is_best=True) self.logger.info(f"新的最佳验证损失: {val_loss:.6f}") else: self.patience_counter += 1 # 定期保存检查点 if (epoch + 1) % 10 == 0: self._save_checkpoint(epoch, is_best=False) # 定期测试 if (epoch + 1) % test_every_n_epochs == 0: test_metrics = self._test_epoch() self.logger.info(f"测试指标: {test_metrics}") # 早停检查 if self.patience_counter >= patience: self.logger.info(f"早停触发,在epoch {epoch+1}") break # 训练完成 self.logger.info("训练完成") self.writer.close() # 最终测试 final_test_metrics = self._test_epoch() self.logger.info(f"最终测试指标: {final_test_metrics}") def _train_epoch(self, train_loader): """训练一个epoch""" self.model.train() total_loss = 0.0 num_batches = 0 for batch_idx, (x, y) in enumerate(train_loader): # 移动数据到设备 x = x.to(self.device) y = y.to(self.device) # 前向传播 self.optimizer.zero_grad() output = self.model(x) # 计算损失 loss = self.criterion(output, y) # 反向传播 loss.backward() # 梯度裁剪 max_grad_norm = self.train_config.get('max_grad_norm', 5.0) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) # 更新参数 self.optimizer.step() total_loss += loss.item() num_batches += 1 return total_loss / num_batches def _validate_epoch(self, val_loader): """验证一个epoch""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for x, y in val_loader: x = x.to(self.device) y = y.to(self.device) output = self.model(x) loss = self.criterion(output, y) total_loss += loss.item() num_batches += 1 return total_loss / num_batches def _test_epoch(self): """测试模型""" self.model.eval() test_loader = self.dataloader.test_loader total_mae = 0.0 total_mape = 0.0 total_rmse = 0.0 num_batches = 0 with torch.no_grad(): for x, y in test_loader: x = x.to(self.device) y = y.to(self.device) output = self.model(x) # 计算各种指标 mae = masked_mae_loss(output, y) mape = masked_mape_loss(output, y) rmse = masked_rmse_loss(output, y) total_mae += mae.item() total_mape += mape.item() total_rmse += rmse.item() num_batches += 1 metrics = { 'MAE': total_mae / num_batches, 'MAPE': total_mape / num_batches, 'RMSE': total_rmse / num_batches } return metrics def _save_checkpoint(self, epoch, is_best=False): """保存检查点""" checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_val_loss': self.best_val_loss, 'config': self.config } # 保存最新检查点 latest_path = self.checkpoint_dir / "latest.pth" torch.save(checkpoint, latest_path) # 保存最佳检查点 if is_best: best_path = self.checkpoint_dir / "best.pth" torch.save(checkpoint, best_path) # 保存特定epoch的检查点 epoch_path = self.checkpoint_dir / f"epoch_{epoch+1}.pth" torch.save(checkpoint, epoch_path) self.logger.info(f"检查点已保存: {epoch_path}") def load_checkpoint(self, checkpoint_path): """加载检查点""" checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.current_epoch = checkpoint['epoch'] self.best_val_loss = checkpoint['best_val_loss'] self.logger.info(f"检查点已加载: {checkpoint_path}") def evaluate(self, save_predictions=False): """评估模型""" self.logger.info("开始评估模型") # 加载最佳模型 best_checkpoint_path = self.checkpoint_dir / "best.pth" if best_checkpoint_path.exists(): self.load_checkpoint(best_checkpoint_path) else: self.logger.warning("未找到最佳检查点,使用当前模型") # 测试 test_metrics = self._test_epoch() # 打印结果 self.logger.info("评估结果:") for metric_name, metric_value in test_metrics.items(): self.logger.info(f" {metric_name}: {metric_value:.6f}") # 保存预测结果 if save_predictions: self._save_predictions() return test_metrics def _save_predictions(self): """保存预测结果""" self.model.eval() test_loader = self.dataloader.test_loader predictions = [] targets = [] with torch.no_grad(): for x, y in test_loader: x = x.to(self.device) output = self.model(x) # 反标准化 scaler = self.dataloader.get_scaler() output_denorm = scaler.inverse_transform(output.cpu().numpy()) y_denorm = scaler.inverse_transform(y.numpy()) predictions.append(output_denorm) targets.append(y_denorm) # 合并所有批次 predictions = np.concatenate(predictions, axis=0) targets = np.concatenate(targets, axis=0) # 保存到文件 results_dir = self.checkpoint_dir / "results" results_dir.mkdir(exist_ok=True) np.save(results_dir / "predictions.npy", predictions) np.save(results_dir / "targets.npy", targets) self.logger.info(f"预测结果已保存到: {results_dir}") def __del__(self): """析构函数,确保资源被正确释放""" if hasattr(self, 'writer'): self.writer.close()