#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ STDEN项目主运行文件 根据模型名称自动调用对应的配置、数据加载器和训练器 """ import argparse import yaml import torch import numpy as np import os import sys from pathlib import Path # 添加项目根目录到Python路径 project_root = Path(__file__).parent sys.path.append(str(project_root)) from lib.logger import setup_logger from lib.utils import load_graph_data from trainer.stden_trainer import STDENTrainer from dataloader.stden_dataloader import STDENDataloader def load_config(config_path): """加载YAML配置文件""" with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config def setup_environment(config): """设置环境变量和随机种子""" # 设置随机种子 random_seed = config.get('random_seed', 2021) torch.manual_seed(random_seed) np.random.seed(random_seed) # 设置设备 device = 'cuda' if torch.cuda.is_available() and not config.get('use_cpu_only', False) else 'cpu' config['device'] = device return config def main(): parser = argparse.ArgumentParser(description='STDEN项目训练和评估') parser.add_argument('--model_name', type=str, required=True, choices=['stde_gt', 'stde_wrs', 'stde_zgc'], help='模型名称,对应配置文件') parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval'], help='运行模式:训练或评估') parser.add_argument('--config_dir', type=str, default='configs', help='配置文件目录') parser.add_argument('--use_cpu_only', action='store_true', help='仅使用CPU') parser.add_argument('--save_pred', action='store_true', help='保存预测结果(仅评估模式)') args = parser.parse_args() # 构建配置文件路径 config_path = Path(args.config_dir) / f"{args.model_name}.yaml" if not config_path.exists(): print(f"错误:配置文件 {config_path} 不存在") sys.exit(1) # 加载配置 config = load_config(config_path) config['use_cpu_only'] = args.use_cpu_only config['mode'] = args.mode # 设置环境 config = setup_environment(config) # 设置日志 logger = setup_logger(config) logger.info(f"开始运行 {args.model_name} 模型,模式:{args.mode}") logger.info(f"使用设备:{config['device']}") # 加载图数据 graph_pkl_filename = config['data']['graph_pkl_filename'] adj_matrix = load_graph_data(graph_pkl_filename) config['adj_matrix'] = adj_matrix # 创建数据加载器 dataloader = STDENDataloader(config) # 创建训练器 trainer = STDENTrainer(config, dataloader) # 根据模式执行 if args.mode == 'train': trainer.train() else: # eval mode trainer.evaluate(save_predictions=args.save_pred) logger.info(f"{args.model_name} 模型运行完成") if __name__ == '__main__': main()