add few-shot
This commit is contained in:
parent
87fab512bb
commit
3ae65229c1
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 29f2a739226a509202a092b464163da81fa74960
|
|
||||||
40
train.py
40
train.py
|
|
@ -7,22 +7,7 @@ from dataloader.loader_selector import get_dataloader
|
||||||
from trainer.trainer_selector import select_trainer
|
from trainer.trainer_selector import select_trainer
|
||||||
|
|
||||||
|
|
||||||
def read_config(config_path):
|
|
||||||
with open(config_path, "r") as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
|
|
||||||
# 全局配置
|
|
||||||
device = "cpu" # 指定设备为cuda:0
|
|
||||||
seed = 2023 # 随机种子
|
|
||||||
epochs = 1 # 训练轮数
|
|
||||||
|
|
||||||
# 拷贝项
|
|
||||||
config["basic"]["device"] = device
|
|
||||||
config["model"]["device"] = device
|
|
||||||
config["train"]["device"] = device
|
|
||||||
config["basic"]["seed"] = seed
|
|
||||||
config["train"]["epochs"] = epochs
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def run(config):
|
def run(config):
|
||||||
|
|
@ -99,6 +84,29 @@ def main(model_list, data, debug=False):
|
||||||
run(config)
|
run(config)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config(config_path):
|
||||||
|
# 设置卡,种子,epochs
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
# 全局配置
|
||||||
|
device = "cuda:0" # 指定设备为cuda:0
|
||||||
|
seed = 2023 # 随机种子
|
||||||
|
epochs = 50 # 训练轮数
|
||||||
|
|
||||||
|
# 拷贝项
|
||||||
|
config["basic"]["seed"] = seed
|
||||||
|
config["train"]["epochs"] = epochs
|
||||||
|
for x in ["basic", "model", "train"]:
|
||||||
|
config[x]["device"] = device
|
||||||
|
|
||||||
|
# few-shot 0.05-0.05-0.9
|
||||||
|
config["data"]["test_ratio"] = 0.9
|
||||||
|
config["data"]["val_ratio"] = 0.05
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调试用
|
# 调试用
|
||||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||||
|
|
@ -109,7 +117,7 @@ if __name__ == "__main__":
|
||||||
big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
big_dataset = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||||
mid_dataset = ["PEMS-BAY"]
|
mid_dataset = ["PEMS-BAY"]
|
||||||
regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
regular_dataset = ["AirQuality", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||||
test_dataset = ["BJTaxi-InFlow"]
|
test_dataset = ["AirQuality"]
|
||||||
|
|
||||||
all_dataset = big_dataset + mid_dataset + regular_dataset
|
all_dataset = big_dataset + mid_dataset + regular_dataset
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue