64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
import yaml
|
|
import torch
|
|
|
|
import utils.initializer as init
|
|
from dataloader.loader_selector import get_dataloader
|
|
from trainer.trainer_selector import select_trainer
|
|
|
|
def run(config):
|
|
init.init_seed(config["basic"]["seed"])
|
|
model = init.init_model(config)
|
|
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
|
config, normalizer=config["data"]["normalizer"], single=False
|
|
)
|
|
loss = init.init_loss(config, scaler)
|
|
optimizer, lr_scheduler = init.init_optimizer(model, config["train"])
|
|
init.create_logs(config)
|
|
trainer = select_trainer(
|
|
model,
|
|
loss, optimizer,
|
|
train_loader, val_loader, test_loader, scaler,
|
|
config,
|
|
lr_scheduler, extra_data,
|
|
)
|
|
|
|
# 开始训练
|
|
match config["basic"]["mode"]:
|
|
case "train":
|
|
trainer.train()
|
|
case "test":
|
|
model.load_state_dict(
|
|
torch.load(
|
|
f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth",
|
|
map_location=config["basic"]["device"],
|
|
weights_only=True,
|
|
)
|
|
)
|
|
trainer.test(
|
|
model.to(config["basic"]["device"]),
|
|
trainer.args, test_loader, scaler,
|
|
trainer.logger,
|
|
)
|
|
case _:
|
|
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 指定模型
|
|
model_list = ["HI"]
|
|
# 指定数据集
|
|
dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
|
|
device = "cuda:0" # 指定设备
|
|
seed = 2023 # 随机种子
|
|
for model in model_list:
|
|
for dataset in dataset_list:
|
|
config_path = f"./config/{model}/{dataset}.yaml"
|
|
with open(config_path, "r") as file:
|
|
config = yaml.safe_load(file)
|
|
config["basic"]["device"] = device
|
|
config["basic"]["seed"] = seed
|
|
print(f"\nRunning {model} on {dataset} with seed {seed} on {device}")
|
|
print(f"config: {config}")
|
|
run(config)
|
|
|