Project-I/examples/train_example.py

116 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STDEN项目训练示例
演示如何使用新的项目结构进行模型训练
"""
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from lib.logger import setup_logger
from dataloader.stden_dataloader import STDENDataloader
from trainer.stden_trainer import STDENTrainer
import yaml
def main():
"""主函数示例"""
# 配置字典也可以从YAML文件加载
config = {
'model_name': 'stde_gt',
'log_level': 'INFO',
'log_base_dir': 'logs/example',
'device': 'cpu', # 或 'cuda'
'data': {
'dataset_dir': 'data/BJ_GM',
'batch_size': 16,
'val_batch_size': 16,
'graph_pkl_filename': 'data/sensor_graph/adj_GM.npy'
},
'model': {
'seq_len': 12,
'horizon': 12,
'input_dim': 1,
'output_dim': 1,
'latent_dim': 4,
'n_traj_samples': 3,
'ode_method': 'dopri5',
'odeint_atol': 0.00001,
'odeint_rtol': 0.00001,
'rnn_units': 64,
'num_rnn_layers': 1,
'gcn_step': 2,
'filter_type': 'default',
'recg_type': 'gru',
'save_latent': False,
'nfe': False,
'l1_decay': 0
},
'train': {
'base_lr': 0.01,
'dropout': 0,
'load': 0,
'epoch': 0,
'epochs': 50, # 减少训练轮数用于示例
'epsilon': 1.0e-3,
'lr_decay_ratio': 0.1,
'max_grad_norm': 5,
'min_learning_rate': 2.0e-06,
'optimizer': 'adam',
'patience': 10,
'steps': [10, 20, 30],
'test_every_n_epochs': 5
}
}
try:
# 设置日志
logger = setup_logger(config)
logger.info("开始STDEN项目训练示例")
# 注意:这里需要实际的邻接矩阵数据
# 为了示例,我们创建一个虚拟的邻接矩阵
import numpy as np
config['adj_matrix'] = np.random.rand(10, 10) # 10x10的随机邻接矩阵
# 创建数据加载器
logger.info("创建数据加载器...")
try:
dataloader = STDENDataloader(config)
logger.info("数据加载器创建成功")
except FileNotFoundError as e:
logger.warning(f"数据加载器创建失败(预期行为,因为示例中没有实际数据): {e}")
logger.info("继续演示项目结构...")
return
# 创建训练器
logger.info("创建训练器...")
trainer = STDENTrainer(config, dataloader)
# 开始训练
logger.info("开始训练...")
trainer.train()
# 评估模型
logger.info("开始评估...")
metrics = trainer.evaluate(save_predictions=True)
logger.info("训练示例完成!")
except Exception as e:
logger.error(f"训练过程中发生错误: {e}")
raise
if __name__ == '__main__':
main()