TrafficWheel/train.py

98 lines
3.5 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import yaml
import torch
import os
import utils.initializer as init
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def read_config(config_path):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
# 全局配置
device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 1 # 训练轮数
# 拷贝项
config["basic"]["device"] = device
config["model"]["device"] = device
config["train"]["device"] = device
config["basic"]["seed"] = seed
config["train"]["epochs"] = epochs
return config
def run(config):
init.init_seed(config["basic"]["seed"])
model = init.init_model(config)
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
config, normalizer=config["data"]["normalizer"], single=False
)
loss = init.init_loss(config, scaler)
optimizer, lr_scheduler = init.init_optimizer(model, config["train"])
init.create_logs(config)
trainer = select_trainer(
model,
loss, optimizer,
train_loader, val_loader, test_loader, scaler,
config,
lr_scheduler, extra_data,
)
# 开始训练
match config["basic"]["mode"]:
case "train":
trainer.train()
case "test":
model.load_state_dict(
torch.load(
f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth",
map_location=config["basic"]["device"],
weights_only=True,
)
)
trainer.test(
model.to(config["basic"]["device"]),
trainer.args, test_loader, scaler,
trainer.logger,
)
case _:
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
def main(model, data, debug=False):
# 我的调试开关,不做测试就填 str(False)
# os.environ["TRY"] = str(False)
os.environ["TRY"] = str(debug)
for model in model_list:
for dataset in data:
config_path = f"./config/{model}/{dataset}.yaml"
# 可去这个函数里面调整统一的config项注意调设备epochs
config = read_config(config_path)
print(f"\nRunning {model} on {dataset}")
if os.environ.get("TRY") == "True":
try:
run(config)
except Exception as e:
import traceback
import sys, traceback
tb_lines = traceback.format_exc().splitlines()
# 如果不是AssertionError才打印完整traceback
if not tb_lines[-1].startswith("AssertionError"):
traceback.print_exc()
print(f"\n===== {model} on {dataset} failed with error: {e} =====\n")
else:
run(config)
if __name__ == "__main__":
# 调试用
model_list = ["iTransformer", "PatchTST", "HI"]
# model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"]
# model_list = ["MTGNN"]
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
# dataset_list = ["AirQuality"]
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"]
main(model_list, dataset_list, debug = True)