add few-shot

This commit is contained in:
czzhangheng 2025-12-29 10:00:55 +08:00
parent 87fab512bb
commit 3ae65229c1
2 changed files with 24 additions and 17 deletions

@ -1 +0,0 @@
Subproject commit 29f2a739226a509202a092b464163da81fa74960

View File

@ -7,22 +7,7 @@ from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
def read_config(config_path):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
# 全局配置
device = "cpu" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 1 # 训练轮数
# 拷贝项
config["basic"]["device"] = device
config["model"]["device"] = device
config["train"]["device"] = device
config["basic"]["seed"] = seed
config["train"]["epochs"] = epochs
return config
def run(config):
@ -99,6 +84,29 @@ def main(model_list, data, debug=False):
run(config)
def read_config(config_path):
# 设置卡种子epochs
with open(config_path, "r") as file:
config = yaml.safe_load(file)
# 全局配置
device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 50 # 训练轮数
# 拷贝项
config["basic"]["seed"] = seed
config["train"]["epochs"] = epochs
for x in ["basic", "model", "train"]:
config[x]["device"] = device
# few-shot 0.05-0.05-0.9
config["data"]["test_ratio"] = 0.9
config["data"]["val_ratio"] = 0.05
return config
if __name__ == "__main__":
# 调试用
# model_list = ["iTransformer", "PatchTST", "HI"]
@ -109,7 +117,7 @@ if __name__ == "__main__":
big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
mid_dataset = ["PEMS-BAY"]
regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
test_dataset = ["BJTaxi-InFlow"]
test_dataset = ["AirQuality"]
all_dataset = big_dataset + mid_dataset + regular_dataset