From 3ae65229c1509c1f87a3ab138704db675a2c1521 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 29 Dec 2025 10:00:55 +0800 Subject: [PATCH] add few-shot --- Informer/Informer2020 | 1 - train.py | 40 ++++++++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 17 deletions(-) delete mode 160000 Informer/Informer2020 diff --git a/Informer/Informer2020 b/Informer/Informer2020 deleted file mode 160000 index 29f2a73..0000000 --- a/Informer/Informer2020 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 29f2a739226a509202a092b464163da81fa74960 diff --git a/train.py b/train.py index 8e2c6d3..c78e856 100644 --- a/train.py +++ b/train.py @@ -7,22 +7,7 @@ from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer -def read_config(config_path): - with open(config_path, "r") as file: - config = yaml.safe_load(file) - # 全局配置 - device = "cpu" # 指定设备为cuda:0 - seed = 2023 # 随机种子 - epochs = 1 # 训练轮数 - - # 拷贝项 - config["basic"]["device"] = device - config["model"]["device"] = device - config["train"]["device"] = device - config["basic"]["seed"] = seed - config["train"]["epochs"] = epochs - return config def run(config): @@ -99,6 +84,29 @@ def main(model_list, data, debug=False): run(config) +def read_config(config_path): + # 设置卡,种子,epochs + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + # 全局配置 + device = "cuda:0" # 指定设备为cuda:0 + seed = 2023 # 随机种子 + epochs = 50 # 训练轮数 + + # 拷贝项 + config["basic"]["seed"] = seed + config["train"]["epochs"] = epochs + for x in ["basic", "model", "train"]: + config[x]["device"] = device + + # few-shot 0.05-0.05-0.9 + config["data"]["test_ratio"] = 0.9 + config["data"]["val_ratio"] = 0.05 + + return config + + if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] @@ -109,7 +117,7 @@ if __name__ == "__main__": big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] mid_dataset = ["PEMS-BAY"] regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] - test_dataset = ["BJTaxi-InFlow"] + test_dataset = ["AirQuality"] all_dataset = big_dataset + mid_dataset + regular_dataset