TrafficWheel/train.py

64 lines
2.2 KiB
Python

import yaml
import torch
import utils.initializer as init
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def run(config):
init.init_seed(config["basic"]["seed"])
model = init.init_model(config)
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
config, normalizer=config["data"]["normalizer"], single=False
)
loss = init.init_loss(config, scaler)
optimizer, lr_scheduler = init.init_optimizer(model, config["train"])
init.create_logs(config)
trainer = select_trainer(
model,
loss, optimizer,
train_loader, val_loader, test_loader, scaler,
config,
lr_scheduler, extra_data,
)
# 开始训练
match config["basic"]["mode"]:
case "train":
trainer.train()
case "test":
model.load_state_dict(
torch.load(
f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth",
map_location=config["basic"]["device"],
weights_only=True,
)
)
trainer.test(
model.to(config["basic"]["device"]),
trainer.args, test_loader, scaler,
trainer.logger,
)
case _:
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
if __name__ == "__main__":
# 指定模型
model_list = ["HI"]
# 指定数据集
dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
device = "cuda:0" # 指定设备
seed = 2023 # 随机种子
for model in model_list:
for dataset in dataset_list:
config_path = f"./config/{model}/{dataset}.yaml"
with open(config_path, "r") as file:
config = yaml.safe_load(file)
config["basic"]["device"] = device
config["basic"]["seed"] = seed
print(f"\nRunning {model} on {dataset} with seed {seed} on {device}")
print(f"config: {config}")
run(config)