#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ STDEN项目训练示例 演示如何使用新的项目结构进行模型训练 """ import sys from pathlib import Path # 添加项目根目录到Python路径 project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from lib.logger import setup_logger from dataloader.stden_dataloader import STDENDataloader from trainer.stden_trainer import STDENTrainer import yaml def main(): """主函数示例""" # 配置字典(也可以从YAML文件加载) config = { 'model_name': 'stde_gt', 'log_level': 'INFO', 'log_base_dir': 'logs/example', 'device': 'cpu', # 或 'cuda' 'data': { 'dataset_dir': 'data/BJ_GM', 'batch_size': 16, 'val_batch_size': 16, 'graph_pkl_filename': 'data/sensor_graph/adj_GM.npy' }, 'model': { 'seq_len': 12, 'horizon': 12, 'input_dim': 1, 'output_dim': 1, 'latent_dim': 4, 'n_traj_samples': 3, 'ode_method': 'dopri5', 'odeint_atol': 0.00001, 'odeint_rtol': 0.00001, 'rnn_units': 64, 'num_rnn_layers': 1, 'gcn_step': 2, 'filter_type': 'default', 'recg_type': 'gru', 'save_latent': False, 'nfe': False, 'l1_decay': 0 }, 'train': { 'base_lr': 0.01, 'dropout': 0, 'load': 0, 'epoch': 0, 'epochs': 50, # 减少训练轮数用于示例 'epsilon': 1.0e-3, 'lr_decay_ratio': 0.1, 'max_grad_norm': 5, 'min_learning_rate': 2.0e-06, 'optimizer': 'adam', 'patience': 10, 'steps': [10, 20, 30], 'test_every_n_epochs': 5 } } try: # 设置日志 logger = setup_logger(config) logger.info("开始STDEN项目训练示例") # 注意:这里需要实际的邻接矩阵数据 # 为了示例,我们创建一个虚拟的邻接矩阵 import numpy as np config['adj_matrix'] = np.random.rand(10, 10) # 10x10的随机邻接矩阵 # 创建数据加载器 logger.info("创建数据加载器...") try: dataloader = STDENDataloader(config) logger.info("数据加载器创建成功") except FileNotFoundError as e: logger.warning(f"数据加载器创建失败(预期行为,因为示例中没有实际数据): {e}") logger.info("继续演示项目结构...") return # 创建训练器 logger.info("创建训练器...") trainer = STDENTrainer(config, dataloader) # 开始训练 logger.info("开始训练...") trainer.train() # 评估模型 logger.info("开始评估...") metrics = trainer.evaluate(save_predictions=True) logger.info("训练示例完成!") except Exception as e: logger.error(f"训练过程中发生错误: {e}") raise if __name__ == '__main__': main()