116 lines
3.2 KiB
Python
116 lines
3.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
STDEN项目训练示例
|
||
演示如何使用新的项目结构进行模型训练
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.append(str(project_root))
|
||
|
||
from lib.logger import setup_logger
|
||
from dataloader.stden_dataloader import STDENDataloader
|
||
from trainer.stden_trainer import STDENTrainer
|
||
import yaml
|
||
|
||
|
||
def main():
|
||
"""主函数示例"""
|
||
|
||
# 配置字典(也可以从YAML文件加载)
|
||
config = {
|
||
'model_name': 'stde_gt',
|
||
'log_level': 'INFO',
|
||
'log_base_dir': 'logs/example',
|
||
'device': 'cpu', # 或 'cuda'
|
||
|
||
'data': {
|
||
'dataset_dir': 'data/BJ_GM',
|
||
'batch_size': 16,
|
||
'val_batch_size': 16,
|
||
'graph_pkl_filename': 'data/sensor_graph/adj_GM.npy'
|
||
},
|
||
|
||
'model': {
|
||
'seq_len': 12,
|
||
'horizon': 12,
|
||
'input_dim': 1,
|
||
'output_dim': 1,
|
||
'latent_dim': 4,
|
||
'n_traj_samples': 3,
|
||
'ode_method': 'dopri5',
|
||
'odeint_atol': 0.00001,
|
||
'odeint_rtol': 0.00001,
|
||
'rnn_units': 64,
|
||
'num_rnn_layers': 1,
|
||
'gcn_step': 2,
|
||
'filter_type': 'default',
|
||
'recg_type': 'gru',
|
||
'save_latent': False,
|
||
'nfe': False,
|
||
'l1_decay': 0
|
||
},
|
||
|
||
'train': {
|
||
'base_lr': 0.01,
|
||
'dropout': 0,
|
||
'load': 0,
|
||
'epoch': 0,
|
||
'epochs': 50, # 减少训练轮数用于示例
|
||
'epsilon': 1.0e-3,
|
||
'lr_decay_ratio': 0.1,
|
||
'max_grad_norm': 5,
|
||
'min_learning_rate': 2.0e-06,
|
||
'optimizer': 'adam',
|
||
'patience': 10,
|
||
'steps': [10, 20, 30],
|
||
'test_every_n_epochs': 5
|
||
}
|
||
}
|
||
|
||
try:
|
||
# 设置日志
|
||
logger = setup_logger(config)
|
||
logger.info("开始STDEN项目训练示例")
|
||
|
||
# 注意:这里需要实际的邻接矩阵数据
|
||
# 为了示例,我们创建一个虚拟的邻接矩阵
|
||
import numpy as np
|
||
config['adj_matrix'] = np.random.rand(10, 10) # 10x10的随机邻接矩阵
|
||
|
||
# 创建数据加载器
|
||
logger.info("创建数据加载器...")
|
||
try:
|
||
dataloader = STDENDataloader(config)
|
||
logger.info("数据加载器创建成功")
|
||
except FileNotFoundError as e:
|
||
logger.warning(f"数据加载器创建失败(预期行为,因为示例中没有实际数据): {e}")
|
||
logger.info("继续演示项目结构...")
|
||
return
|
||
|
||
# 创建训练器
|
||
logger.info("创建训练器...")
|
||
trainer = STDENTrainer(config, dataloader)
|
||
|
||
# 开始训练
|
||
logger.info("开始训练...")
|
||
trainer.train()
|
||
|
||
# 评估模型
|
||
logger.info("开始评估...")
|
||
metrics = trainer.evaluate(save_predictions=True)
|
||
|
||
logger.info("训练示例完成!")
|
||
|
||
except Exception as e:
|
||
logger.error(f"训练过程中发生错误: {e}")
|
||
raise
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|