106 lines
3.1 KiB
Python
106 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
STDEN项目主运行文件
|
|
根据模型名称自动调用对应的配置、数据加载器和训练器
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# 添加项目根目录到Python路径
|
|
project_root = Path(__file__).parent
|
|
sys.path.append(str(project_root))
|
|
|
|
from lib.logger import setup_logger
|
|
from lib.utils import load_graph_data
|
|
from trainer.stden_trainer import STDENTrainer
|
|
from dataloader.stden_dataloader import STDENDataloader
|
|
|
|
|
|
def load_config(config_path):
|
|
"""加载YAML配置文件"""
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
return config
|
|
|
|
|
|
def setup_environment(config):
|
|
"""设置环境变量和随机种子"""
|
|
# 设置随机种子
|
|
random_seed = config.get('random_seed', 2021)
|
|
torch.manual_seed(random_seed)
|
|
np.random.seed(random_seed)
|
|
|
|
# 设置设备
|
|
device = 'cuda' if torch.cuda.is_available() and not config.get('use_cpu_only', False) else 'cpu'
|
|
config['device'] = device
|
|
|
|
return config
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='STDEN项目训练和评估')
|
|
parser.add_argument('--model_name', type=str, required=True,
|
|
choices=['stde_gt', 'stde_wrs', 'stde_zgc'],
|
|
help='模型名称,对应配置文件')
|
|
parser.add_argument('--mode', type=str, default='train',
|
|
choices=['train', 'eval'],
|
|
help='运行模式:训练或评估')
|
|
parser.add_argument('--config_dir', type=str, default='configs',
|
|
help='配置文件目录')
|
|
parser.add_argument('--use_cpu_only', action='store_true',
|
|
help='仅使用CPU')
|
|
parser.add_argument('--save_pred', action='store_true',
|
|
help='保存预测结果(仅评估模式)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 构建配置文件路径
|
|
config_path = Path(args.config_dir) / f"{args.model_name}.yaml"
|
|
|
|
if not config_path.exists():
|
|
print(f"错误:配置文件 {config_path} 不存在")
|
|
sys.exit(1)
|
|
|
|
# 加载配置
|
|
config = load_config(config_path)
|
|
config['use_cpu_only'] = args.use_cpu_only
|
|
config['mode'] = args.mode
|
|
|
|
# 设置环境
|
|
config = setup_environment(config)
|
|
|
|
# 设置日志
|
|
logger = setup_logger(config)
|
|
logger.info(f"开始运行 {args.model_name} 模型,模式:{args.mode}")
|
|
logger.info(f"使用设备:{config['device']}")
|
|
|
|
# 加载图数据
|
|
graph_pkl_filename = config['data']['graph_pkl_filename']
|
|
adj_matrix = load_graph_data(graph_pkl_filename)
|
|
config['adj_matrix'] = adj_matrix
|
|
|
|
# 创建数据加载器
|
|
dataloader = STDENDataloader(config)
|
|
|
|
# 创建训练器
|
|
trainer = STDENTrainer(config, dataloader)
|
|
|
|
# 根据模式执行
|
|
if args.mode == 'train':
|
|
trainer.train()
|
|
else: # eval mode
|
|
trainer.evaluate(save_predictions=args.save_pred)
|
|
|
|
logger.info(f"{args.model_name} 模型运行完成")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|