98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
import yaml
|
||
import torch
|
||
import os
|
||
|
||
import utils.initializer as init
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
|
||
def read_config(config_path):
|
||
with open(config_path, "r") as file:
|
||
config = yaml.safe_load(file)
|
||
|
||
# 全局配置
|
||
device = "cpu" # 指定设备为cuda:0
|
||
seed = 2023 # 随机种子
|
||
epochs = 1 # 训练轮数
|
||
|
||
# 拷贝项
|
||
config["basic"]["device"] = device
|
||
config["model"]["device"] = device
|
||
config["train"]["device"] = device
|
||
config["basic"]["seed"] = seed
|
||
config["train"]["epochs"] = epochs
|
||
return config
|
||
|
||
def run(config):
|
||
init.init_seed(config["basic"]["seed"])
|
||
model = init.init_model(config)
|
||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||
config, normalizer=config["data"]["normalizer"], single=False
|
||
)
|
||
loss = init.init_loss(config, scaler)
|
||
optimizer, lr_scheduler = init.init_optimizer(model, config["train"])
|
||
init.create_logs(config)
|
||
trainer = select_trainer(
|
||
model,
|
||
loss, optimizer,
|
||
train_loader, val_loader, test_loader, scaler,
|
||
config,
|
||
lr_scheduler, extra_data,
|
||
)
|
||
|
||
# 开始训练
|
||
match config["basic"]["mode"]:
|
||
case "train":
|
||
trainer.train()
|
||
case "test":
|
||
model.load_state_dict(
|
||
torch.load(
|
||
f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth",
|
||
map_location=config["basic"]["device"],
|
||
weights_only=True,
|
||
)
|
||
)
|
||
trainer.test(
|
||
model.to(config["basic"]["device"]),
|
||
trainer.args, test_loader, scaler,
|
||
trainer.logger,
|
||
)
|
||
case _:
|
||
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
|
||
|
||
def main(model, data, debug=False):
|
||
# 我的调试开关,不做测试就填 str(False)
|
||
# os.environ["TRY"] = str(False)
|
||
os.environ["TRY"] = str(debug)
|
||
|
||
for model in model_list:
|
||
for dataset in data:
|
||
config_path = f"./config/{model}/{dataset}.yaml"
|
||
# 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs
|
||
config = read_config(config_path)
|
||
print(f"\nRunning {model} on {dataset}")
|
||
if os.environ.get("TRY") == "True":
|
||
try:
|
||
run(config)
|
||
except Exception as e:
|
||
import traceback
|
||
import sys, traceback
|
||
tb_lines = traceback.format_exc().splitlines()
|
||
# 如果不是AssertionError,才打印完整traceback
|
||
if not tb_lines[-1].startswith("AssertionError"):
|
||
traceback.print_exc()
|
||
print(f"\n===== {model} on {dataset} failed with error: {e} =====\n")
|
||
else:
|
||
run(config)
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 调试用
|
||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||
model_list = ["STNorm"]
|
||
# model_list = ["PatchTST"]
|
||
# dataset_list = ["AirQuality"]
|
||
dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||
main(model_list, dataset_list, debug = True) |