126 lines
3.7 KiB
Python
126 lines
3.7 KiB
Python
import yaml
|
||
import torch
|
||
import os
|
||
|
||
import utils.initializer as init
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
|
||
|
||
|
||
|
||
|
||
def run(config):
|
||
init.init_seed(config["basic"]["seed"])
|
||
model = init.init_model(config)
|
||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||
config, normalizer=config["data"]["normalizer"], single=False
|
||
)
|
||
loss = init.init_loss(config, scaler)
|
||
optimizer, lr_scheduler = init.init_optimizer(model, config["train"])
|
||
init.create_logs(config)
|
||
trainer = select_trainer(
|
||
model,
|
||
loss,
|
||
optimizer,
|
||
train_loader,
|
||
val_loader,
|
||
test_loader,
|
||
scaler,
|
||
config,
|
||
lr_scheduler,
|
||
extra_data,
|
||
)
|
||
|
||
# 开始训练
|
||
match config["basic"]["mode"]:
|
||
case "train":
|
||
trainer.train()
|
||
case "test":
|
||
model.load_state_dict(
|
||
torch.load(
|
||
f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth",
|
||
map_location=config["basic"]["device"],
|
||
weights_only=True,
|
||
)
|
||
)
|
||
trainer.test(
|
||
model.to(config["basic"]["device"]),
|
||
trainer.args,
|
||
test_loader,
|
||
scaler,
|
||
trainer.logger,
|
||
)
|
||
case _:
|
||
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
|
||
|
||
|
||
def main(model_list, data, debug=False):
|
||
# 我的调试开关,不做测试就填 str(False)
|
||
# os.environ["TRY"] = str(False)
|
||
os.environ["TRY"] = str(debug)
|
||
|
||
for model in model_list:
|
||
for dataset in data:
|
||
config_path = f"./config/{model}/{dataset}.yaml"
|
||
# 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs
|
||
config = read_config(config_path)
|
||
print(f"\nRunning {model} on {dataset}")
|
||
if os.environ.get("TRY") == "True":
|
||
try:
|
||
run(config)
|
||
except Exception as e:
|
||
import traceback
|
||
import sys, traceback
|
||
|
||
tb_lines = traceback.format_exc().splitlines()
|
||
# 如果不是AssertionError,才打印完整traceback
|
||
if not tb_lines[-1].startswith("AssertionError"):
|
||
traceback.print_exc()
|
||
print(
|
||
f"\n===== {model} on {dataset} failed with error: {e} =====\n"
|
||
)
|
||
else:
|
||
run(config)
|
||
|
||
|
||
def read_config(config_path):
|
||
# 设置卡,种子,epochs
|
||
with open(config_path, "r") as file:
|
||
config = yaml.safe_load(file)
|
||
|
||
# 全局配置
|
||
device = "cuda:1" # 指定设备为cuda:0
|
||
seed = 2023 # 随机种子
|
||
epochs = 100 # 训练轮数
|
||
|
||
# 拷贝项
|
||
config["basic"]["seed"] = seed
|
||
config["train"]["epochs"] = epochs
|
||
for x in ["basic", "model", "train"]:
|
||
config[x]["device"] = device
|
||
|
||
# few-shot 0.01-0.01-0.98
|
||
config["data"]["test_ratio"] = 0.98
|
||
config["data"]["val_ratio"] = 0.01
|
||
|
||
return config
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 调试用
|
||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||
model_list = ["REPST"]
|
||
# model_list = ["PatchTST"]
|
||
|
||
air = ["AirQuality"]
|
||
big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||
mid_dataset = ["PEMS-BAY"]
|
||
regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||
test_dataset = ["AirQuality"]
|
||
|
||
all_dataset = big_dataset + mid_dataset + regular_dataset
|
||
|
||
dataset_list = all_dataset
|
||
main(model_list, dataset_list, debug=False)
|