diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml index 3cdf977..91a497e 100644 --- a/config/PatchTST/AirQuality.yaml +++ b/config/PatchTST/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml index 576dbd6..a4e0308 100644 --- a/config/PatchTST/BJTaxi-Inflow.yaml +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 2048 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml index 773ba26..68c8476 100644 --- a/config/PatchTST/BJTaxi-Outflow.yaml +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 2048 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml index 6b9461a..3f88951 100644 --- a/config/PatchTST/METR-LA.yaml +++ b/config/PatchTST/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml index 408995c..0f7bc97 100644 --- a/config/PatchTST/NYCBike-Inflow.yaml +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml index c50f4a1..516e1e1 100644 --- a/config/PatchTST/NYCBike-Outflow.yaml +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml index e798294..ba93575 100644 --- a/config/PatchTST/PEMS-BAY.yaml +++ b/config/PatchTST/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index b1de602..d31a458 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/AirQuality.yaml b/config/iTransformer/AirQuality.yaml index 23eba27..b27d72c 100644 --- a/config/iTransformer/AirQuality.yaml +++ b/config/iTransformer/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-Inflow.yaml index dfc2df2..1df1a67 100644 --- a/config/iTransformer/BJTaxi-Inflow.yaml +++ b/config/iTransformer/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 2048 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-Outflow.yaml index d14bed5..8da0e92 100644 --- a/config/iTransformer/BJTaxi-Outflow.yaml +++ b/config/iTransformer/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 2048 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 20c4068..996e44c 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-Inflow.yaml index 8afa656..fdb4dce 100644 --- a/config/iTransformer/NYCBike-Inflow.yaml +++ b/config/iTransformer/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-Outflow.yaml index 7abba88..7401648 100644 --- a/config/iTransformer/NYCBike-Outflow.yaml +++ b/config/iTransformer/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/PEMS-BAY.yaml b/config/iTransformer/PEMS-BAY.yaml index 17f2fd4..80d354a 100644 --- a/config/iTransformer/PEMS-BAY.yaml +++ b/config/iTransformer/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/SolarEnergy.yaml b/config/iTransformer/SolarEnergy.yaml index cce005a..154be4a 100644 --- a/config/iTransformer/SolarEnergy.yaml +++ b/config/iTransformer/SolarEnergy.yaml @@ -6,11 +6,11 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index c1862df..5ea47fa 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -10,19 +10,19 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] - if model_name == "Informer": - return Informer_loader(config, normalizer, single) - elif model_name in TS_model: - return TS_loader(config, normalizer, single) - else : - match model_name: - case "STGNCDE": - return cde_loader(config, normalizer, single) - case "STGNRDE": - return nrde_loader(config, normalizer, single) - case "DCRNN": - return DCRNN_loader(config, normalizer, single) - case "EXP": - return EXP_loader(config, normalizer, single) - case _: - return normal_loader(config, normalizer, single) + # if model_name == "Informer": + # return Informer_loader(config, normalizer, single) + # elif model_name in TS_model: + # return TS_loader(config, normalizer, single) + # else : + match model_name: + case "STGNCDE": + return cde_loader(config, normalizer, single) + case "STGNRDE": + return nrde_loader(config, normalizer, single) + case "DCRNN": + return DCRNN_loader(config, normalizer, single) + case "EXP": + return EXP_loader(config, normalizer, single) + case _: + return normal_loader(config, normalizer, single) diff --git a/test_informer.py b/test_informer.py deleted file mode 100644 index b614533..0000000 --- a/test_informer.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -from model.model_selector import model_selector -import yaml - -# 读取配置文件 -with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f: - config = yaml.safe_load(f) - -# 初始化模型 -model = model_selector(config) -print('Informer模型初始化成功!') -print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}') - -# 创建测试数据 -B, T, C = 2, 24, 6 -x_enc = torch.randn(B, T, C) - -# 测试1: 完整参数 -print('\n测试1: 完整参数') -x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维 -x_dec = torch.randn(B, 12+24, C) # label_len + pred_len -x_mark_dec = torch.randn(B, 12+24, 4) -try: - output = model(x_enc, x_mark_enc, x_dec, x_mark_dec) - print(f'输出形状: {output.shape}') - print('测试1通过!') -except Exception as e: - print(f'测试1失败: {e}') - -# 测试2: 省略x_mark_enc -print('\n测试2: 省略x_mark_enc') -try: - output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) - print(f'输出形状: {output.shape}') - print('测试2通过!') -except Exception as e: - print(f'测试2失败: {e}') - -# 测试3: 省略x_dec和x_mark_dec -print('\n测试3: 省略x_dec和x_mark_dec') -try: - output = model(x_enc, x_mark_enc=x_mark_enc) - print(f'输出形状: {output.shape}') - print('测试3通过!') -except Exception as e: - print(f'测试3失败: {e}') - -# 测试4: 仅传入x_enc -print('\n测试4: 仅传入x_enc') -try: - output = model(x_enc) - print(f'输出形状: {output.shape}') - print('测试4通过!') -except Exception as e: - print(f'测试4失败: {e}') - -print('\n所有测试完成!') \ No newline at end of file diff --git a/train.py b/train.py index 5beb472..76ea652 100644 --- a/train.py +++ b/train.py @@ -6,14 +6,16 @@ import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer +import cProfile + def read_config(config_path): with open(config_path, "r") as file: config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cuda:1" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 100 + epochs = 120 # 拷贝项 config["basic"]["device"] = device @@ -60,13 +62,13 @@ def run(config): case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") - -if __name__ == "__main__": +def main(debug=False): # 指定模型 model_list = ["iTransformer"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["PEMS-BAY"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) @@ -93,3 +95,8 @@ if __name__ == "__main__": else: run(config) + + +if __name__ == "__main__": + # 调试用 + main(debug = False) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py new file mode 100755 index 0000000..b8def31 --- /dev/null +++ b/trainer/TSTrainer.py @@ -0,0 +1,296 @@ +import math +import os +import time +import copy +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + +class TSWrapper: + def __init__(self, args): + self.b = args['train']['batch_size'] + self.t = args['data']['lag'] + self.n = args['data']['num_nodes'] + self.c = args['data']['input_dim'] + + + def transpose(self, x : torch.Tensor): + # [b, t, n, c] -> [b*n, t, c] + self.b = x.shape[0] + x = x[..., :-2] + x = x.permute(0, 2, 1, 3) + x = x.reshape(self.b*self.n, self.t, self.c) + return x + + def inv_transpose(self, x : torch.Tensor): + x = x.reshape(self.b, self.n, self.t, self.c) + x = x.permute(0, 2, 1, 3) + return x + + +class Trainer: + """模型训练器,负责整个训练流程的管理""" + + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): + # 设备和基本参数 + self.config = args + self.device = args["basic"]["device"] + train_args = args["train"] + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + # 数据处理工具 + self.scaler = scaler + self.args = train_args + self.ts_wrapper = TSWrapper(args) + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + for _, (data, target) in progress_bar: + # 转移数据 + data = data.to(self.device) + target = target.to(self.device) + label = target[..., : self.args["output_dim"]] + # 转换为 [b*n, t, c] + data = self.ts_wrapper.transpose(data) + # 计算loss和反归一化loss + output = self.model(data) + # 转换回[b, t, n, c] + output = self.ts_wrapper.inv_transpose(output) + # 我的调试开关 + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") + assert False + loss = self.loss(output, label) + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + d_loss = self.loss(d_output, d_label) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + # 计算损失并记录指标 + avg_loss = total_loss / len(dataloader) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + # 初始化记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + # 开始训练 + self.logger.info("Training process started") + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + # 早停 + if self._should_early_stop(not_improved_count): + break + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + # 最终评估 + self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况,从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + + # 加载模型检查点(如果提供了路径) + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for data, target in data_loader: + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + data = data[..., :-2] + b, t, n, c = data.shape + data = data.permute(0, 2, 1, 3) + data = data.reshape(b*n, t, c) + label = target[..., : output_dim] + output = model(data) + output = output.reshape(b, n, t, c) + output = output.permute(0, 2, 1, 3) + + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 89340ea..723b257 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -5,7 +5,7 @@ from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer - +from trainer.TSTrainer import Trainer as TSTrainer def select_trainer( model, @@ -20,6 +20,21 @@ def select_trainer( kwargs, ): model_name = args["basic"]["model"] + TS_model = ["HI", "PatchTST", "iTransformer"] + if model_name in TS_model: + return TSTrainer( + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler, + ) + + match model_name: case "STGNCDE": return cdeTrainer(