Project-I/run.py

106 lines
3.1 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STDEN项目主运行文件
根据模型名称自动调用对应的配置、数据加载器和训练器
"""
import argparse
import yaml
import torch
import numpy as np
import os
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.append(str(project_root))
from lib.logger import setup_logger
from lib.utils import load_graph_data
from trainer.stden_trainer import STDENTrainer
from dataloader.stden_dataloader import STDENDataloader
def load_config(config_path):
"""加载YAML配置文件"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def setup_environment(config):
"""设置环境变量和随机种子"""
# 设置随机种子
random_seed = config.get('random_seed', 2021)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# 设置设备
device = 'cuda' if torch.cuda.is_available() and not config.get('use_cpu_only', False) else 'cpu'
config['device'] = device
return config
def main():
parser = argparse.ArgumentParser(description='STDEN项目训练和评估')
parser.add_argument('--model_name', type=str, required=True,
choices=['stde_gt', 'stde_wrs', 'stde_zgc'],
help='模型名称,对应配置文件')
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'eval'],
help='运行模式:训练或评估')
parser.add_argument('--config_dir', type=str, default='configs',
help='配置文件目录')
parser.add_argument('--use_cpu_only', action='store_true',
help='仅使用CPU')
parser.add_argument('--save_pred', action='store_true',
help='保存预测结果(仅评估模式)')
args = parser.parse_args()
# 构建配置文件路径
config_path = Path(args.config_dir) / f"{args.model_name}.yaml"
if not config_path.exists():
print(f"错误:配置文件 {config_path} 不存在")
sys.exit(1)
# 加载配置
config = load_config(config_path)
config['use_cpu_only'] = args.use_cpu_only
config['mode'] = args.mode
# 设置环境
config = setup_environment(config)
# 设置日志
logger = setup_logger(config)
logger.info(f"开始运行 {args.model_name} 模型,模式:{args.mode}")
logger.info(f"使用设备:{config['device']}")
# 加载图数据
graph_pkl_filename = config['data']['graph_pkl_filename']
adj_matrix = load_graph_data(graph_pkl_filename)
config['adj_matrix'] = adj_matrix
# 创建数据加载器
dataloader = STDENDataloader(config)
# 创建训练器
trainer = STDENTrainer(config, dataloader)
# 根据模式执行
if args.mode == 'train':
trainer.train()
else: # eval mode
trainer.evaluate(save_predictions=args.save_pred)
logger.info(f"{args.model_name} 模型运行完成")
if __name__ == '__main__':
main()