import yaml import torch import os import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer def run(config): init.init_seed(config["basic"]["seed"]) model = init.init_model(config) train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( config, normalizer=config["data"]["normalizer"], single=False ) loss = init.init_loss(config, scaler) optimizer, lr_scheduler = init.init_optimizer(model, config["train"]) init.create_logs(config) trainer = select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, config, lr_scheduler, extra_data, ) # 开始训练 match config["basic"]["mode"]: case "train": trainer.train() case "test": model.load_state_dict( torch.load( f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth", map_location=config["basic"]["device"], weights_only=True, ) ) trainer.test( model.to(config["basic"]["device"]), trainer.args, test_loader, scaler, trainer.logger, ) case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") def main(model_list, data, debug=False): # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) os.environ["TRY"] = str(debug) for model in model_list: for dataset in data: config_path = f"./config/{model}/{dataset}.yaml" # 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs config = read_config(config_path) print(f"\nRunning {model} on {dataset}") if os.environ.get("TRY") == "True": try: run(config) except Exception as e: import traceback import sys, traceback tb_lines = traceback.format_exc().splitlines() # 如果不是AssertionError,才打印完整traceback if not tb_lines[-1].startswith("AssertionError"): traceback.print_exc() print( f"\n===== {model} on {dataset} failed with error: {e} =====\n" ) else: run(config) def read_config(config_path): # 设置卡,种子,epochs with open(config_path, "r") as file: config = yaml.safe_load(file) # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 epochs = 10 # 训练轮数 # 拷贝项 config["basic"]["seed"] = seed config["train"]["epochs"] = epochs for x in ["basic", "model", "train"]: config[x]["device"] = device # few-shot 0.01-0.01-0.98 config["data"]["test_ratio"] = 0.98 config["data"]["val_ratio"] = 0.01 return config if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] model_list = ["STID"] # model_list = ["PatchTST"] air = ["AirQuality"] big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] mid_dataset = ["PEMS-BAY"] regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] test_dataset = ["AirQuality"] all_dataset = big_dataset + mid_dataset + regular_dataset dataset_list = regular_dataset main(model_list, dataset_list, debug=True)