import yaml import torch import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer def run(config): init.init_seed(config["basic"]["seed"]) model = init.init_model(config) train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( config, normalizer=config["data"]["normalizer"], single=False ) loss = init.init_loss(config, scaler) optimizer, lr_scheduler = init.init_optimizer(model, config["train"]) init.create_logs(config) trainer = select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, config, lr_scheduler, extra_data, ) # 开始训练 match config["basic"]["mode"]: case "train": trainer.train() case "test": model.load_state_dict( torch.load( f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth", map_location=config["basic"]["device"], weights_only=True, ) ) trainer.test( model.to(config["basic"]["device"]), trainer.args, test_loader, scaler, trainer.logger, ) case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") if __name__ == "__main__": # 指定模型 model_list = ["HI"] # 指定数据集 dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] device = "cuda:0" # 指定设备 seed = 2023 # 随机种子 for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" with open(config_path, "r") as file: config = yaml.safe_load(file) config["basic"]["device"] = device config["basic"]["seed"] = seed print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") print(f"config: {config}") run(config)