From d4ee8e309e7b6263b5ed80c8119c4dc58fb3c3e4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 1 Dec 2025 21:36:37 +0800 Subject: [PATCH 01/41] =?UTF-8?q?trainer=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mypy.ini | 4 + trainer/DCRNN_Trainer.py | 235 +++++++++++++++++++---------- trainer/E32Trainer.py | 194 +++++++++++++++--------- trainer/EXP_trainer.py | 236 +++++++++++++++++++---------- trainer/PDG2SEQ_Trainer.py | 250 ++++++++++++++++++++----------- trainer/STMLP_Trainer.py | 208 +++++++++++++++---------- trainer/cdeTrainer/__init__.py | 0 trainer/cdeTrainer/cdetrainer.py | 248 +++++++++++++++++++----------- utils/training_stats.py | 28 +--- 9 files changed, 900 insertions(+), 503 deletions(-) create mode 100644 mypy.ini create mode 100644 trainer/cdeTrainer/__init__.py diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..c77f418 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +explicit_package_bases = True +ignore_missing_imports = True +no_site_packages = True diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index 417d078..8bb2298 100755 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -2,6 +2,7 @@ import math import os import time import copy +import psutil from tqdm import tqdm import torch @@ -23,34 +24,56 @@ class Trainer: args, lr_scheduler=None, ): + # 设备和基本参数 + self.device = args["basic"]["device"] + train_args = args["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler + self.args = train_args + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - # Paths for saving models and logs + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" self.best_path = os.path.join(args["log_dir"], "best_model.pth") self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger + + def _initialize_logger(self, args): + """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) self.logger = get_logger( args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] ) self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) + + def _initialize_stats(self): + """初始化统计信息记录器""" + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 if mode == "train": self.model.train() optimizer_step = True @@ -58,54 +81,77 @@ class Trainer: self.model.eval() optimizer_step = False + # 初始化变量 total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] + # 训练/验证循环 with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, (data, target) in enumerate(dataloader): - start_time = time.time() - label = target[..., : self.args["output_dim"]] - output = self.model(data, labels=label.clone()).to( - self.args["device"] - ) + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for _, (data, target) in progress_bar: + # 记录步骤开始时间 + start_time = time.time() - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) - label = self.scaler.inverse_transform(label) + # 前向传播 + label = target[..., : self.args["output_dim"]] + output = self.model(data, labels=label.clone()).to(self.device) + loss = self.loss(output, label) - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] ) + self.optimizer.step() + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) + # 记录步骤时间和内存使用 + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) - # 记录内存 + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 self.stats.record_memory_usage() + return avg_loss def train_epoch(self, epoch): @@ -118,21 +164,29 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 + # 开始训练 self.stats.start_training() self.logger.info("Training process started") + + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 @@ -141,38 +195,55 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + # 检查早停条件 + if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self._save_best_models(best_model, best_test_model) - # 输出统计与参数 + # 结束训练并输出统计信息 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + + # 最终评估 self._finalize_training(best_model, best_test_model) + # 输出模型参数量 + self._log_model_params() + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") @@ -184,44 +255,44 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) + model.to(args["basic"]["device"]) + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] + # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: label = target[..., : args["output_dim"]] - output = model(data, labels=label.clone()).to(args["device"]) - y_pred.append(output) - y_true.append(label) + output = model(data, labels=label.clone()) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) - if args["real_value"]: - y_pred = scaler.inverse_transform(y_pred) - y_true = scaler.inverse_transform(y_true) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/trainer/E32Trainer.py b/trainer/E32Trainer.py index 07ff01c..5011131 100644 --- a/trainer/E32Trainer.py +++ b/trainer/E32Trainer.py @@ -23,44 +23,65 @@ class Trainer: global_config, lr_scheduler=None, ): + # 设备和基本参数 self.device = global_config["basic"]["device"] train_config = global_config["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler self.args = train_config - self.lr_scheduler = lr_scheduler + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - # Paths for saving models and logs - self.best_path = os.path.join(train_config["log_dir"], "best_model.pth") - self.best_test_path = os.path.join( - train_config["log_dir"], "best_test_model.pth" - ) - self.loss_figure_path = os.path.join(train_config["log_dir"], "loss.png") - - # Initialize logger - if not os.path.isdir(train_config["log_dir"]) and not train_config["debug"]: - os.makedirs(train_config["log_dir"], exist_ok=True) + # 初始化路径、日志和统计 + self._initialize_paths(train_config) + self._initialize_logger(train_config) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) self.logger = get_logger( - train_config["log_dir"], + args["log_dir"], name=self.model.__class__.__name__, - debug=train_config["debug"], + debug=args["debug"], ) - self.logger.info(f"Experiment log path in: {train_config['log_dir']}") - # Stats tracker + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _initialize_stats(self): + """初始化统计信息记录器""" self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 is_train = mode == "train" self.model.train() if is_train else self.model.eval() + + # 初始化变量 total_loss = 0.0 epoch_time = time.time() + y_pred, y_true = [], [] with ( torch.set_grad_enabled(is_train), @@ -85,10 +106,12 @@ class Trainer: # compute loss label = target[..., : self.args["output_dim"]] - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) loss = self.loss(output, label) + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + # backward / step if is_train: loss.backward() @@ -98,22 +121,39 @@ class Trainer: ) self.optimizer.step() + # 反归一化的loss + d_loss = self.loss(d_output, d_label) + step_time = time.time() - start_time self.stats.record_step_time(step_time, mode) - total_loss += loss.item() + total_loss += d_loss.item() + + # 累积预测结果 + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) # logging if is_train and (batch_idx + 1) % self.args["log_step"] == 0: self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" + f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {d_loss.item():.6f}" ) pbar.update(1) - pbar.set_postfix(loss=loss.item()) + pbar.set_postfix(loss=d_loss.item()) + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 avg_loss = total_loss / len(dataloader) + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] + ) self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" ) # 记录内存 self.stats.record_memory_usage() @@ -129,21 +169,29 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 + # 开始训练 self.stats.start_training() self.logger.info("Training process started") + + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 @@ -152,38 +200,55 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + # 检查早停条件 + if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self._save_best_models(best_model, best_test_model) - # 输出统计与参数 + # 结束训练并输出统计信息 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + + # 最终评估 self._finalize_training(best_model, best_test_model) + # 输出模型参数量 + self._log_model_params() + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") @@ -195,51 +260,44 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): - global_config = args - device = global_config["basic"]["device"] - args = global_config["train"] + """对模型进行评估并输出性能指标""" + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(device) + model.to(args["basic"]["device"]) + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] + # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target, cycle_index in data_loader: label = target[..., : args["output_dim"]] output = model(data, cycle_index) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/trainer/EXP_trainer.py b/trainer/EXP_trainer.py index 0416cc3..b5a48a5 100755 --- a/trainer/EXP_trainer.py +++ b/trainer/EXP_trainer.py @@ -2,6 +2,7 @@ import math import os import time import copy +import psutil from tqdm import tqdm import torch @@ -23,34 +24,56 @@ class Trainer: args, lr_scheduler=None, ): + # 设备和基本参数 + self.device = args["basic"]["device"] + train_args = args["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler + self.args = train_args + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - # Paths for saving models and logs + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" self.best_path = os.path.join(args["log_dir"], "best_model.pth") self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger + + def _initialize_logger(self, args): + """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) self.logger = get_logger( args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] ) self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) + + def _initialize_stats(self): + """初始化统计信息记录器""" + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 if mode == "train": self.model.train() optimizer_step = True @@ -58,52 +81,77 @@ class Trainer: self.model.eval() optimizer_step = False + # 初始化变量 total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] + # 训练/验证循环 with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, (data, target) in enumerate(dataloader): - start_time = time.time() - label = target[..., : self.args["output_dim"]] - output = self.model(data).to(self.args["device"]) + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for _, (data, target) in progress_bar: + # 记录步骤开始时间 + start_time = time.time() - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) + # 前向传播 + label = target[..., : self.args["output_dim"]] + output = self.model(data).to(self.device) + loss = self.loss(output, label) - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] ) + self.optimizer.step() + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) + # 记录步骤时间和内存使用 + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) - # 记录内存 + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 self.stats.record_memory_usage() + return avg_loss def train_epoch(self, epoch): @@ -116,21 +164,29 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 + # 开始训练 self.stats.start_training() self.logger.info("Training process started") + + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 @@ -139,37 +195,55 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + # 检查早停条件 + if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) - # 输出统计与参数 + self._save_best_models(best_model, best_test_model) + + # 结束训练并输出统计信息 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + + # 最终评估 self._finalize_training(best_model, best_test_model) + # 输出模型参数量 + self._log_model_params() + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") @@ -181,48 +255,44 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) + model.to(args["basic"]["device"]) + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] + # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: label = target[..., : args["output_dim"]] output = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py index 72d155d..95b7a61 100755 --- a/trainer/PDG2SEQ_Trainer.py +++ b/trainer/PDG2SEQ_Trainer.py @@ -23,35 +23,57 @@ class Trainer: args, lr_scheduler=None, ): + # 设备和基本参数 + self.device = args["basic"]["device"] + train_args = args["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler + self.args = train_args + self.batches_seen = 0 + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - self.batches_seen = 0 - # Paths for saving models and logs + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" self.best_path = os.path.join(args["log_dir"], "best_model.pth") self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger + + def _initialize_logger(self, args): + """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) self.logger = get_logger( args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] ) self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) + + def _initialize_stats(self): + """初始化统计信息记录器""" + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 if mode == "train": self.model.train() optimizer_step = True @@ -59,55 +81,86 @@ class Trainer: self.model.eval() optimizer_step = False + # 初始化变量 total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, (data, target) in enumerate(dataloader): - start_time = time.time() - self.batches_seen += 1 - label = target[..., : self.args["output_dim"]].clone() - output = self.model(data, target, self.batches_seen).to( - self.args["device"] + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for batch_idx, (data, target) in progress_bar: + start_time = time.time() + self.batches_seen += 1 + label = target[..., : self.args["output_dim"]].clone() + + # 前向传播 + if mode == "train": + output = self.model(data, target, self.batches_seen).to(self.device) + else: + output = self.model(data, target).to(self.device) + + # 计算原始loss + loss = self.loss(output, label) + + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) + + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] + ) + self.optimizer.step() + + # 记录步骤时间 + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) + total_loss += d_loss.item() + + # 累积预测结果 + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: + self.logger.info( + f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {d_loss.item():.6f}" ) - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) + # 更新 tqdm 的进度 + progress_bar.update(1) + progress_bar.set_postfix(loss=d_loss.item()) - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() - - # record step time - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" - ) - - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + # 计算平均损失 avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) - # 记录内存 + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 self.stats.record_memory_usage() + return avg_loss def train_epoch(self, epoch): @@ -120,21 +173,29 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 + # 开始训练 self.stats.start_training() self.logger.info("Training process started") + + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 @@ -143,37 +204,54 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + # 检查早停条件 + if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self._save_best_models(best_model, best_test_model) - # 输出统计与参数 + # 结束训练并输出统计信息 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + + # 输出模型参数量 + self._log_model_params() + + # 最终评估 self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -186,44 +264,44 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) + model.to(args["basic"]["device"]) + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] + # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: label = target[..., : args["output_dim"]].clone() output = model(data, target) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/trainer/STMLP_Trainer.py b/trainer/STMLP_Trainer.py index 6b2217a..d1ac02a 100644 --- a/trainer/STMLP_Trainer.py +++ b/trainer/STMLP_Trainer.py @@ -26,42 +26,35 @@ class Trainer: args, lr_scheduler=None, ): + # 设备和基本参数 + self.device = args["basic"]["device"] + train_args = args["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler - self.args = args["train"] - self.lr_scheduler = lr_scheduler + self.args = train_args + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - # Paths for saving models and logs - self.best_path = os.path.join(self.args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(self.args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(self.args["log_dir"], "loss.png") - self.pretrain_dir = ( - f"./pre-train/{args['model']['type']}/{args['data']['type']}" - ) - self.pretrain_path = os.path.join(self.pretrain_dir, "best_model.pth") - self.pretrain_best_path = os.path.join(self.pretrain_dir, "best_test_model.pth") - - # Initialize logger - if not os.path.isdir(self.args["log_dir"]) and not self.args["debug"]: - os.makedirs(self.args["log_dir"], exist_ok=True) - if not os.path.isdir(self.pretrain_dir) and not self.args["debug"]: - os.makedirs(self.pretrain_dir, exist_ok=True) - self.logger = get_logger( - self.args["log_dir"], - name=self.model.__class__.__name__, - debug=self.args["debug"], - ) - self.logger.info(f"Experiment log path in: {self.args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) - + # 初始化路径、日志和统计 + self._initialize_paths(args, train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + # 教师-学生蒸馏相关 if self.args["teacher_stu"]: self.tmodel = self.loadTeacher(args) else: @@ -70,9 +63,41 @@ class Trainer: f"./pre-train/{args['model']['type']}/{args['data']['type']}/best_model.pth" f"然后在config中配置train.teacher_stu模式为True开启蒸馏模式" ) + + def _initialize_paths(self, args, train_args): + """初始化模型保存路径""" + self.best_path = os.path.join(train_args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(train_args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(train_args["log_dir"], "loss.png") + self.pretrain_dir = ( + f"./pre-train/{args['model']['type']}/{args['data']['type']}" + ) + self.pretrain_path = os.path.join(self.pretrain_dir, "best_model.pth") + self.pretrain_best_path = os.path.join(self.pretrain_dir, "best_test_model.pth") + + # 创建预训练目录 + if not os.path.isdir(self.pretrain_dir) and not train_args["debug"]: + os.makedirs(self.pretrain_dir, exist_ok=True) + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger( + args["log_dir"], + name=self.model.__class__.__name__, + debug=args["debug"], + ) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _initialize_stats(self): + """初始化统计信息记录器""" + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" # self.tmodel.eval() + # 设置模型模式和是否进行优化 if mode == "train": self.model.train() optimizer_step = True @@ -80,8 +105,10 @@ class Trainer: self.model.eval() optimizer_step = False + # 初始化变量 total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] with torch.set_grad_enabled(optimizer_step): with tqdm( @@ -89,15 +116,17 @@ class Trainer: ) as pbar: for batch_idx, (data, target) in enumerate(dataloader): start_time = time.time() + label = target[..., : self.args["output_dim"]] + if self.args["teacher_stu"]: - label = target[..., : self.args["output_dim"]] + # 教师-学生蒸馏模式 output, out_, _ = self.model(data) gout, tout, sout = self.tmodel(data) - - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) - + + # 计算原始loss loss1 = self.loss(output, label) + + # 计算蒸馏相关loss scl = self.loss_cls(out_, sout) kl_loss = nn.KLDivLoss( reduction="batchmean", log_target=True @@ -105,17 +134,22 @@ class Trainer: gout = F.log_softmax(gout, dim=-1).cuda() mlp_emb_ = F.log_softmax(output, dim=-1).cuda() tkloss = kl_loss(mlp_emb_.cuda().float(), gout.cuda().float()) + + # 总loss loss = loss1 + 10 * tkloss + 1 * scl - else: - label = target[..., : self.args["output_dim"]] + # 普通训练模式 output, out_, _ = self.model(data) - - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) - loss = self.loss(output, label) + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) + + # 反向传播和优化(仅在训练模式) if optimizer_step and self.optimizer is not None: self.optimizer.zero_grad() loss.backward() @@ -128,20 +162,34 @@ class Trainer: step_time = time.time() - start_time self.stats.record_step_time(step_time, mode) - total_loss += loss.item() + total_loss += d_loss.item() + + # 累积预测结果 + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" + f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {d_loss.item():.6f}" ) # 更新 tqdm 的进度 pbar.update(1) - pbar.set_postfix(loss=loss.item()) + pbar.set_postfix(loss=d_loss.item()) + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 avg_loss = total_loss / len(dataloader) + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] + ) self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" ) # 记录内存 self.stats.record_memory_usage() @@ -157,6 +205,7 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 @@ -182,13 +231,7 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + if self._should_early_stop(not_improved_count): break if test_epoch_loss < best_test_loss: @@ -207,14 +250,25 @@ class Trainer: # 输出统计与参数 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + self._log_model_params() self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -274,48 +328,44 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) + model.to(args["basic"]["device"]) + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] + # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: label = target[..., : args["output_dim"]] output, _, _ = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1] - - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/trainer/cdeTrainer/__init__.py b/trainer/cdeTrainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainer/cdeTrainer/cdetrainer.py b/trainer/cdeTrainer/cdetrainer.py index 5678a7c..84111fb 100755 --- a/trainer/cdeTrainer/cdetrainer.py +++ b/trainer/cdeTrainer/cdetrainer.py @@ -25,37 +25,60 @@ class Trainer: times, w, ): + # 设备和基本参数 + self.device = args["basic"]["device"] + train_args = args["train"] + + # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader + + # 数据处理工具 self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler + self.args = train_args + + # 统计信息 self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 - self.device = args["device"] - - # Paths for saving models and logs + + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + # 模型特定参数 + self.times = times.to(self.device, dtype=torch.float) + self.w = w + + def _initialize_paths(self, args): + """初始化模型保存路径""" self.best_path = os.path.join(args["log_dir"], "best_model.pth") self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger + + def _initialize_logger(self, args): + """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) self.logger = get_logger( args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] ) self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) - self.times = times.to(self.device, dtype=torch.float) - self.w = w + + def _initialize_stats(self): + """初始化统计信息记录器""" + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 if mode == "train": self.model.train() optimizer_step = True @@ -63,53 +86,84 @@ class Trainer: self.model.eval() optimizer_step = False + # 初始化变量 total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, batch in enumerate(dataloader): - start_time = time.time() - batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) - *train_coeffs, target = batch - label = target[..., : self.args["output_dim"]] - output = self.model(self.times, train_coeffs) + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for batch_idx, batch in progress_bar: + start_time = time.time() + batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) + *train_coeffs, target = batch + label = target[..., : self.args["output_dim"]] + + # 前向传播 + output = self.model(self.times, train_coeffs) + + # 计算原始loss + loss = self.loss(output, label) - # if self.args['real_value']: - # output = self.scaler.inverse_transform(output) + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() + # 反归一化的loss + d_loss = self.loss(d_output, d_label) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] ) + self.optimizer.step() - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) + # 记录步骤时间 + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) + total_loss += d_loss.item() + # 累积预测结果 + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: + self.logger.info( + f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {d_loss.item():.6f}" + ) + + # 更新 tqdm 的进度 + progress_bar.update(1) + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) - # 记录内存 + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 self.stats.record_memory_usage() + return avg_loss def train_epoch(self, epoch): @@ -122,21 +176,29 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 + # 开始训练 self.stats.start_training() self.logger.info("Training process started") + + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 @@ -145,37 +207,54 @@ class Trainer: else: not_improved_count += 1 - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) + # 检查早停条件 + if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self._save_best_models(best_model, best_test_model) - # 输出统计与参数 + # 结束训练并输出统计信息 self.stats.end_training() self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + + # 输出模型参数量 + self._log_model_params() + + # 最终评估 self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -188,42 +267,41 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger): + """对模型进行评估并输出性能指标""" + # 设置为评估模式 model.eval() + + # 收集预测和真实标签 y_pred, y_true = [], [] times = torch.linspace(0, 11, 12) + # 不计算梯度的情况下进行预测 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): - batch = tuple(b.to(args["device"], dtype=torch.float) for b in batch) + batch = tuple(b.to(args["basic"]["device"], dtype=torch.float) for b in batch) *test_coeffs, target = batch label = target[..., : args["output_dim"]] - output = model(times.to(args["device"], dtype=torch.float), test_coeffs) - y_true.append(label) - y_pred.append(output) + output = model(times.to(args["basic"]["device"], dtype=torch.float), test_coeffs) + y_true.append(label.detach().cpu()) + y_pred.append(output.detach().cpu()) - # if args['real_value']: - # y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - # else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + # 反归一化 + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - for t in range(y_true.shape[1]): + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/utils/training_stats.py b/utils/training_stats.py index ecda094..9483354 100644 --- a/utils/training_stats.py +++ b/utils/training_stats.py @@ -47,6 +47,10 @@ class TrainingStats: self.cpu_mem_usage_list.append(cpu_mem) self.gpu_mem_usage_list.append(gpu_mem) + def _calculate_average(self, values_list): + """安全计算平均值,避免除零错误""" + return sum(values_list) / len(values_list) if values_list else 0 + def report(self, logger): """在训练结束时输出汇总统计""" if not self.start_time or not self.end_time: @@ -54,26 +58,10 @@ class TrainingStats: return total_time = self.end_time - self.start_time - avg_gpu_mem = ( - sum(self.gpu_mem_usage_list) / len(self.gpu_mem_usage_list) - if self.gpu_mem_usage_list - else 0 - ) - avg_cpu_mem = ( - sum(self.cpu_mem_usage_list) / len(self.cpu_mem_usage_list) - if self.cpu_mem_usage_list - else 0 - ) - avg_train_time = ( - sum(self.train_time_list) / len(self.train_time_list) - if self.train_time_list - else 0 - ) - avg_infer_time = ( - sum(self.infer_time_list) / len(self.infer_time_list) - if self.infer_time_list - else 0 - ) + avg_gpu_mem = self._calculate_average(self.gpu_mem_usage_list) + avg_cpu_mem = self._calculate_average(self.cpu_mem_usage_list) + avg_train_time = self._calculate_average(self.train_time_list) + avg_infer_time = self._calculate_average(self.infer_time_list) iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 logger.info("===== Training Summary =====") -- 2.40.1 From 77a32104755d3a576c01d59809b33edd10978538 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 1 Dec 2025 22:29:52 +0800 Subject: [PATCH 02/41] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=89=80=E6=9C=89?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=98=AF=E5=90=A6=E6=AD=A3=E5=B8=B8=E8=BF=90?= =?UTF-8?q?=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 1848 ++++++++++ generate_launch_configs.py | 134 + test_configs.py | 124 + test_results.txt | 5670 ++++++++++++++++++++++++++++++ trainer/DCRNN_Trainer.py | 10 + trainer/E32Trainer.py | 10 + trainer/PDG2SEQ_Trainer.py | 10 + trainer/STMLP_Trainer.py | 20 + trainer/cdeTrainer/cdetrainer.py | 10 + 9 files changed, 7836 insertions(+) create mode 100644 generate_launch_configs.py create mode 100644 test_configs.py create mode 100644 test_results.txt diff --git a/.vscode/launch.json b/.vscode/launch.json index 6193b75..2b530ca 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -234,5 +234,1853 @@ "console": "integratedTerminal", "args": "--config ./config/AEPSA/v2_SolarEnergy.yaml" }, + { + "name": "EXPB: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/NYCBike-InFlow.yaml" + }, + { + "name": "EXPB: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/PEMSD4.yaml" + }, + { + "name": "EXPB: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/METR-LA.yaml" + }, + { + "name": "EXPB: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/AirQuality.yaml" + }, + { + "name": "EXPB: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/NYCBike-OutFlow.yaml" + }, + { + "name": "EXPB: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXPB/SolarEnergy.yaml" + }, + { + "name": "TWDGCN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/NYCBike-InFlow.yaml" + }, + { + "name": "TWDGCN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD4.yaml" + }, + { + "name": "TWDGCN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/METR-LA.yaml" + }, + { + "name": "TWDGCN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/AirQuality.yaml" + }, + { + "name": "TWDGCN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/NYCBike-OutFlow.yaml" + }, + { + "name": "TWDGCN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD8.yaml" + }, + { + "name": "TWDGCN: PEMSD7(L)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD7(L).yaml" + }, + { + "name": "TWDGCN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD3.yaml" + }, + { + "name": "TWDGCN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/SolarEnergy.yaml" + }, + { + "name": "TWDGCN: Hainan", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/Hainan.yaml" + }, + { + "name": "TWDGCN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD7.yaml" + }, + { + "name": "TWDGCN: PEMSD7(M)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TWDGCN/PEMSD7(M).yaml" + }, + { + "name": "STSGCN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/NYCBike-InFlow.yaml" + }, + { + "name": "STSGCN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/PEMSD4.yaml" + }, + { + "name": "STSGCN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/METR-LA.yaml" + }, + { + "name": "STSGCN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/AirQuality.yaml" + }, + { + "name": "STSGCN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/NYCBike-OutFlow.yaml" + }, + { + "name": "STSGCN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/PEMSD8.yaml" + }, + { + "name": "STSGCN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/PEMSD3.yaml" + }, + { + "name": "STSGCN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/SolarEnergy.yaml" + }, + { + "name": "STSGCN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STSGCN/PEMSD7.yaml" + }, + { + "name": "STID: NYCBike_Inflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/NYCBike_Inflow.yaml" + }, + { + "name": "STID: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/AirQuality.yaml" + }, + { + "name": "STID: NYCBike_Outflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/NYCBike_Outflow.yaml" + }, + { + "name": "STAWnet: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/NYCBike-InFlow.yaml" + }, + { + "name": "STAWnet: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/PEMSD4.yaml" + }, + { + "name": "STAWnet: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/METR-LA.yaml" + }, + { + "name": "STAWnet: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/AirQuality.yaml" + }, + { + "name": "STAWnet: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/NYCBike-OutFlow.yaml" + }, + { + "name": "STAWnet: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/PEMSD8.yaml" + }, + { + "name": "STAWnet: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/PEMSD3.yaml" + }, + { + "name": "STAWnet: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/SolarEnergy.yaml" + }, + { + "name": "STAWnet: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAWnet/PEMSD7.yaml" + }, + { + "name": "DCRNN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/NYCBike-InFlow.yaml" + }, + { + "name": "DCRNN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/PEMSD4.yaml" + }, + { + "name": "DCRNN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/METR-LA.yaml" + }, + { + "name": "DCRNN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/AirQuality.yaml" + }, + { + "name": "DCRNN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/NYCBike-OutFlow.yaml" + }, + { + "name": "DCRNN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/PEMSD8.yaml" + }, + { + "name": "DCRNN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/PEMSD3.yaml" + }, + { + "name": "DCRNN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/SolarEnergy.yaml" + }, + { + "name": "DCRNN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DCRNN/PEMSD7.yaml" + }, + { + "name": "STAEFormer: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/NYCBike-InFlow.yaml" + }, + { + "name": "STAEFormer: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/PEMSD4.yaml" + }, + { + "name": "STAEFormer: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/METR-LA.yaml" + }, + { + "name": "STAEFormer: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/AirQuality.yaml" + }, + { + "name": "STAEFormer: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/NYCBike-OutFlow.yaml" + }, + { + "name": "STAEFormer: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/PEMSD8.yaml" + }, + { + "name": "STAEFormer: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/PEMSD3.yaml" + }, + { + "name": "STAEFormer: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/SolarEnergy.yaml" + }, + { + "name": "STAEFormer: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STAEFormer/PEMSD7.yaml" + }, + { + "name": "STGODE: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/NYCBike-InFlow.yaml" + }, + { + "name": "STGODE: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/PEMSD4.yaml" + }, + { + "name": "STGODE: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/METR-LA.yaml" + }, + { + "name": "STGODE: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/AirQuality.yaml" + }, + { + "name": "STGODE: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/NYCBike-OutFlow.yaml" + }, + { + "name": "STGODE: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/PEMSD8.yaml" + }, + { + "name": "STGODE: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/PEMSD3.yaml" + }, + { + "name": "STGODE: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/SolarEnergy.yaml" + }, + { + "name": "STGODE: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGODE/PEMSD7.yaml" + }, + { + "name": "STGNCDE: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/NYCBike-InFlow.yaml" + }, + { + "name": "STGNCDE: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/PEMSD4.yaml" + }, + { + "name": "STGNCDE: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/METR-LA.yaml" + }, + { + "name": "STGNCDE: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/AirQuality.yaml" + }, + { + "name": "STGNCDE: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/NYCBike-OutFlow.yaml" + }, + { + "name": "STGNCDE: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/PEMSD8.yaml" + }, + { + "name": "STGNCDE: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/PEMSD3.yaml" + }, + { + "name": "STGNCDE: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/SolarEnergy.yaml" + }, + { + "name": "STGNCDE: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNCDE/PEMSD7.yaml" + }, + { + "name": "AEPSA: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/NYCBike-InFlow.yaml" + }, + { + "name": "AEPSA: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/NYCBike-OutFlow.yaml" + }, + { + "name": "ST_SSL: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/NYCBike-InFlow.yaml" + }, + { + "name": "ST_SSL: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/PEMSD4.yaml" + }, + { + "name": "ST_SSL: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/METR-LA.yaml" + }, + { + "name": "ST_SSL: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/AirQuality.yaml" + }, + { + "name": "ST_SSL: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/NYCBike-OutFlow.yaml" + }, + { + "name": "ST_SSL: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/PEMSD8.yaml" + }, + { + "name": "ST_SSL: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/PEMSD3.yaml" + }, + { + "name": "ST_SSL: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/SolarEnergy.yaml" + }, + { + "name": "ST_SSL: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ST_SSL/PEMSD7.yaml" + }, + { + "name": "TCN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/NYCBike-InFlow.yaml" + }, + { + "name": "TCN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/PEMSD4.yaml" + }, + { + "name": "TCN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/METR-LA.yaml" + }, + { + "name": "TCN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/AirQuality.yaml" + }, + { + "name": "TCN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/NYCBike-OutFlow.yaml" + }, + { + "name": "TCN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/PEMSD8.yaml" + }, + { + "name": "TCN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/PEMSD3.yaml" + }, + { + "name": "TCN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/SolarEnergy.yaml" + }, + { + "name": "TCN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/TCN/PEMSD7.yaml" + }, + { + "name": "EXP: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/NYCBike-InFlow.yaml" + }, + { + "name": "EXP: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/PEMSD4.yaml" + }, + { + "name": "EXP: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/METR-LA.yaml" + }, + { + "name": "EXP: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/AirQuality.yaml" + }, + { + "name": "EXP: SD", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/SD.yaml" + }, + { + "name": "EXP: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/NYCBike-OutFlow.yaml" + }, + { + "name": "EXP: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/PEMSD8.yaml" + }, + { + "name": "EXP: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/PEMSD3.yaml" + }, + { + "name": "EXP: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/SolarEnergy.yaml" + }, + { + "name": "EXP: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/EXP/PEMSD7.yaml" + }, + { + "name": "DDGCRN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/NYCBike-InFlow.yaml" + }, + { + "name": "DDGCRN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD4.yaml" + }, + { + "name": "DDGCRN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/AirQuality.yaml" + }, + { + "name": "DDGCRN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/NYCBike-OutFlow.yaml" + }, + { + "name": "DDGCRN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD8.yaml" + }, + { + "name": "DDGCRN: PEMSD7(L)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD7(L).yaml" + }, + { + "name": "DDGCRN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD3.yaml" + }, + { + "name": "DDGCRN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/SolarEnergy.yaml" + }, + { + "name": "DDGCRN: Hainan", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/Hainan.yaml" + }, + { + "name": "DDGCRN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD7.yaml" + }, + { + "name": "DDGCRN: PEMSD7(M)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD7(M).yaml" + }, + { + "name": "DSANET: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/NYCBike-InFlow.yaml" + }, + { + "name": "DSANET: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/PEMSD4.yaml" + }, + { + "name": "DSANET: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/METR-LA.yaml" + }, + { + "name": "DSANET: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/AirQuality.yaml" + }, + { + "name": "DSANET: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/NYCBike-OutFlow.yaml" + }, + { + "name": "DSANET: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/PEMSD8.yaml" + }, + { + "name": "DSANET: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/PEMSD3.yaml" + }, + { + "name": "DSANET: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/SolarEnergy.yaml" + }, + { + "name": "DSANET: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DSANET/PEMSD7.yaml" + }, + { + "name": "STFGNN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/NYCBike-InFlow.yaml" + }, + { + "name": "STFGNN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/PEMSD4.yaml" + }, + { + "name": "STFGNN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/METR-LA.yaml" + }, + { + "name": "STFGNN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/AirQuality.yaml" + }, + { + "name": "STFGNN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/NYCBike-OutFlow.yaml" + }, + { + "name": "STFGNN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/PEMSD8.yaml" + }, + { + "name": "STFGNN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/PEMSD3.yaml" + }, + { + "name": "STFGNN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/SolarEnergy.yaml" + }, + { + "name": "STFGNN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STFGNN/PEMSD7.yaml" + }, + { + "name": "AGCRN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/NYCBike-InFlow.yaml" + }, + { + "name": "AGCRN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/PEMSD4.yaml" + }, + { + "name": "AGCRN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/METR-LA.yaml" + }, + { + "name": "AGCRN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/AirQuality.yaml" + }, + { + "name": "AGCRN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/NYCBike-OutFlow.yaml" + }, + { + "name": "AGCRN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/PEMSD8.yaml" + }, + { + "name": "AGCRN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/PEMSD3.yaml" + }, + { + "name": "AGCRN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/SolarEnergy.yaml" + }, + { + "name": "AGCRN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AGCRN/PEMSD7.yaml" + }, + { + "name": "STGNRDE: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/NYCBike-InFlow.yaml" + }, + { + "name": "STGNRDE: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/PEMSD4.yaml" + }, + { + "name": "STGNRDE: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/METR-LA.yaml" + }, + { + "name": "STGNRDE: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/AirQuality.yaml" + }, + { + "name": "STGNRDE: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/NYCBike-OutFlow.yaml" + }, + { + "name": "STGNRDE: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/PEMSD8.yaml" + }, + { + "name": "STGNRDE: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/PEMSD3.yaml" + }, + { + "name": "STGNRDE: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/SolarEnergy.yaml" + }, + { + "name": "STGNRDE: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGNRDE/PEMSD7.yaml" + }, + { + "name": "REPST: PEMS-BAY_paper", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/PEMS-BAY_paper.yaml" + }, + { + "name": "REPST: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/NYCBike-InFlow.yaml" + }, + { + "name": "REPST: BeijingAirQuality(Deprecated)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BeijingAirQuality(Deprecated).yaml" + }, + { + "name": "REPST: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/NYCBike-OutFlow.yaml" + }, + { + "name": "STIDGCN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/NYCBike-InFlow.yaml" + }, + { + "name": "STIDGCN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/PEMSD4.yaml" + }, + { + "name": "STIDGCN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/METR-LA.yaml" + }, + { + "name": "STIDGCN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/AirQuality.yaml" + }, + { + "name": "STIDGCN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/NYCBike-OutFlow.yaml" + }, + { + "name": "STIDGCN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/PEMSD8.yaml" + }, + { + "name": "STIDGCN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/PEMSD3.yaml" + }, + { + "name": "STIDGCN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/SolarEnergy.yaml" + }, + { + "name": "STIDGCN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STIDGCN/PEMSD7.yaml" + }, + { + "name": "PDG2SEQ: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/NYCBike-InFlow.yaml" + }, + { + "name": "PDG2SEQ: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/PEMSD4.yaml" + }, + { + "name": "PDG2SEQ: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/METR-LA.yaml" + }, + { + "name": "PDG2SEQ: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/AirQuality.yaml" + }, + { + "name": "PDG2SEQ: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/NYCBike-OutFlow.yaml" + }, + { + "name": "PDG2SEQ: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/PEMSD8.yaml" + }, + { + "name": "PDG2SEQ: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/PEMSD3.yaml" + }, + { + "name": "PDG2SEQ: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/SolarEnergy.yaml" + }, + { + "name": "PDG2SEQ: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/PDG2SEQ/PEMSD7.yaml" + }, + { + "name": "NLT: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/NYCBike-InFlow.yaml" + }, + { + "name": "NLT: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/PEMSD4.yaml" + }, + { + "name": "NLT: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/METR-LA.yaml" + }, + { + "name": "NLT: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/AirQuality.yaml" + }, + { + "name": "NLT: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/NYCBike-OutFlow.yaml" + }, + { + "name": "NLT: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/PEMSD8.yaml" + }, + { + "name": "NLT: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/PEMSD3.yaml" + }, + { + "name": "NLT: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/SolarEnergy.yaml" + }, + { + "name": "NLT: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/NLT/PEMSD7.yaml" + }, + { + "name": "ARIMA: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/NYCBike-InFlow.yaml" + }, + { + "name": "ARIMA: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD4.yaml" + }, + { + "name": "ARIMA: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/METR-LA.yaml" + }, + { + "name": "ARIMA: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/AirQuality.yaml" + }, + { + "name": "ARIMA: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/NYCBike-OutFlow.yaml" + }, + { + "name": "ARIMA: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD8.yaml" + }, + { + "name": "ARIMA: PEMSD7(L)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD7(L).yaml" + }, + { + "name": "ARIMA: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD3.yaml" + }, + { + "name": "ARIMA: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/SolarEnergy.yaml" + }, + { + "name": "ARIMA: Hainan", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/Hainan.yaml" + }, + { + "name": "ARIMA: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD7.yaml" + }, + { + "name": "ARIMA: PEMSD7(M)", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/ARIMA/PEMSD7(M).yaml" + }, + { + "name": "STMLP: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/NYCBike-InFlow.yaml" + }, + { + "name": "STMLP: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/PEMSD4.yaml" + }, + { + "name": "STMLP: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/METR-LA.yaml" + }, + { + "name": "STMLP: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/AirQuality.yaml" + }, + { + "name": "STMLP: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/NYCBike-OutFlow.yaml" + }, + { + "name": "STMLP: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/PEMSD8.yaml" + }, + { + "name": "STMLP: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/PEMSD3.yaml" + }, + { + "name": "STMLP: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/SolarEnergy.yaml" + }, + { + "name": "STMLP: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STMLP/PEMSD7.yaml" + }, + { + "name": "MegaCRN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/NYCBike-InFlow.yaml" + }, + { + "name": "MegaCRN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/PEMSD4.yaml" + }, + { + "name": "MegaCRN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/METR-LA.yaml" + }, + { + "name": "MegaCRN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/AirQuality.yaml" + }, + { + "name": "MegaCRN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/NYCBike-OutFlow.yaml" + }, + { + "name": "MegaCRN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/PEMSD8.yaml" + }, + { + "name": "MegaCRN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/PEMSD3.yaml" + }, + { + "name": "MegaCRN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/SolarEnergy.yaml" + }, + { + "name": "MegaCRN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/MegaCRN/PEMSD7.yaml" + }, + { + "name": "GWN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/NYCBike-InFlow.yaml" + }, + { + "name": "GWN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/PEMSD4.yaml" + }, + { + "name": "GWN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/METR-LA.yaml" + }, + { + "name": "GWN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/AirQuality.yaml" + }, + { + "name": "GWN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/NYCBike-OutFlow.yaml" + }, + { + "name": "GWN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/PEMSD8.yaml" + }, + { + "name": "GWN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/PEMSD3.yaml" + }, + { + "name": "GWN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/SolarEnergy.yaml" + }, + { + "name": "GWN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/GWN/PEMSD7.yaml" + }, + { + "name": "STGCN: NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/NYCBike-InFlow.yaml" + }, + { + "name": "STGCN: PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/PEMSD4.yaml" + }, + { + "name": "STGCN: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/METR-LA.yaml" + }, + { + "name": "STGCN: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/AirQuality.yaml" + }, + { + "name": "STGCN: NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/NYCBike-OutFlow.yaml" + }, + { + "name": "STGCN: PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/PEMSD8.yaml" + }, + { + "name": "STGCN: PEMSD3", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/PEMSD3.yaml" + }, + { + "name": "STGCN: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/SolarEnergy.yaml" + }, + { + "name": "STGCN: PEMSD7", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STGCN/PEMSD7.yaml" + } ] } \ No newline at end of file diff --git a/generate_launch_configs.py b/generate_launch_configs.py new file mode 100644 index 0000000..6477e16 --- /dev/null +++ b/generate_launch_configs.py @@ -0,0 +1,134 @@ +import os +import re + +# 配置路径 +CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config" +LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json" + +# 遍历所有yaml文件 +def find_all_yaml_files(directory): + yaml_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".yaml") and not file.startswith("BJTaxi"): + yaml_files.append(os.path.join(root, file)) + return yaml_files + +# 生成launch配置字符串 +def generate_launch_config_string(yaml_files): + config_strings = [] + + for file_path in yaml_files: + # 提取模型名和数据集名 + relative_path = os.path.relpath(file_path, CONFIG_DIR) + model_name = relative_path.split(os.sep)[0] + dataset_name = os.path.splitext(os.path.basename(file_path))[0] + + # 处理v2版本 + if "v2_" in dataset_name: + model_display_name = f"{model_name}_v2" + dataset_display_name = dataset_name.replace("v2_", "") + else: + model_display_name = model_name + dataset_display_name = dataset_name + + # 生成配置字符串 + config_string = f''' + {{ + "name": "{model_display_name}: {dataset_display_name}", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/{model_name}/{os.path.basename(file_path)}" + }}''' + + config_strings.append(config_string) + + return ",".join(config_strings) + +# 读取现有的launch.json文件,提取配置名称 +def get_existing_config_names(): + with open(LAUNCH_FILE, 'r') as f: + content = f.read() + + # 提取所有配置名称 + name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"') + matches = name_pattern.findall(content) + + return set(matches) + +# 生成新的配置,过滤掉已存在的 +def generate_new_configs(yaml_files, existing_names): + new_configs = [] + + for file_path in yaml_files: + # 提取模型名和数据集名 + relative_path = os.path.relpath(file_path, CONFIG_DIR) + model_name = relative_path.split(os.sep)[0] + dataset_name = os.path.splitext(os.path.basename(file_path))[0] + + # 处理v2版本 + if "v2_" in dataset_name: + model_display_name = f"{model_name}_v2" + dataset_display_name = dataset_name.replace("v2_", "") + else: + model_display_name = model_name + dataset_display_name = dataset_name + + # 生成配置名称 + config_name = f"{model_display_name}: {dataset_display_name}" + + # 如果配置不存在,则添加 + if config_name not in existing_names: + new_configs.append(file_path) + + return new_configs + +# 更新launch.json文件 +def update_launch_json(new_configs_string): + with open(LAUNCH_FILE, 'r') as f: + content = f.read() + + # 找到configurations数组的结束位置 + configs_end_match = re.search(r'\s*\]\s*\}', content) + if not configs_end_match: + return False + + # 插入新的配置 + insert_pos = configs_end_match.start() + new_content = content[:insert_pos] + new_configs_string + content[insert_pos:] + + # 保存文件 + with open(LAUNCH_FILE, 'w') as f: + f.write(new_content) + + return True + +# 主函数 +def main(): + # 查找所有yaml文件 + yaml_files = find_all_yaml_files(CONFIG_DIR) + + # 获取现有配置名称 + existing_names = get_existing_config_names() + + # 生成新的配置,过滤掉已存在的 + new_config_files = generate_new_configs(yaml_files, existing_names) + + if not new_config_files: + print("No new configurations to add") + return + + # 生成新的配置字符串 + new_configs_string = generate_launch_config_string(new_config_files) + + # 更新launch.json文件 + if update_launch_json(new_configs_string): + print(f"Added {len(new_config_files)} new launch configurations") + print(f"Total configurations: {len(existing_names) + len(new_config_files)}") + else: + print("Failed to update launch.json") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_configs.py b/test_configs.py new file mode 100644 index 0000000..4fb4bbe --- /dev/null +++ b/test_configs.py @@ -0,0 +1,124 @@ +import os +import subprocess +import yaml +import time + +# 配置路径 +CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config" +RUN_SCRIPT = "/user/czzhangheng/code/TrafficWheel/run.py" +RESULTS_FILE = "/user/czzhangheng/code/TrafficWheel/test_results.txt" + +# 记录测试结果的字典 +results = { + "passed": [], + "failed": [], + "error": [] +} + +# 遍历所有yaml文件 +def find_all_yaml_files(directory): + yaml_files = [] + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith(".yaml") and not file.startswith("BJTaxi"): + yaml_files.append(os.path.join(root, file)) + return yaml_files + +# 测试单个yaml文件 +def test_yaml_file(yaml_path): + print(f"\n=== Testing {yaml_path} ===") + + # 检查文件是否存在 + if not os.path.exists(yaml_path): + print(f"File not found: {yaml_path}") + return "error", f"File not found: {yaml_path}" + + # 运行测试命令 + command = ["python", RUN_SCRIPT, "--config", yaml_path] + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + timeout=600 # 10分钟超时 + ) + + # 分析结果 + if result.returncode == 0: + if "Test passed" in result.stdout: + print(f"✓ PASSED: {yaml_path}") + return "passed", result.stdout.strip() + else: + print(f"✗ FAILED: {yaml_path}") + return "failed", result.stdout.strip() + "\n" + result.stderr.strip() + else: + print(f"✗ ERROR: {yaml_path}") + return "error", result.stdout.strip() + "\n" + result.stderr.strip() + except subprocess.TimeoutExpired: + print(f"✗ TIMEOUT: {yaml_path}") + return "error", "Timeout after 10 minutes" + except Exception as e: + print(f"✗ EXCEPTION: {yaml_path}") + return "error", str(e) + +# 生成测试报告 +def generate_report(results): + total = len(results["passed"]) + len(results["failed"]) + len(results["error"]) + + report = f"""# 测试报告 + +## 测试概述 +- 测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')} +- 总测试文件数: {total} +- 通过: {len(results['passed'])} +- 失败: {len(results['failed'])} +- 错误: {len(results['error'])} + +## 通过的配置文件 +""" + + for file_path, output in results["passed"]: + report += f"- ✅ {file_path}\n" + + report += "\n## 失败的配置文件\n" + for file_path, output in results["failed"]: + report += f"- ❌ {file_path}\n" + + report += "\n## 出错的配置文件\n" + for file_path, output in results["error"]: + report += f"- ⚠️ {file_path}\n" + + report += "\n## 详细输出\n" + + for status, files in results.items(): + report += f"\n### {status.upper()}\n\n" + for file_path, output in files: + report += f"#### {file_path}\n\n```\n{output}\n```\n\n" + + return report + +# 主函数 +def main(): + # 找到所有符合条件的yaml文件 + yaml_files = find_all_yaml_files(CONFIG_DIR) + print(f"Found {len(yaml_files)} yaml files to test") + + # 测试每个文件 + for yaml_file in yaml_files: + status, output = test_yaml_file(yaml_file) + results[status].append((yaml_file, output)) + + # 生成并保存报告 + report = generate_report(results) + with open(RESULTS_FILE, "w") as f: + f.write(report) + + print(f"\n=== Test Results ===") + print(f"Total: {len(yaml_files)}") + print(f"Passed: {len(results['passed'])}") + print(f"Failed: {len(results['failed'])}") + print(f"Error: {len(results['error'])}") + print(f"Report saved to: {RESULTS_FILE}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_results.txt b/test_results.txt new file mode 100644 index 0000000..6b5c248 --- /dev/null +++ b/test_results.txt @@ -0,0 +1,5670 @@ +# 测试报告 + +## 测试概述 +- 测试时间: 2025-12-01 22:20:35 +- 总测试文件数: 252 +- 通过: 41 +- 失败: 0 +- 错误: 211 + +## 通过的配置文件 +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml +- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml + +## 失败的配置文件 + +## 出错的配置文件 +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(L).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/Hainan.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(M).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml +- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml + +## 详细输出 + +### PASSED + +#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml + +``` +模型参数量: 118040 +加载 NYCBike-InFlow 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml + +``` +模型参数量: 120568 +加载 METR-LA 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml + +``` +模型参数量: 115064 +加载 AirQuality 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 35, 1]) matches label shape torch.Size([64, 24, 35, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml + +``` +模型参数量: 124344 +加载 PEMS-BAY 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 325, 1]) matches label shape torch.Size([64, 24, 325, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml + +``` +模型参数量: 118040 +加载 NYCBike-OutFlow 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml + +``` +模型参数量: 118328 +加载 SolarEnergy 数据集中... +✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml + +``` +模型参数量: 1354932 +加载 PEMSD4 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD4/2025-12-01_21-52-10/run.log +✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml + +``` +模型参数量: 1258932 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/METR-LA/2025-12-01_21-52-24/run.log +✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml + +``` +模型参数量: 1223412 +加载 PEMSD8 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD8/2025-12-01_21-52-49/run.log +✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml + +``` +模型参数量: 1403892 +加载 PEMSD3 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD3/2025-12-01_21-53-06/run.log +✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml + +``` +模型参数量: 1907892 +加载 PEMSD7 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD7/2025-12-01_21-54-04/run.log +✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml + +``` +模型参数量: 103504579 +加载 NYCBike-InFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-InFlow/2025-12-01_21-55-58/run.log +✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml + +``` +模型参数量: 103505369 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/METR-LA/2025-12-01_21-56-29/run.log +✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml + +``` +模型参数量: 103503669 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/AirQuality/2025-12-01_21-56-40/run.log +✓ Test passed: output shape torch.Size([16, 24, 35, 6]) matches label shape torch.Size([16, 24, 35, 6]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml + +``` +模型参数量: 103506549 +加载 PEMS-BAY 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/PEMS-BAY/2025-12-01_21-57-30/run.log +✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml + +``` +模型参数量: 103504669 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/SolarEnergy/2025-12-01_21-57-55/run.log +✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml + +``` +模型参数量: 103524820 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/METR-LA/2025-12-01_21-58-18/run.log +✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml + +``` +模型参数量: 103504579 +加载 NYCBike-OutFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-OutFlow/2025-12-01_21-58-29/run.log +✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml + +``` +模型参数量: 103524120 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/SolarEnergy/2025-12-01_21-58-54/run.log +✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml + +``` +模型参数量: 35873 +加载 NYCBike-InFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-InFlow/2025-12-01_21-59-55/run.log +✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml + +``` +模型参数量: 35873 +加载 PEMSD4 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD4/2025-12-01_22-00-07/run.log +✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml + +``` +模型参数量: 35873 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/METR-LA/2025-12-01_22-00-28/run.log +✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml + +``` +模型参数量: 35873 +加载 NYCBike-OutFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-OutFlow/2025-12-01_22-00-44/run.log +✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml + +``` +模型参数量: 35873 +加载 PEMSD8 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD8/2025-12-01_22-00-54/run.log +✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml + +``` +模型参数量: 35873 +加载 PEMSD3 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD3/2025-12-01_22-01-10/run.log +✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml + +``` +模型参数量: 35873 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/SolarEnergy/2025-12-01_22-01-33/run.log +✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml + +``` +模型参数量: 35873 +加载 PEMSD7 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD7/2025-12-01_22-02-05/run.log +✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml + +``` +模型参数量: 671644 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/METR-LA/2025-12-01_22-03-35/run.log +✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml + +``` +模型参数量: 311759 +加载 PEMSD8 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/PEMSD8/2025-12-01_22-03-57/run.log +✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml + +``` +模型参数量: 37896712 +加载 PEMSD4 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD4/2025-12-01_22-04-52/run.log +✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml + +``` +模型参数量: 37896712 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/METR-LA/2025-12-01_22-05-06/run.log +✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml + +``` +模型参数量: 37896712 +加载 PEMSD8 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD8/2025-12-01_22-05-33/run.log +✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml + +``` +模型参数量: 37896712 +加载 PEMSD3 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD3/2025-12-01_22-05-49/run.log +✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml + +``` +模型参数量: 615304 +加载 PEMSD7 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD7/2025-12-01_22-06-48/run.log +✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml + +``` +模型参数量: 103481647 +加载 NYCBike-InFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-InFlow/2025-12-01_22-09-34/run.log +✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml + +``` +模型参数量: 103815937 +加载 BeijingAirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/BeijingAirQuality/2025-12-01_22-09-59/run.log +✓ Test passed: output shape torch.Size([16, 24, 7, 3]) matches label shape torch.Size([16, 24, 7, 3]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml + +``` +模型参数量: 103481647 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/METR-LA/2025-12-01_22-10-22/run.log +✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml + +``` +模型参数量: 103815973 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/AirQuality/2025-12-01_22-10-33/run.log +✓ Test passed: output shape torch.Size([16, 24, 35, 3]) matches label shape torch.Size([16, 24, 35, 3]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml + +``` +模型参数量: 103481647 +加载 PEMS-BAY 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/PEMS-BAY/2025-12-01_22-11-23/run.log +✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml + +``` +模型参数量: 103481647 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/SolarEnergy/2025-12-01_22-11-48/run.log +✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml + +``` +模型参数量: 103481647 +加载 NYCBike-OutFlow 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-OutFlow/2025-12-01_22-11-58/run.log +✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) +``` + + +### FAILED + + +### ERROR + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^^^^ +AttributeError: 'NoneType' object has no attribute 'to' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(L).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/Hainan.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(M).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector + return TWDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector + return STSGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ + self.adj = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj + match args["num_nodes"]: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml + +``` +模型参数量: 146712 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 70, in model_selector + return STID(model_config) + ^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STID/STID.py", line 13, in __init__ + self.embed_dim = model_args["embed_dim"] + ~~~~~~~~~~^^^^^^^^^^^^^ +KeyError: 'embed_dim' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml + +``` +模型参数量: 146712 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector + return STAWnet(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 29, in get_adj + return adj + ^^^ +UnboundLocalError: cannot access local variable 'adj' where it is not associated with a value +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector + return DCRNNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ + adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml + +``` +模型参数量: 3086208 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml + +``` +模型参数量: 1624752 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33/run.log +2025/12/01 21:52:33 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33 +2025/12/01 21:52:33 - Training process started + +Train Epoch 1: 0%| | 0/325 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward + x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) + ^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward + return F.linear(input, self.weight, self.bias) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: mat1 and mat2 shapes cannot be multiplied (13440x1 and 6x24) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml + +``` +模型参数量: 3086208 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml + +``` +模型参数量: 13296192 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32/run.log +2025/12/01 21:53:32 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32 +2025/12/01 21:53:32 - Training process started + +Train Epoch 1: 0%| | 0/1970 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward + x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) + ^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward + return F.linear(input, self.weight, self.bias) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: mat1 and mat2 shapes cannot be multiplied (52608x1 and 137x24) +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector + return ODEGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ + num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector + return make_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml + +``` +模型参数量: 103513539 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml + +``` +模型参数量: 103513539 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector + return STSSLModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml + +``` +模型参数量: 36678 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37/run.log +2025/12/01 22:00:37 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37 +2025/12/01 22:00:37 - Training process started + +Train Epoch 1: 0%| | 0/325 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 43, in forward + x = self.network(x) + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward + input = module(input) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 89, in forward + res = x if self.downsample is None else self.downsample(x) + ^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 554, in forward + return self._conv_forward(input, self.weight, self.bias) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward + return F.conv2d( + ^^^^^^^^^ +RuntimeError: Given groups=1, weight of size [32, 6, 1, 1], expected input[16, 1, 35, 24] to have 6 channels, but got 1 channels instead +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml + +``` +模型参数量: 235788 +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 18, in get_dataloader + return EXP_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/EXPdataloader.py", line 8, in get_dataloader + data = load_st_dataset(args["type"], args["sample"]) # [T, N, F] + ~~~~^^^^^^^^ +KeyError: 'type' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector + return EXP(model_config) + ^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ + self.horizon = args["horizon"] # 预测步长 + ~~~~^^^^^^^^^^^ +KeyError: 'horizon' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector + return DDGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml + +``` +模型参数量: 37897240 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml + +``` +模型参数量: 37897240 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16/run.log +2025/12/01 22:05:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16 +2025/12/01 22:05:16 - Training process started + +Train Epoch 1: 0%| | 0/325 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch + loss = self.loss(output, label) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward + return F.l1_loss(input, target, reduction=self.reduction) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors + return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml + +``` +模型参数量: 37897240 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml + +``` +模型参数量: 37897240 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16/run.log +2025/12/01 22:06:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16 +2025/12/01 22:06:16 - Training process started + +Train Epoch 1: 0%| | 0/1970 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch + loss = self.loss(output, label) + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward + return F.l1_loss(input, target, reduction=self.reduction) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors + return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector + return STFGNN(model_config) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ + adj = torch.tensor(get_adj(args)) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 53, in __init__ + self.input_dim = args["input_dim"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'input_dim' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector + return AGCRN(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ + self.num_node = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector + return make_nrde_model(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 84, in model_selector + return REPST(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/REPST/repst.py", line 24, in __init__ + self.word_choice = GumbelSoftmax(configs['word_num']) + ~~~~~~~^^^^^^^^^^^^ +KeyError: 'word_num' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml + +``` +模型参数量: 103481647 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml + +``` +模型参数量: 103481647 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector + return STIDGCN(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector + return PDG2Seq(model_config) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ + self.num_nodes = args["num_nodes"] + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector + return HierAttnLstm(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ + args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml + +``` +模型参数量: 4 +加载 NYCBike-InFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml + +``` +模型参数量: 4 +加载 PEMSD4 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54/run.log +2025/12/01 22:14:54 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54 +2025/12/01 22:14:54 - Training process started + +Train Epoch 1: 0%| | 0/160 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml + +``` +模型参数量: 4 +加载 METR-LA 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07/run.log +2025/12/01 22:15:07 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07 +2025/12/01 22:15:07 - Training process started + +Train Epoch 1: 0%| | 0/1285 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml + +``` +模型参数量: 4 +加载 AirQuality 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15/run.log +2025/12/01 22:15:15 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15 +2025/12/01 22:15:15 - Training process started + +Train Epoch 1: 0%| | 0/325 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml + +``` +模型参数量: 4 +加载 NYCBike-OutFlow 数据集中... +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + ^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader + return normal_loader(config, normalizer, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader + x, y = _prepare_data_with_windows(data, args, single) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features + return np.concatenate([data, time_day, time_week], axis=-1) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml + +``` +模型参数量: 4 +加载 PEMSD8 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31/run.log +2025/12/01 22:15:31 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31 +2025/12/01 22:15:31 - Training process started + +Train Epoch 1: 0%| | 0/168 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector + return ARIMA(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ + self.p = args["p"] # 自回归阶数 + ~~~~^^^^^ +KeyError: 'p' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml + +``` +模型参数量: 4 +加载 PEMSD3 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53/run.log +2025/12/01 22:15:53 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53 +2025/12/01 22:15:53 - Training process started + +Train Epoch 1: 0%| | 0/982 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml + +``` +模型参数量: 4 +加载 SolarEnergy 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18/run.log +2025/12/01 22:16:18 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18 +2025/12/01 22:16:18 - Training process started + +Train Epoch 1: 0%| | 0/1970 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector + return ARIMA(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ + self.p = args["p"] # 自回归阶数 + ~~~~^^^^^ +KeyError: 'p' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml + +``` +模型参数量: 4 +加载 PEMSD7 数据集中... +Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56/run.log +2025/12/01 22:16:56 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56 +2025/12/01 22:16:56 - Training process started + +Train Epoch 1: 0%| | 0/1058 [00:00 + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main + trainer.train() + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train + train_epoch_loss = self.train_epoch(epoch) + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch + return self._run_epoch(epoch, self.train_loader, "train") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch + output = self.model(data).to(self.device) + ^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward + drift = self.drift[n] if self.drift is not None else None + ~~~~~~~~~~^^^ +IndexError: index 1 is out of bounds for dimension 0 with size 1 +``` + +#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector + return ARIMA(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ + self.p = args["p"] # 自回归阶数 + ~~~~^^^^^ +KeyError: 'p' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector + return STMLP(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ + self.adj_mx = get_adj(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector + return MegaCRNModel(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ + num_nodes=args["num_nodes"], + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector + return gwnet(model_config) + ^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ + torch.randn(args["num_nodes"], 10, device=args["device"]) + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + +#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml + +``` + +Traceback (most recent call last): + File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in + main() + File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main + model = init.init_model(args) + ^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model + model = model_selector(args).to(device) + ^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector + return STGCNChebGraphConv(model_config) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ + gso = get_gso(args) + ^^^^^^^^^^^^^ + File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso + match args['num_nodes']: + ~~~~^^^^^^^^^^^^^ +KeyError: 'num_nodes' +``` + diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index 8bb2298..a60eddb 100755 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -103,6 +103,16 @@ class Trainer: output = self.model(data, labels=label.clone()).to(self.device) loss = self.loss(output, label) + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) + # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) diff --git a/trainer/E32Trainer.py b/trainer/E32Trainer.py index 5011131..4bad8dd 100644 --- a/trainer/E32Trainer.py +++ b/trainer/E32Trainer.py @@ -108,6 +108,16 @@ class Trainer: label = target[..., : self.args["output_dim"]] loss = self.loss(output, label) + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) + # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py index 95b7a61..e42e841 100755 --- a/trainer/PDG2SEQ_Trainer.py +++ b/trainer/PDG2SEQ_Trainer.py @@ -107,6 +107,16 @@ class Trainer: # 计算原始loss loss = self.loss(output, label) + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) + # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) diff --git a/trainer/STMLP_Trainer.py b/trainer/STMLP_Trainer.py index d1ac02a..4f3d576 100644 --- a/trainer/STMLP_Trainer.py +++ b/trainer/STMLP_Trainer.py @@ -137,11 +137,31 @@ class Trainer: # 总loss loss = loss1 + 10 * tkloss + 1 * scl + + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) else: # 普通训练模式 output, out_, _ = self.model(data) loss = self.loss(output, label) + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) + # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) diff --git a/trainer/cdeTrainer/cdetrainer.py b/trainer/cdeTrainer/cdetrainer.py index 84111fb..cda4a10 100755 --- a/trainer/cdeTrainer/cdetrainer.py +++ b/trainer/cdeTrainer/cdetrainer.py @@ -110,6 +110,16 @@ class Trainer: # 计算原始loss loss = self.loss(output, label) + # 检查output和label的shape是否一致 + if output.shape == label.shape: + print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + import sys + sys.exit(0) + else: + print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + import sys + sys.exit(1) + # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) -- 2.40.1 From c4414dd5d9ac043e64b266790058859ef80e09b4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 2 Dec 2025 09:12:09 +0800 Subject: [PATCH 03/41] fix twdgcn config --- config/TWDGCN/AirQuality.yaml | 16 +++++++++------- config/TWDGCN/BJTaxi-InFlow.yaml | 4 +++- config/TWDGCN/BJTaxi-OutFlow.yaml | 4 +++- config/TWDGCN/Hainan.yaml | 1 + config/TWDGCN/METR-LA.yaml | 8 +++++--- config/TWDGCN/NYCBike-InFlow.yaml | 12 +++++++----- config/TWDGCN/NYCBike-OutFlow.yaml | 12 +++++++----- config/TWDGCN/PEMSD3.yaml | 2 ++ config/TWDGCN/PEMSD4.yaml | 6 ++++-- config/TWDGCN/PEMSD7.yaml | 2 ++ config/TWDGCN/PEMSD8.yaml | 2 ++ config/TWDGCN/SolarEnergy.yaml | 18 ++++++++++-------- model/TWDGCN/TWDGCN.py | 7 +++---- 13 files changed, 58 insertions(+), 36 deletions(-) diff --git a/config/TWDGCN/AirQuality.yaml b/config/TWDGCN/AirQuality.yaml index b57aa40..97f31a1 100644 --- a/config/TWDGCN/AirQuality.yaml +++ b/config/TWDGCN/AirQuality.yaml @@ -5,11 +5,11 @@ basic: model: TWDGCN seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 35 @@ -19,14 +19,16 @@ data: model: cheb_order: 2 embed_dim: 12 - input_dim: 6 + horizon: 24 + input_dim: 1 num_layers: 1 - output_dim: 6 + num_nodes: 35 + output_dim: 1 rnn_units: 64 use_day: true use_week: false train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 @@ -38,10 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.001 mape_thresh: 0.0 max_grad_norm: 5 - output_dim: 6 + output_dim: 1 plot: false real_value: true seed: 10 diff --git a/config/TWDGCN/BJTaxi-InFlow.yaml b/config/TWDGCN/BJTaxi-InFlow.yaml index 1ee9c33..cf543c8 100644 --- a/config/TWDGCN/BJTaxi-InFlow.yaml +++ b/config/TWDGCN/BJTaxi-InFlow.yaml @@ -19,8 +19,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 24 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_day: true @@ -38,7 +40,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/TWDGCN/BJTaxi-OutFlow.yaml b/config/TWDGCN/BJTaxi-OutFlow.yaml index bb2933b..a9ff5f9 100644 --- a/config/TWDGCN/BJTaxi-OutFlow.yaml +++ b/config/TWDGCN/BJTaxi-OutFlow.yaml @@ -19,8 +19,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 24 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_day: true @@ -38,7 +40,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/TWDGCN/Hainan.yaml b/config/TWDGCN/Hainan.yaml index d32a56b..7774f92 100755 --- a/config/TWDGCN/Hainan.yaml +++ b/config/TWDGCN/Hainan.yaml @@ -25,6 +25,7 @@ model: horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 13 output_dim: 1 rnn_units: 32 use_day: true diff --git a/config/TWDGCN/METR-LA.yaml b/config/TWDGCN/METR-LA.yaml index 42eb251..0788a9d 100644 --- a/config/TWDGCN/METR-LA.yaml +++ b/config/TWDGCN/METR-LA.yaml @@ -8,9 +8,9 @@ data: batch_size: 16 column_wise: false days_per_week: 7 - horizon: 12 + horizon: 24 input_dim: 1 - lag: 12 + lag: 24 normalizer: std num_nodes: 207 steps_per_day: 288 @@ -19,8 +19,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 24 input_dim: 1 num_layers: 1 + num_nodes: 207 output_dim: 1 rnn_units: 64 use_day: true @@ -38,7 +40,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/TWDGCN/NYCBike-InFlow.yaml b/config/TWDGCN/NYCBike-InFlow.yaml index 060bdeb..af27b83 100644 --- a/config/TWDGCN/NYCBike-InFlow.yaml +++ b/config/TWDGCN/NYCBike-InFlow.yaml @@ -8,19 +8,21 @@ data: batch_size: 32 column_wise: false days_per_week: 7 - horizon: 24 + horizon: 12 input_dim: 1 - lag: 24 + lag: 12 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_day: true @@ -38,8 +40,8 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' - mape_thresh: 0.0 + mae_thresh: 0.0 + mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false diff --git a/config/TWDGCN/NYCBike-OutFlow.yaml b/config/TWDGCN/NYCBike-OutFlow.yaml index fd50df1..2b509a1 100644 --- a/config/TWDGCN/NYCBike-OutFlow.yaml +++ b/config/TWDGCN/NYCBike-OutFlow.yaml @@ -8,19 +8,21 @@ data: batch_size: 32 column_wise: false days_per_week: 7 - horizon: 24 + horizon: 12 input_dim: 1 - lag: 24 + lag: 12 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_day: true @@ -38,8 +40,8 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' - mape_thresh: 0.0 + mae_thresh: 0.0 + mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false diff --git a/config/TWDGCN/PEMSD3.yaml b/config/TWDGCN/PEMSD3.yaml index 7227a76..196970e 100755 --- a/config/TWDGCN/PEMSD3.yaml +++ b/config/TWDGCN/PEMSD3.yaml @@ -21,8 +21,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 358 output_dim: 1 rnn_units: 64 use_day: true diff --git a/config/TWDGCN/PEMSD4.yaml b/config/TWDGCN/PEMSD4.yaml index 22d540b..c8e14a6 100755 --- a/config/TWDGCN/PEMSD4.yaml +++ b/config/TWDGCN/PEMSD4.yaml @@ -21,8 +21,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 307 output_dim: 1 rnn_units: 64 use_day: true @@ -41,8 +43,8 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: - mape_thresh: 0.0 + mae_thresh: 0.0 + mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false diff --git a/config/TWDGCN/PEMSD7.yaml b/config/TWDGCN/PEMSD7.yaml index 6854017..2f79918 100755 --- a/config/TWDGCN/PEMSD7.yaml +++ b/config/TWDGCN/PEMSD7.yaml @@ -21,8 +21,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 883 output_dim: 1 rnn_units: 64 use_day: true diff --git a/config/TWDGCN/PEMSD8.yaml b/config/TWDGCN/PEMSD8.yaml index 857b9eb..6dac03a 100755 --- a/config/TWDGCN/PEMSD8.yaml +++ b/config/TWDGCN/PEMSD8.yaml @@ -21,8 +21,10 @@ data: model: cheb_order: 2 embed_dim: 12 + horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 170 output_dim: 1 rnn_units: 64 use_day: true diff --git a/config/TWDGCN/SolarEnergy.yaml b/config/TWDGCN/SolarEnergy.yaml index 2403f5c..859116a 100644 --- a/config/TWDGCN/SolarEnergy.yaml +++ b/config/TWDGCN/SolarEnergy.yaml @@ -5,11 +5,11 @@ basic: model: TWDGCN seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 137 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -19,14 +19,16 @@ data: model: cheb_order: 2 embed_dim: 12 - input_dim: 137 + horizon: 24 + input_dim: 1 num_layers: 1 - output_dim: 137 + num_nodes: 137 + output_dim: 1 rnn_units: 64 use_day: true use_week: false train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 @@ -38,10 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' - mape_thresh: 0.0 + mae_thresh: 0.0 + mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 plot: false real_value: true seed: 10 diff --git a/model/TWDGCN/TWDGCN.py b/model/TWDGCN/TWDGCN.py index b360b57..bdf1186 100755 --- a/model/TWDGCN/TWDGCN.py +++ b/model/TWDGCN/TWDGCN.py @@ -89,7 +89,6 @@ class TWDGCN(nn.Module): self.num_layers = args["num_layers"] self.use_day = args["use_day"] self.use_week = args["use_week"] - self.default_graph = args["default_graph"] self.node_embeddings1 = nn.Parameter( torch.randn(self.num_node, args["embed_dim"]), requires_grad=True @@ -154,17 +153,17 @@ class TWDGCN(nn.Module): node_embedding1 = self.node_embeddings1 if self.use_day: - t_i_d_data = source[..., 1] + t_i_d_data = source[..., -2] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()] node_embedding1 = node_embedding1 * T_i_D_emb if self.use_week: - d_i_w_data = source[..., 2] + d_i_w_data = source[..., -1] D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()] node_embedding1 = node_embedding1 * D_i_W_emb node_embeddings = [node_embedding1, self.node_embeddings1] - source = source[..., 0].unsqueeze(-1) + source = source[..., 0:self.input_dim] init_state1 = self.encoder1.init_hidden(source.shape[0]) output, _ = self.encoder1(source, init_state1, node_embeddings) -- 2.40.1 From e50f44347013574f8441cf4726811e8cb3863bf0 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 2 Dec 2025 09:40:24 +0800 Subject: [PATCH 04/41] add astra-pemsbay v2 --- config/AEPSA/v2_PEMS-BAY.yaml | 54 +++++++ test_results.txt | 264 ---------------------------------- 2 files changed, 54 insertions(+), 264 deletions(-) create mode 100755 config/AEPSA/v2_PEMS-BAY.yaml diff --git a/config/AEPSA/v2_PEMS-BAY.yaml b/config/AEPSA/v2_PEMS-BAY.yaml new file mode 100755 index 0000000..7b5c97e --- /dev/null +++ b/config/AEPSA/v2_PEMS-BAY.yaml @@ -0,0 +1,54 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: AEPSA_v2 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 325 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/test_results.txt b/test_results.txt index 6b5c248..6116217 100644 --- a/test_results.txt +++ b/test_results.txt @@ -734,270 +734,6 @@ Traceback (most recent call last): AttributeError: 'NoneType' object has no attribute 'to' ``` -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(L).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/Hainan.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(M).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 38, in model_selector - return TWDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TWDGCN/TWDGCN.py", line 84, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - #### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml ``` -- 2.40.1 From 140ead397528102905210d13c6a6fd0b5303a376 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 08:40:00 +0800 Subject: [PATCH 05/41] =?UTF-8?q?=E4=B8=BAmodel=E6=B7=BB=E5=8A=A0num=5Fnod?= =?UTF-8?q?es?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/AirQuality.yaml | 4 +- config/AEPSA/BJTaxi-InFlow.yaml | 3 + config/AEPSA/BJTaxi-OutFlow.yaml | 3 + config/AEPSA/NYCBike-InFlow.yaml | 7 +- config/AEPSA/NYCBike-OutFlow.yaml | 7 +- config/AGCRN/AirQuality.yaml | 14 ++- config/AGCRN/BJTaxi-InFlow.yaml | 12 ++- config/AGCRN/BJTaxi-OutFlow.yaml | 12 ++- config/AGCRN/METR-LA.yaml | 12 ++- config/AGCRN/NYCBike-InFlow.yaml | 12 ++- config/AGCRN/NYCBike-OutFlow.yaml | 12 ++- config/AGCRN/PEMSD3.yaml | 10 +- config/AGCRN/PEMSD4.yaml | 10 +- config/AGCRN/PEMSD7.yaml | 10 +- config/AGCRN/PEMSD8.yaml | 10 +- config/AGCRN/SolarEnergy.yaml | 5 +- config/ARIMA/AirQuality.yaml | 9 +- config/ARIMA/BJTaxi-InFlow.yaml | 7 +- config/ARIMA/BJTaxi-OutFlow.yaml | 7 +- config/ARIMA/Hainan.yaml | 5 +- config/ARIMA/METR-LA.yaml | 7 +- config/ARIMA/NYCBike-InFlow.yaml | 9 +- config/ARIMA/NYCBike-OutFlow.yaml | 9 +- config/ARIMA/PEMSD3.yaml | 4 +- config/ARIMA/PEMSD4.yaml | 10 +- config/ARIMA/PEMSD7(L).yaml | 1 - config/ARIMA/PEMSD7(M).yaml | 1 - config/ARIMA/PEMSD7.yaml | 2 +- config/ARIMA/PEMSD8.yaml | 8 +- config/ARIMA/SolarEnergy.yaml | 7 +- config/DCRNN/AirQuality.yaml | 8 +- config/DCRNN/BJTaxi-InFlow.yaml | 6 +- config/DCRNN/BJTaxi-OutFlow.yaml | 6 +- config/DCRNN/METR-LA.yaml | 6 +- config/DCRNN/NYCBike-InFlow.yaml | 6 +- config/DCRNN/NYCBike-OutFlow.yaml | 6 +- config/DCRNN/PEMSD3.yaml | 4 +- config/DCRNN/PEMSD4.yaml | 4 +- config/DCRNN/PEMSD7.yaml | 4 +- config/DCRNN/PEMSD8.yaml | 4 +- config/DCRNN/SolarEnergy.yaml | 5 +- config/DDGCRN/AirQuality.yaml | 9 +- config/DDGCRN/BJTaxi-InFlow.yaml | 7 +- config/DDGCRN/BJTaxi-OutFlow.yaml | 7 +- config/DDGCRN/Hainan.yaml | 5 +- config/DDGCRN/METR-LA.yaml | 1 - config/DDGCRN/NYCBike-InFlow.yaml | 9 +- config/DDGCRN/NYCBike-OutFlow.yaml | 9 +- config/DDGCRN/PEMSD3.yaml | 4 +- config/DDGCRN/PEMSD4.yaml | 4 +- config/DDGCRN/PEMSD7(L).yaml | 1 - config/DDGCRN/PEMSD7(M).yaml | 1 - config/DDGCRN/PEMSD7.yaml | 2 +- config/DDGCRN/PEMSD8.yaml | 1 - config/DDGCRN/SolarEnergy.yaml | 7 +- config/DSANET/AirQuality.yaml | 9 +- config/DSANET/BJTaxi-InFlow.yaml | 7 +- config/DSANET/BJTaxi-OutFlow.yaml | 7 +- config/DSANET/METR-LA.yaml | 7 +- config/DSANET/NYCBike-InFlow.yaml | 9 +- config/DSANET/NYCBike-OutFlow.yaml | 9 +- config/DSANET/PEMSD3.yaml | 4 +- config/DSANET/PEMSD4.yaml | 4 +- config/DSANET/PEMSD7.yaml | 4 +- config/DSANET/PEMSD8.yaml | 4 +- config/DSANET/SolarEnergy.yaml | 7 +- config/EXP/AirQuality.yaml | 9 +- config/EXP/BJTaxi-InFlow.yaml | 7 +- config/EXP/BJTaxi-OutFlow.yaml | 7 +- config/EXP/METR-LA.yaml | 7 +- config/EXP/NYCBike-InFlow.yaml | 9 +- config/EXP/NYCBike-OutFlow.yaml | 9 +- config/EXP/PEMSD3.yaml | 4 +- config/EXP/PEMSD4.yaml | 4 +- config/EXP/PEMSD7.yaml | 4 +- config/EXP/PEMSD8.yaml | 3 +- config/EXP/SD.yaml | 5 +- config/EXP/SolarEnergy.yaml | 7 +- config/EXPB/AirQuality.yaml | 9 +- config/EXPB/BJTaxi-InFlow.yaml | 7 +- config/EXPB/BJTaxi-OutFlow.yaml | 7 +- config/EXPB/METR-LA.yaml | 7 +- config/EXPB/NYCBike-InFlow.yaml | 9 +- config/EXPB/NYCBike-OutFlow.yaml | 9 +- config/EXPB/PEMSD4.yaml | 4 +- config/EXPB/SolarEnergy.yaml | 7 +- config/GWN/AirQuality.yaml | 13 ++- config/GWN/BJTaxi-InFlow.yaml | 11 +-- config/GWN/BJTaxi-OutFlow.yaml | 11 +-- config/GWN/METR-LA.yaml | 11 +-- config/GWN/NYCBike-InFlow.yaml | 11 +-- config/GWN/NYCBike-OutFlow.yaml | 11 +-- config/GWN/PEMSD3.yaml | 16 +--- config/GWN/PEMSD4.yaml | 16 +--- config/GWN/PEMSD7.yaml | 16 +--- config/GWN/PEMSD8.yaml | 16 +--- config/GWN/SolarEnergy.yaml | 10 +- config/MegaCRN/AirQuality.yaml | 9 +- config/MegaCRN/BJTaxi-InFlow.yaml | 7 +- config/MegaCRN/BJTaxi-OutFlow.yaml | 7 +- config/MegaCRN/METR-LA.yaml | 7 +- config/MegaCRN/NYCBike-InFlow.yaml | 9 +- config/MegaCRN/NYCBike-OutFlow.yaml | 9 +- config/MegaCRN/PEMSD3.yaml | 4 +- config/MegaCRN/PEMSD4.yaml | 4 +- config/MegaCRN/PEMSD7.yaml | 4 +- config/MegaCRN/PEMSD8.yaml | 4 +- config/MegaCRN/SolarEnergy.yaml | 7 +- config/NLT/AirQuality.yaml | 9 +- config/NLT/BJTaxi-InFlow.yaml | 7 +- config/NLT/BJTaxi-OutFlow.yaml | 7 +- config/NLT/METR-LA.yaml | 7 +- config/NLT/NYCBike-InFlow.yaml | 9 +- config/NLT/NYCBike-OutFlow.yaml | 9 +- config/NLT/PEMSD3.yaml | 4 +- config/NLT/PEMSD4.yaml | 4 +- config/NLT/PEMSD7.yaml | 4 +- config/NLT/PEMSD8.yaml | 4 +- config/NLT/SolarEnergy.yaml | 7 +- config/PDG2SEQ/AirQuality.yaml | 9 +- config/PDG2SEQ/BJTaxi-InFlow.yaml | 7 +- config/PDG2SEQ/BJTaxi-OutFlow.yaml | 7 +- config/PDG2SEQ/METR-LA.yaml | 7 +- config/PDG2SEQ/NYCBike-InFlow.yaml | 9 +- config/PDG2SEQ/NYCBike-OutFlow.yaml | 9 +- config/PDG2SEQ/PEMSD3.yaml | 4 +- config/PDG2SEQ/PEMSD4.yaml | 4 +- config/PDG2SEQ/PEMSD7.yaml | 2 +- config/PDG2SEQ/PEMSD8.yaml | 2 +- config/PDG2SEQ/SolarEnergy.yaml | 7 +- config/REPST/AirQuality.yaml | 4 +- config/REPST/BJTaxi-InFlow.yaml | 3 + config/REPST/BJTaxi-OutFlow.yaml | 3 + config/REPST/NYCBike-InFlow.yaml | 7 +- config/REPST/NYCBike-OutFlow.yaml | 7 +- config/STAEFormer/AirQuality.yaml | 10 +- config/STAEFormer/BJTaxi-InFlow.yaml | 6 +- config/STAEFormer/BJTaxi-OutFlow.yaml | 6 +- config/STAEFormer/METR-LA.yaml | 6 +- config/STAEFormer/NYCBike-InFlow.yaml | 10 +- config/STAEFormer/NYCBike-OutFlow.yaml | 10 +- config/STAEFormer/PEMSD3.yaml | 3 +- config/STAEFormer/PEMSD4.yaml | 3 +- config/STAEFormer/PEMSD7.yaml | 3 +- config/STAEFormer/PEMSD8.yaml | 3 +- config/STAEFormer/SolarEnergy.yaml | 6 +- config/STAWnet/AirQuality.yaml | 9 +- config/STAWnet/BJTaxi-InFlow.yaml | 7 +- config/STAWnet/BJTaxi-OutFlow.yaml | 7 +- config/STAWnet/METR-LA.yaml | 7 +- config/STAWnet/NYCBike-InFlow.yaml | 9 +- config/STAWnet/NYCBike-OutFlow.yaml | 9 +- config/STAWnet/PEMSD3.yaml | 4 +- config/STAWnet/PEMSD4.yaml | 4 +- config/STAWnet/PEMSD7.yaml | 4 +- config/STAWnet/PEMSD8.yaml | 4 +- config/STAWnet/SolarEnergy.yaml | 7 +- config/STFGNN/AirQuality.yaml | 27 +++--- config/STFGNN/BJTaxi-InFlow.yaml | 25 ++--- config/STFGNN/BJTaxi-OutFlow.yaml | 25 ++--- config/STFGNN/METR-LA.yaml | 25 ++--- config/STFGNN/NYCBike-InFlow.yaml | 27 +++--- config/STFGNN/NYCBike-OutFlow.yaml | 27 +++--- config/STFGNN/PEMSD3.yaml | 15 ++- config/STFGNN/PEMSD4.yaml | 15 ++- config/STFGNN/PEMSD7.yaml | 15 ++- config/STFGNN/PEMSD8.yaml | 15 ++- config/STFGNN/SolarEnergy.yaml | 25 ++--- config/STGCN/AirQuality.yaml | 8 +- config/STGCN/BJTaxi-InFlow.yaml | 6 +- config/STGCN/BJTaxi-OutFlow.yaml | 6 +- config/STGCN/METR-LA.yaml | 6 +- config/STGCN/NYCBike-InFlow.yaml | 6 +- config/STGCN/NYCBike-OutFlow.yaml | 6 +- config/STGCN/PEMSD3.yaml | 4 +- config/STGCN/PEMSD4.yaml | 4 +- config/STGCN/PEMSD7.yaml | 4 +- config/STGCN/PEMSD8.yaml | 4 +- config/STGCN/SolarEnergy.yaml | 6 +- config/STGNCDE/AirQuality.yaml | 17 ++-- config/STGNCDE/BJTaxi-InFlow.yaml | 15 +-- config/STGNCDE/BJTaxi-OutFlow.yaml | 15 +-- config/STGNCDE/METR-LA.yaml | 15 +-- config/STGNCDE/NYCBike-InFlow.yaml | 17 ++-- config/STGNCDE/NYCBike-OutFlow.yaml | 17 ++-- config/STGNCDE/PEMSD3.yaml | 10 +- config/STGNCDE/PEMSD4.yaml | 10 +- config/STGNCDE/PEMSD7.yaml | 10 +- config/STGNCDE/PEMSD8.yaml | 10 +- config/STGNCDE/SolarEnergy.yaml | 15 +-- config/STGNRDE/AirQuality.yaml | 17 ++-- config/STGNRDE/BJTaxi-InFlow.yaml | 15 +-- config/STGNRDE/BJTaxi-OutFlow.yaml | 15 +-- config/STGNRDE/METR-LA.yaml | 15 +-- config/STGNRDE/NYCBike-InFlow.yaml | 17 ++-- config/STGNRDE/NYCBike-OutFlow.yaml | 17 ++-- config/STGNRDE/PEMSD3.yaml | 10 +- config/STGNRDE/PEMSD4.yaml | 10 +- config/STGNRDE/PEMSD7.yaml | 10 +- config/STGNRDE/PEMSD8.yaml | 10 +- config/STGNRDE/SolarEnergy.yaml | 15 +-- config/STGODE/AirQuality.yaml | 9 +- config/STGODE/BJTaxi-InFlow.yaml | 7 +- config/STGODE/BJTaxi-OutFlow.yaml | 7 +- config/STGODE/METR-LA.yaml | 7 +- config/STGODE/NYCBike-InFlow.yaml | 9 +- config/STGODE/NYCBike-OutFlow.yaml | 9 +- config/STGODE/PEMSD3.yaml | 4 +- config/STGODE/PEMSD4.yaml | 4 +- config/STGODE/PEMSD7.yaml | 4 +- config/STGODE/PEMSD8.yaml | 4 +- config/STGODE/SolarEnergy.yaml | 7 +- config/STID/AirQuality.yaml | 6 +- config/STID/BJTaxi-InFlow.yaml | 5 +- config/STID/BJTaxi-OutFlow.yaml | 5 +- config/STID/BJTaxi_Inflow.yaml | 3 +- config/STID/BJTaxi_Outflow.yaml | 3 +- config/STID/METR-LA.yaml | 3 +- config/STID/NYCBike-InFlow.yaml | 9 +- config/STID/NYCBike-OutFlow.yaml | 9 +- config/STID/NYCBike_Inflow.yaml | 3 +- config/STID/NYCBike_Outflow.yaml | 3 +- config/STID/PEMS-BAY.yaml | 2 +- config/STID/PEMSD4.yaml | 3 +- config/STID/SolarEnergy.yaml | 3 +- config/STIDGCN/AirQuality.yaml | 9 +- config/STIDGCN/BJTaxi-InFlow.yaml | 7 +- config/STIDGCN/BJTaxi-OutFlow.yaml | 7 +- config/STIDGCN/METR-LA.yaml | 7 +- config/STIDGCN/NYCBike-InFlow.yaml | 9 +- config/STIDGCN/NYCBike-OutFlow.yaml | 9 +- config/STIDGCN/PEMSD3.yaml | 4 +- config/STIDGCN/PEMSD4.yaml | 4 +- config/STIDGCN/PEMSD7.yaml | 4 +- config/STIDGCN/PEMSD8.yaml | 4 +- config/STIDGCN/SolarEnergy.yaml | 7 +- config/STMLP/AirQuality.yaml | 9 +- config/STMLP/BJTaxi-InFlow.yaml | 7 +- config/STMLP/BJTaxi-OutFlow.yaml | 7 +- config/STMLP/METR-LA.yaml | 7 +- config/STMLP/NYCBike-InFlow.yaml | 9 +- config/STMLP/NYCBike-OutFlow.yaml | 9 +- config/STMLP/PEMSD3.yaml | 4 +- config/STMLP/PEMSD4.yaml | 4 +- config/STMLP/PEMSD7.yaml | 4 +- config/STMLP/PEMSD8.yaml | 4 +- config/STMLP/SolarEnergy.yaml | 7 +- config/STSGCN/AirQuality.yaml | 33 ++++--- config/STSGCN/BJTaxi-InFlow.yaml | 31 ++++--- config/STSGCN/BJTaxi-OutFlow.yaml | 31 ++++--- config/STSGCN/METR-LA.yaml | 31 ++++--- config/STSGCN/NYCBike-InFlow.yaml | 33 ++++--- config/STSGCN/NYCBike-OutFlow.yaml | 33 ++++--- config/STSGCN/PEMSD3.yaml | 18 +++- config/STSGCN/PEMSD4.yaml | 18 +++- config/STSGCN/PEMSD7.yaml | 18 +++- config/STSGCN/PEMSD8.yaml | 18 +++- config/STSGCN/SolarEnergy.yaml | 31 ++++--- config/ST_SSL/AirQuality.yaml | 9 +- config/ST_SSL/BJTaxi-InFlow.yaml | 7 +- config/ST_SSL/BJTaxi-OutFlow.yaml | 7 +- config/ST_SSL/METR-LA.yaml | 7 +- config/ST_SSL/NYCBike-InFlow.yaml | 9 +- config/ST_SSL/NYCBike-OutFlow.yaml | 9 +- config/ST_SSL/PEMSD3.yaml | 4 +- config/ST_SSL/PEMSD4.yaml | 4 +- config/ST_SSL/PEMSD7.yaml | 4 +- config/ST_SSL/PEMSD8.yaml | 4 +- config/ST_SSL/SolarEnergy.yaml | 7 +- config/TCN/AirQuality.yaml | 13 ++- config/TCN/BJTaxi-InFlow.yaml | 11 ++- config/TCN/BJTaxi-OutFlow.yaml | 11 ++- config/TCN/METR-LA.yaml | 11 ++- config/TCN/NYCBike-InFlow.yaml | 11 ++- config/TCN/NYCBike-OutFlow.yaml | 11 ++- config/TCN/PEMSD3.yaml | 9 +- config/TCN/PEMSD4.yaml | 9 +- config/TCN/PEMSD7.yaml | 9 +- config/TCN/PEMSD8.yaml | 9 +- config/TCN/SolarEnergy.yaml | 11 ++- config/TWDGCN/AirQuality.yaml | 8 +- config/TWDGCN/BJTaxi-InFlow.yaml | 4 +- config/TWDGCN/BJTaxi-OutFlow.yaml | 4 +- config/TWDGCN/Hainan.yaml | 6 +- config/TWDGCN/METR-LA.yaml | 4 +- config/TWDGCN/NYCBike-InFlow.yaml | 4 +- config/TWDGCN/NYCBike-OutFlow.yaml | 4 +- config/TWDGCN/PEMSD3.yaml | 3 +- config/TWDGCN/PEMSD4.yaml | 3 +- config/TWDGCN/PEMSD7(L).yaml | 1 - config/TWDGCN/PEMSD7(M).yaml | 1 - config/TWDGCN/PEMSD7.yaml | 1 - config/TWDGCN/PEMSD8.yaml | 1 - config/TWDGCN/SolarEnergy.yaml | 4 +- test_configs.py | 124 ------------------------- 295 files changed, 1514 insertions(+), 1084 deletions(-) delete mode 100644 test_configs.py diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml index d6061d9..c2d905a 100644 --- a/config/AEPSA/AirQuality.yaml +++ b/config/AEPSA/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -26,7 +26,7 @@ model: gpt_path: ./GPT-2 input_dim: 6 n_heads: 1 - num_nodes: 35 + num_nodes: 12 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/config/AEPSA/BJTaxi-InFlow.yaml b/config/AEPSA/BJTaxi-InFlow.yaml index 64b53dc..a453b38 100644 --- a/config/AEPSA/BJTaxi-InFlow.yaml +++ b/config/AEPSA/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: AEPSA seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -30,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 32 debug: false diff --git a/config/AEPSA/BJTaxi-OutFlow.yaml b/config/AEPSA/BJTaxi-OutFlow.yaml index d0cf19d..9fa0f5f 100644 --- a/config/AEPSA/BJTaxi-OutFlow.yaml +++ b/config/AEPSA/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: AEPSA seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -30,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 32 debug: false diff --git a/config/AEPSA/NYCBike-InFlow.yaml b/config/AEPSA/NYCBike-InFlow.yaml index 2384c58..b561493 100644 --- a/config/AEPSA/NYCBike-InFlow.yaml +++ b/config/AEPSA/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: AEPSA seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -24,12 +26,13 @@ model: gpt_path: ./GPT-2 input_dim: 1 n_heads: 1 - num_nodes: 1024 + num_nodes: 128 patch_len: 6 pred_len: 24 seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 32 debug: false diff --git a/config/AEPSA/NYCBike-OutFlow.yaml b/config/AEPSA/NYCBike-OutFlow.yaml index 0b3597f..5c4da71 100644 --- a/config/AEPSA/NYCBike-OutFlow.yaml +++ b/config/AEPSA/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: AEPSA seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -24,12 +26,13 @@ model: gpt_path: ./GPT-2 input_dim: 1 n_heads: 1 - num_nodes: 1024 + num_nodes: 128 patch_len: 6 pred_len: 24 seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 32 debug: false diff --git a/config/AGCRN/AirQuality.yaml b/config/AGCRN/AirQuality.yaml index b1b904b..e400582 100644 --- a/config/AGCRN/AirQuality.yaml +++ b/config/AGCRN/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 6 num_layers: 2 + num_nodes: 12 output_dim: 6 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/BJTaxi-InFlow.yaml b/config/AGCRN/BJTaxi-InFlow.yaml index 5206e36..b01c8bf 100644 --- a/config/AGCRN/BJTaxi-InFlow.yaml +++ b/config/AGCRN/BJTaxi-InFlow.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 1024 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/BJTaxi-OutFlow.yaml b/config/AGCRN/BJTaxi-OutFlow.yaml index 7b5dc61..c7b687a 100644 --- a/config/AGCRN/BJTaxi-OutFlow.yaml +++ b/config/AGCRN/BJTaxi-OutFlow.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 1024 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/METR-LA.yaml b/config/AGCRN/METR-LA.yaml index b24e57e..20eb587 100644 --- a/config/AGCRN/METR-LA.yaml +++ b/config/AGCRN/METR-LA.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 207 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/NYCBike-InFlow.yaml b/config/AGCRN/NYCBike-InFlow.yaml index c1abc45..d33cab3 100644 --- a/config/AGCRN/NYCBike-InFlow.yaml +++ b/config/AGCRN/NYCBike-InFlow.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 128 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/NYCBike-OutFlow.yaml b/config/AGCRN/NYCBike-OutFlow.yaml index 9a5a846..1e1044f 100644 --- a/config/AGCRN/NYCBike-OutFlow.yaml +++ b/config/AGCRN/NYCBike-OutFlow.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 128 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/AGCRN/PEMSD3.yaml b/config/AGCRN/PEMSD3.yaml index aa68819..53621fd 100755 --- a/config/AGCRN/PEMSD3.yaml +++ b/config/AGCRN/PEMSD3.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 358 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/AGCRN/PEMSD4.yaml b/config/AGCRN/PEMSD4.yaml index 317b767..3aa586c 100755 --- a/config/AGCRN/PEMSD4.yaml +++ b/config/AGCRN/PEMSD4.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 307 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/AGCRN/PEMSD7.yaml b/config/AGCRN/PEMSD7.yaml index 0c0d96f..5859001 100755 --- a/config/AGCRN/PEMSD7.yaml +++ b/config/AGCRN/PEMSD7.yaml @@ -24,6 +24,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 2 + num_nodes: 883 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/AGCRN/PEMSD8.yaml b/config/AGCRN/PEMSD8.yaml index 7725af7..93a4250 100755 --- a/config/AGCRN/PEMSD8.yaml +++ b/config/AGCRN/PEMSD8.yaml @@ -24,6 +24,7 @@ model: embed_dim: 2 input_dim: 1 num_layers: 2 + num_nodes: 170 output_dim: 1 rnn_units: 64 @@ -38,13 +39,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 12 weight_decay: 0 diff --git a/config/AGCRN/SolarEnergy.yaml b/config/AGCRN/SolarEnergy.yaml index 094aec9..c6a666a 100644 --- a/config/AGCRN/SolarEnergy.yaml +++ b/config/AGCRN/SolarEnergy.yaml @@ -46,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/ARIMA/AirQuality.yaml b/config/ARIMA/AirQuality.yaml index cc2885e..45ce496 100644 --- a/config/ARIMA/AirQuality.yaml +++ b/config/ARIMA/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,17 +13,20 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 6 + num_nodes: 12 output_dim: 6 p: 2 q: 1 + train: batch_size: 16 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/BJTaxi-InFlow.yaml b/config/ARIMA/BJTaxi-InFlow.yaml index 0be9d12..85254a9 100644 --- a/config/ARIMA/BJTaxi-InFlow.yaml +++ b/config/ARIMA/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,13 +17,16 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 1 + num_nodes: 1024 output_dim: 1 p: 2 q: 1 + train: batch_size: 32 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/BJTaxi-OutFlow.yaml b/config/ARIMA/BJTaxi-OutFlow.yaml index 14f9578..3c1d233 100644 --- a/config/ARIMA/BJTaxi-OutFlow.yaml +++ b/config/ARIMA/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,13 +17,16 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 1 + num_nodes: 1024 output_dim: 1 p: 2 q: 1 + train: batch_size: 32 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/Hainan.yaml b/config/ARIMA/Hainan.yaml index 90cfda2..ed1a12a 100755 --- a/config/ARIMA/Hainan.yaml +++ b/config/ARIMA/Hainan.yaml @@ -13,13 +13,14 @@ data: input_dim: 1 lag: 12 normalizer: std - num_nodes: 13 + num_nodes: 200 steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 model: input_dim: 1 + num_nodes: 200 output_dim: 1 train: @@ -39,7 +40,7 @@ train: - 40 - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: null mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/ARIMA/METR-LA.yaml b/config/ARIMA/METR-LA.yaml index 2b8598a..084e20e 100644 --- a/config/ARIMA/METR-LA.yaml +++ b/config/ARIMA/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,13 +17,16 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 1 + num_nodes: 207 output_dim: 1 p: 2 q: 1 + train: batch_size: 16 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/NYCBike-InFlow.yaml b/config/ARIMA/NYCBike-InFlow.yaml index 127493b..0da5634 100644 --- a/config/ARIMA/NYCBike-InFlow.yaml +++ b/config/ARIMA/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,17 +13,20 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 1 + num_nodes: 128 output_dim: 1 p: 2 q: 1 + train: batch_size: 32 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/NYCBike-OutFlow.yaml b/config/ARIMA/NYCBike-OutFlow.yaml index a1e3819..ddb85a2 100644 --- a/config/ARIMA/NYCBike-OutFlow.yaml +++ b/config/ARIMA/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,17 +13,20 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 1 + num_nodes: 128 output_dim: 1 p: 2 q: 1 + train: batch_size: 32 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/PEMSD3.yaml b/config/ARIMA/PEMSD3.yaml index 37bae7a..27b8605 100755 --- a/config/ARIMA/PEMSD3.yaml +++ b/config/ARIMA/PEMSD3.yaml @@ -22,6 +22,7 @@ model: d: 1 drift: true input_dim: 1 + num_nodes: 358 output_dim: 1 p: 2 q: 1 @@ -39,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/PEMSD4.yaml b/config/ARIMA/PEMSD4.yaml index 714973b..b4f0439 100755 --- a/config/ARIMA/PEMSD4.yaml +++ b/config/ARIMA/PEMSD4.yaml @@ -22,6 +22,7 @@ model: d: 1 drift: true input_dim: 1 + num_nodes: 307 output_dim: 1 p: 2 q: 1 @@ -37,13 +38,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/PEMSD7(L).yaml b/config/ARIMA/PEMSD7(L).yaml index 3eedd0d..25653c3 100755 --- a/config/ARIMA/PEMSD7(L).yaml +++ b/config/ARIMA/PEMSD7(L).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/ARIMA/PEMSD7(M).yaml b/config/ARIMA/PEMSD7(M).yaml index 5992f68..24ef88e 100755 --- a/config/ARIMA/PEMSD7(M).yaml +++ b/config/ARIMA/PEMSD7(M).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/ARIMA/PEMSD7.yaml b/config/ARIMA/PEMSD7.yaml index 6c027c6..f5a5255 100755 --- a/config/ARIMA/PEMSD7.yaml +++ b/config/ARIMA/PEMSD7.yaml @@ -22,6 +22,7 @@ model: d: 1 drift: true input_dim: 1 + num_nodes: 883 output_dim: 1 p: 2 q: 1 @@ -49,5 +50,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/PEMSD8.yaml b/config/ARIMA/PEMSD8.yaml index cbe3ed3..16e0339 100755 --- a/config/ARIMA/PEMSD8.yaml +++ b/config/ARIMA/PEMSD8.yaml @@ -22,6 +22,7 @@ model: d: 1 drift: true input_dim: 1 + num_nodes: 170 output_dim: 1 p: 2 q: 1 @@ -37,7 +38,11 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 mae_thresh: None mape_thresh: 0.001 @@ -45,5 +50,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ARIMA/SolarEnergy.yaml b/config/ARIMA/SolarEnergy.yaml index f1aeb63..9e921f0 100644 --- a/config/ARIMA/SolarEnergy.yaml +++ b/config/ARIMA/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ARIMA seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,13 +17,16 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: d: 1 drift: true input_dim: 137 + num_nodes: 137 output_dim: 137 p: 2 q: 1 + train: batch_size: 16 debug: false @@ -36,11 +40,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DCRNN/AirQuality.yaml b/config/DCRNN/AirQuality.yaml index 4922365..387b89b 100644 --- a/config/DCRNN/AirQuality.yaml +++ b/config/DCRNN/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -25,6 +25,7 @@ model: input_dim: 6 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 12 num_rnn_layers: 2 output_dim: 6 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/BJTaxi-InFlow.yaml b/config/DCRNN/BJTaxi-InFlow.yaml index b81dc86..16a3f91 100644 --- a/config/DCRNN/BJTaxi-InFlow.yaml +++ b/config/DCRNN/BJTaxi-InFlow.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 1024 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/BJTaxi-OutFlow.yaml b/config/DCRNN/BJTaxi-OutFlow.yaml index dfffb51..339e1ec 100644 --- a/config/DCRNN/BJTaxi-OutFlow.yaml +++ b/config/DCRNN/BJTaxi-OutFlow.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 1024 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/METR-LA.yaml b/config/DCRNN/METR-LA.yaml index 18fb223..15fcb44 100644 --- a/config/DCRNN/METR-LA.yaml +++ b/config/DCRNN/METR-LA.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 207 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/NYCBike-InFlow.yaml b/config/DCRNN/NYCBike-InFlow.yaml index bf7d773..e53a839 100644 --- a/config/DCRNN/NYCBike-InFlow.yaml +++ b/config/DCRNN/NYCBike-InFlow.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 128 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/NYCBike-OutFlow.yaml b/config/DCRNN/NYCBike-OutFlow.yaml index 7472459..a9ba532 100644 --- a/config/DCRNN/NYCBike-OutFlow.yaml +++ b/config/DCRNN/NYCBike-OutFlow.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 128 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 - weight_decay: 0.0001 \ No newline at end of file + weight_decay: 0.0001 diff --git a/config/DCRNN/PEMSD3.yaml b/config/DCRNN/PEMSD3.yaml index 75f7dde..7d0e4a8 100755 --- a/config/DCRNN/PEMSD3.yaml +++ b/config/DCRNN/PEMSD3.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 358 num_rnn_layers: 1 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DCRNN/PEMSD4.yaml b/config/DCRNN/PEMSD4.yaml index 803d032..ddf3156 100755 --- a/config/DCRNN/PEMSD4.yaml +++ b/config/DCRNN/PEMSD4.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 307 num_rnn_layers: 2 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.1 lr_decay_step: 10,20,40,80 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: false - seed: 10 weight_decay: 0.0001 diff --git a/config/DCRNN/PEMSD7.yaml b/config/DCRNN/PEMSD7.yaml index e940611..8b8e43b 100755 --- a/config/DCRNN/PEMSD7.yaml +++ b/config/DCRNN/PEMSD7.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 883 num_rnn_layers: 1 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DCRNN/PEMSD8.yaml b/config/DCRNN/PEMSD8.yaml index cde60d3..709b392 100755 --- a/config/DCRNN/PEMSD8.yaml +++ b/config/DCRNN/PEMSD8.yaml @@ -25,6 +25,7 @@ model: input_dim: 1 l1_decay: 0 max_diffusion_step: 2 + num_nodes: 170 num_rnn_layers: 1 output_dim: 1 rnn_units: 64 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DCRNN/SolarEnergy.yaml b/config/DCRNN/SolarEnergy.yaml index 3bc9fc2..434abdd 100644 --- a/config/DCRNN/SolarEnergy.yaml +++ b/config/DCRNN/SolarEnergy.yaml @@ -45,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/DDGCRN/AirQuality.yaml b/config/DDGCRN/AirQuality.yaml index 954728b..7090253 100644 --- a/config/DDGCRN/AirQuality.yaml +++ b/config/DDGCRN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,19 +13,22 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 6 num_layers: 1 + num_nodes: 12 output_dim: 6 rnn_units: 64 use_day: true use_week: false + train: batch_size: 16 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/BJTaxi-InFlow.yaml b/config/DDGCRN/BJTaxi-InFlow.yaml index ebd58a2..12dffa4 100644 --- a/config/DDGCRN/BJTaxi-InFlow.yaml +++ b/config/DDGCRN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,15 +17,18 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/BJTaxi-OutFlow.yaml b/config/DDGCRN/BJTaxi-OutFlow.yaml index 89a64b6..eb88c12 100644 --- a/config/DDGCRN/BJTaxi-OutFlow.yaml +++ b/config/DDGCRN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,15 +17,18 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/Hainan.yaml b/config/DDGCRN/Hainan.yaml index e22e71f..c02cd76 100755 --- a/config/DDGCRN/Hainan.yaml +++ b/config/DDGCRN/Hainan.yaml @@ -13,7 +13,7 @@ data: input_dim: 1 lag: 12 normalizer: std - num_nodes: 13 + num_nodes: 200 steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -25,6 +25,7 @@ model: horizon: 12 input_dim: 1 num_layers: 1 + num_nodes: 200 output_dim: 1 rnn_units: 32 use_day: true @@ -47,7 +48,7 @@ train: - 40 - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: null mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/DDGCRN/METR-LA.yaml b/config/DDGCRN/METR-LA.yaml index 013fb5e..c18ffb7 100755 --- a/config/DDGCRN/METR-LA.yaml +++ b/config/DDGCRN/METR-LA.yaml @@ -48,5 +48,4 @@ train: max_grad_norm: 5 output_dim: 1 plot: false - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/NYCBike-InFlow.yaml b/config/DDGCRN/NYCBike-InFlow.yaml index 30846fb..1e9f2fb 100644 --- a/config/DDGCRN/NYCBike-InFlow.yaml +++ b/config/DDGCRN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,19 +13,22 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/NYCBike-OutFlow.yaml b/config/DDGCRN/NYCBike-OutFlow.yaml index b48986f..227d00e 100644 --- a/config/DDGCRN/NYCBike-OutFlow.yaml +++ b/config/DDGCRN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,19 +13,22 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD3.yaml b/config/DDGCRN/PEMSD3.yaml index 98bebd0..3064f54 100755 --- a/config/DDGCRN/PEMSD3.yaml +++ b/config/DDGCRN/PEMSD3.yaml @@ -23,6 +23,7 @@ model: embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 358 output_dim: 1 rnn_units: 64 use_day: true @@ -41,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD4.yaml b/config/DDGCRN/PEMSD4.yaml index fcf818e..b6c4073 100755 --- a/config/DDGCRN/PEMSD4.yaml +++ b/config/DDGCRN/PEMSD4.yaml @@ -23,6 +23,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 307 output_dim: 1 rnn_units: 64 use_day: true @@ -41,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD7(L).yaml b/config/DDGCRN/PEMSD7(L).yaml index f9063ef..2ddc530 100755 --- a/config/DDGCRN/PEMSD7(L).yaml +++ b/config/DDGCRN/PEMSD7(L).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD7(M).yaml b/config/DDGCRN/PEMSD7(M).yaml index e7d87c3..a907f41 100755 --- a/config/DDGCRN/PEMSD7(M).yaml +++ b/config/DDGCRN/PEMSD7(M).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD7.yaml b/config/DDGCRN/PEMSD7.yaml index ef828cb..48c2129 100755 --- a/config/DDGCRN/PEMSD7.yaml +++ b/config/DDGCRN/PEMSD7.yaml @@ -23,6 +23,7 @@ model: embed_dim: 12 input_dim: 1 num_layers: 1 + num_nodes: 883 output_dim: 1 rnn_units: 64 use_day: true @@ -51,5 +52,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DDGCRN/PEMSD8.yaml b/config/DDGCRN/PEMSD8.yaml index d467cf8..05b469e 100755 --- a/config/DDGCRN/PEMSD8.yaml +++ b/config/DDGCRN/PEMSD8.yaml @@ -49,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 12 weight_decay: 0 diff --git a/config/DDGCRN/SolarEnergy.yaml b/config/DDGCRN/SolarEnergy.yaml index b23ea03..902aa98 100644 --- a/config/DDGCRN/SolarEnergy.yaml +++ b/config/DDGCRN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DDGCRN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,15 +17,18 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 input_dim: 137 num_layers: 1 + num_nodes: 137 output_dim: 137 rnn_units: 64 use_day: true use_week: false + train: batch_size: 16 debug: false @@ -38,11 +42,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/AirQuality.yaml b/config/DSANET/AirQuality.yaml index 2147269..f5d0b7d 100644 --- a/config/DSANET/AirQuality.yaml +++ b/config/DSANET/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 35 + num_nodes: 12 output_dim: 6 w_kernel: 1 window: 24 + train: batch_size: 16 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/BJTaxi-InFlow.yaml b/config/DSANET/BJTaxi-InFlow.yaml index 7d40eff..2f81ff6 100644 --- a/config/DSANET/BJTaxi-InFlow.yaml +++ b/config/DSANET/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 1024 + num_nodes: 1024 output_dim: 1 w_kernel: 1 window: 24 + train: batch_size: 32 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/BJTaxi-OutFlow.yaml b/config/DSANET/BJTaxi-OutFlow.yaml index 38e1e4e..dc5c1bc 100644 --- a/config/DSANET/BJTaxi-OutFlow.yaml +++ b/config/DSANET/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 1024 + num_nodes: 1024 output_dim: 1 w_kernel: 1 window: 24 + train: batch_size: 32 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/METR-LA.yaml b/config/DSANET/METR-LA.yaml index 108931b..6e920c5 100644 --- a/config/DSANET/METR-LA.yaml +++ b/config/DSANET/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 207 + num_nodes: 207 output_dim: 1 w_kernel: 1 window: 12 + train: batch_size: 16 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/NYCBike-InFlow.yaml b/config/DSANET/NYCBike-InFlow.yaml index 2534078..f3cc3f8 100644 --- a/config/DSANET/NYCBike-InFlow.yaml +++ b/config/DSANET/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 1024 + num_nodes: 128 output_dim: 1 w_kernel: 1 window: 24 + train: batch_size: 32 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/NYCBike-OutFlow.yaml b/config/DSANET/NYCBike-OutFlow.yaml index 3131ccc..eb6c116 100644 --- a/config/DSANET/NYCBike-OutFlow.yaml +++ b/config/DSANET/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 1024 + num_nodes: 128 output_dim: 1 w_kernel: 1 window: 24 + train: batch_size: 32 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/PEMSD3.yaml b/config/DSANET/PEMSD3.yaml index 38dccec..c9cb07e 100755 --- a/config/DSANET/PEMSD3.yaml +++ b/config/DSANET/PEMSD3.yaml @@ -31,6 +31,7 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 358 + num_nodes: 358 output_dim: 1 w_kernel: 1 window: 12 @@ -48,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/PEMSD4.yaml b/config/DSANET/PEMSD4.yaml index bba0aa4..6676526 100755 --- a/config/DSANET/PEMSD4.yaml +++ b/config/DSANET/PEMSD4.yaml @@ -31,6 +31,7 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 307 + num_nodes: 307 output_dim: 1 w_kernel: 1 window: 12 @@ -48,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/PEMSD7.yaml b/config/DSANET/PEMSD7.yaml index a04d2de..8d51681 100755 --- a/config/DSANET/PEMSD7.yaml +++ b/config/DSANET/PEMSD7.yaml @@ -31,6 +31,7 @@ model: n_kernels: 32 n_layers: 3 n_multiv: 883 + num_nodes: 883 output_dim: 1 w_kernel: 1 window: 12 @@ -48,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/PEMSD8.yaml b/config/DSANET/PEMSD8.yaml index 02a46bd..5f5ce7e 100755 --- a/config/DSANET/PEMSD8.yaml +++ b/config/DSANET/PEMSD8.yaml @@ -31,6 +31,7 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 170 + num_nodes: 170 output_dim: 1 w_kernel: 1 window: 12 @@ -48,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/DSANET/SolarEnergy.yaml b/config/DSANET/SolarEnergy.yaml index c2dd6eb..cc44e42 100644 --- a/config/DSANET/SolarEnergy.yaml +++ b/config/DSANET/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: DSANET seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 d_inner: 2048 @@ -29,9 +31,11 @@ model: n_kernels: 32 n_layers: 6 n_multiv: 137 + num_nodes: 137 output_dim: 137 w_kernel: 1 window: 24 + train: batch_size: 16 debug: false @@ -45,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/AirQuality.yaml b/config/EXP/AirQuality.yaml index ff7cf8d..f8dbaba 100644 --- a/config/EXP/AirQuality.yaml +++ b/config/EXP/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 6 num_layers: 1 + num_nodes: 12 output_dim: 6 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/BJTaxi-InFlow.yaml b/config/EXP/BJTaxi-InFlow.yaml index c924453..01f1e63 100644 --- a/config/EXP/BJTaxi-InFlow.yaml +++ b/config/EXP/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/BJTaxi-OutFlow.yaml b/config/EXP/BJTaxi-OutFlow.yaml index 6377e0b..4acad90 100644 --- a/config/EXP/BJTaxi-OutFlow.yaml +++ b/config/EXP/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/METR-LA.yaml b/config/EXP/METR-LA.yaml index 28ef4c1..e0ecd55 100644 --- a/config/EXP/METR-LA.yaml +++ b/config/EXP/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 12 input_dim: 1 num_layers: 1 + num_nodes: 207 output_dim: 1 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/NYCBike-InFlow.yaml b/config/EXP/NYCBike-InFlow.yaml index 34876bc..3210b53 100644 --- a/config/EXP/NYCBike-InFlow.yaml +++ b/config/EXP/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/NYCBike-OutFlow.yaml b/config/EXP/NYCBike-OutFlow.yaml index 79e06ac..28e74d2 100644 --- a/config/EXP/NYCBike-OutFlow.yaml +++ b/config/EXP/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/PEMSD3.yaml b/config/EXP/PEMSD3.yaml index 7e00b5f..def7295 100755 --- a/config/EXP/PEMSD3.yaml +++ b/config/EXP/PEMSD3.yaml @@ -26,6 +26,7 @@ model: in_len: 12 input_dim: 1 num_layers: 1 + num_nodes: 358 output_dim: 1 rnn_units: 64 top_k: 2 @@ -45,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/PEMSD4.yaml b/config/EXP/PEMSD4.yaml index 560d6da..cdb81aa 100755 --- a/config/EXP/PEMSD4.yaml +++ b/config/EXP/PEMSD4.yaml @@ -23,6 +23,7 @@ model: cycle_len: 288 in_len: 12 input_dim: 1 + num_nodes: 307 output_dim: 1 train: @@ -38,11 +39,10 @@ train: lr_decay_rate: 0.5 lr_decay_step: 5,20,40,65 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/PEMSD7.yaml b/config/EXP/PEMSD7.yaml index 029e356..9233f08 100755 --- a/config/EXP/PEMSD7.yaml +++ b/config/EXP/PEMSD7.yaml @@ -22,6 +22,7 @@ model: batch_size: 64 in_len: 12 input_dim: 1 + num_nodes: 883 output_dim: 1 train: @@ -37,11 +38,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/PEMSD8.yaml b/config/EXP/PEMSD8.yaml index 5061050..e8af5dc 100755 --- a/config/EXP/PEMSD8.yaml +++ b/config/EXP/PEMSD8.yaml @@ -43,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXP/SD.yaml b/config/EXP/SD.yaml index f61120b..493b443 100755 --- a/config/EXP/SD.yaml +++ b/config/EXP/SD.yaml @@ -13,7 +13,7 @@ data: input_dim: 1 lag: 12 normalizer: std - num_nodes: 716 + num_nodes: 307 steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -22,6 +22,7 @@ model: batch_size: 64 in_len: 12 input_dim: 1 + num_nodes: 307 output_dim: 1 train: @@ -37,7 +38,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: null mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/EXP/SolarEnergy.yaml b/config/EXP/SolarEnergy.yaml index 79e1496..de5edfe 100644 --- a/config/EXP/SolarEnergy.yaml +++ b/config/EXP/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXP seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 @@ -24,11 +26,13 @@ model: in_len: 24 input_dim: 137 num_layers: 1 + num_nodes: 137 output_dim: 137 rnn_units: 64 top_k: 2 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -42,11 +46,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/AirQuality.yaml b/config/EXPB/AirQuality.yaml index 238f7f7..4b8082c 100644 --- a/config/EXPB/AirQuality.yaml +++ b/config/EXPB/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 6 num_layers: 1 + num_nodes: 12 output_dim: 6 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/BJTaxi-InFlow.yaml b/config/EXPB/BJTaxi-InFlow.yaml index 1eb34f4..0fb7614 100644 --- a/config/EXPB/BJTaxi-InFlow.yaml +++ b/config/EXPB/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/BJTaxi-OutFlow.yaml b/config/EXPB/BJTaxi-OutFlow.yaml index b913f8f..798f5ca 100644 --- a/config/EXPB/BJTaxi-OutFlow.yaml +++ b/config/EXPB/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 1024 output_dim: 1 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/METR-LA.yaml b/config/EXPB/METR-LA.yaml index 3416970..343b252 100644 --- a/config/EXPB/METR-LA.yaml +++ b/config/EXPB/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 64 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 207 output_dim: 1 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 64 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/NYCBike-InFlow.yaml b/config/EXPB/NYCBike-InFlow.yaml index 2642db3..ac7b0d9 100644 --- a/config/EXPB/NYCBike-InFlow.yaml +++ b/config/EXPB/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/NYCBike-OutFlow.yaml b/config/EXPB/NYCBike-OutFlow.yaml index 3501ece..a4d845e 100644 --- a/config/EXPB/NYCBike-OutFlow.yaml +++ b/config/EXPB/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 128 output_dim: 1 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/PEMSD4.yaml b/config/EXPB/PEMSD4.yaml index 4e2c908..cf301c2 100755 --- a/config/EXPB/PEMSD4.yaml +++ b/config/EXPB/PEMSD4.yaml @@ -23,6 +23,7 @@ model: embed_dim: 10 input_dim: 1 num_layers: 1 + num_nodes: 307 output_dim: 1 patch_size: 3 rnn_units: 64 @@ -42,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/EXPB/SolarEnergy.yaml b/config/EXPB/SolarEnergy.yaml index 8e1a595..2d1f64e 100644 --- a/config/EXPB/SolarEnergy.yaml +++ b/config/EXPB/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: EXPB seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 10 input_dim: 137 num_layers: 1 + num_nodes: 137 output_dim: 137 patch_size: 3 rnn_units: 64 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/GWN/AirQuality.yaml b/config/GWN/AirQuality.yaml index e1d8f4b..786219f 100644 --- a/config/GWN/AirQuality.yaml +++ b/config/GWN/AirQuality.yaml @@ -13,14 +13,14 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 model: addaptadj: true - aptinit: + aptinit: null batch_size: 16 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 6 kernel_size: 2 layers: 2 + num_nodes: 12 out_dim: 12 output_dim: 6 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 16 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/BJTaxi-InFlow.yaml b/config/GWN/BJTaxi-InFlow.yaml index 54a5631..f2f10c8 100644 --- a/config/GWN/BJTaxi-InFlow.yaml +++ b/config/GWN/BJTaxi-InFlow.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 1024 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 32 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/BJTaxi-OutFlow.yaml b/config/GWN/BJTaxi-OutFlow.yaml index ea133e8..cef9af4 100644 --- a/config/GWN/BJTaxi-OutFlow.yaml +++ b/config/GWN/BJTaxi-OutFlow.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 1024 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 32 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/METR-LA.yaml b/config/GWN/METR-LA.yaml index 96faa45..9ffb5d1 100644 --- a/config/GWN/METR-LA.yaml +++ b/config/GWN/METR-LA.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 16 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 207 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 16 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/NYCBike-InFlow.yaml b/config/GWN/NYCBike-InFlow.yaml index 1f4c646..c536802 100644 --- a/config/GWN/NYCBike-InFlow.yaml +++ b/config/GWN/NYCBike-InFlow.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 128 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 32 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/NYCBike-OutFlow.yaml b/config/GWN/NYCBike-OutFlow.yaml index a73d3fc..c67790b 100644 --- a/config/GWN/NYCBike-OutFlow.yaml +++ b/config/GWN/NYCBike-OutFlow.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 @@ -31,12 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 128 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - + supports: null train: batch_size: 32 @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/GWN/PEMSD3.yaml b/config/GWN/PEMSD3.yaml index f3d78ac..9e75da7 100755 --- a/config/GWN/PEMSD3.yaml +++ b/config/GWN/PEMSD3.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 64 blocks: 4 dilation_channels: 32 @@ -31,19 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 358 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - - - - - - - - + supports: null train: batch_size: 16 @@ -58,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/GWN/PEMSD4.yaml b/config/GWN/PEMSD4.yaml index ceccee3..5435727 100755 --- a/config/GWN/PEMSD4.yaml +++ b/config/GWN/PEMSD4.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 64 blocks: 4 dilation_channels: 32 @@ -31,19 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 307 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - - - - - - - - + supports: null train: batch_size: 64 @@ -58,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/GWN/PEMSD7.yaml b/config/GWN/PEMSD7.yaml index 2cbfc62..7330998 100755 --- a/config/GWN/PEMSD7.yaml +++ b/config/GWN/PEMSD7.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 64 blocks: 4 dilation_channels: 32 @@ -31,19 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 883 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - - - - - - - - + supports: null train: batch_size: 16 @@ -58,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/GWN/PEMSD8.yaml b/config/GWN/PEMSD8.yaml index 88a5090..cebe500 100755 --- a/config/GWN/PEMSD8.yaml +++ b/config/GWN/PEMSD8.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 64 blocks: 4 dilation_channels: 32 @@ -31,19 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 170 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: - - - - - - - - + supports: null train: batch_size: 64 @@ -58,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/GWN/SolarEnergy.yaml b/config/GWN/SolarEnergy.yaml index 76110e1..afdce7a 100644 --- a/config/GWN/SolarEnergy.yaml +++ b/config/GWN/SolarEnergy.yaml @@ -20,7 +20,7 @@ data: model: addaptadj: true - aptinit: + aptinit: null batch_size: 64 blocks: 4 dilation_channels: 32 @@ -31,11 +31,12 @@ model: input_dim: 1 kernel_size: 2 layers: 2 + num_nodes: 137 out_dim: 12 output_dim: 1 residual_channels: 32 skip_channels: 256 - supports: + supports: null train: batch_size: 64 @@ -50,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/MegaCRN/AirQuality.yaml b/config/MegaCRN/AirQuality.yaml index 66583fe..c7fdfe8 100644 --- a/config/MegaCRN/AirQuality.yaml +++ b/config/MegaCRN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 12 output_dim: 6 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 16 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/BJTaxi-InFlow.yaml b/config/MegaCRN/BJTaxi-InFlow.yaml index c1e5954..b6b0fd5 100644 --- a/config/MegaCRN/BJTaxi-InFlow.yaml +++ b/config/MegaCRN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 32 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/BJTaxi-OutFlow.yaml b/config/MegaCRN/BJTaxi-OutFlow.yaml index df43640..41602a5 100644 --- a/config/MegaCRN/BJTaxi-OutFlow.yaml +++ b/config/MegaCRN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 32 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/METR-LA.yaml b/config/MegaCRN/METR-LA.yaml index 9be97b9..c3e7805 100644 --- a/config/MegaCRN/METR-LA.yaml +++ b/config/MegaCRN/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 207 output_dim: 1 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 16 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/NYCBike-InFlow.yaml b/config/MegaCRN/NYCBike-InFlow.yaml index ef35650..de90784 100644 --- a/config/MegaCRN/NYCBike-InFlow.yaml +++ b/config/MegaCRN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 32 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/NYCBike-OutFlow.yaml b/config/MegaCRN/NYCBike-OutFlow.yaml index 85465f7..ec0487b 100644 --- a/config/MegaCRN/NYCBike-OutFlow.yaml +++ b/config/MegaCRN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 32 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/PEMSD3.yaml b/config/MegaCRN/PEMSD3.yaml index 2716192..5814af0 100644 --- a/config/MegaCRN/PEMSD3.yaml +++ b/config/MegaCRN/PEMSD3.yaml @@ -25,6 +25,7 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 358 output_dim: 1 rnn_units: 64 use_curriculum_learning: true @@ -43,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/PEMSD4.yaml b/config/MegaCRN/PEMSD4.yaml index 2ed68ca..d3c06c9 100644 --- a/config/MegaCRN/PEMSD4.yaml +++ b/config/MegaCRN/PEMSD4.yaml @@ -25,6 +25,7 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 307 output_dim: 1 rnn_units: 64 use_curriculum_learning: true @@ -43,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/PEMSD7.yaml b/config/MegaCRN/PEMSD7.yaml index 47e34f4..b83d7b3 100644 --- a/config/MegaCRN/PEMSD7.yaml +++ b/config/MegaCRN/PEMSD7.yaml @@ -25,6 +25,7 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 883 output_dim: 1 rnn_units: 64 use_curriculum_learning: true @@ -43,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/PEMSD8.yaml b/config/MegaCRN/PEMSD8.yaml index aeda484..ae40736 100644 --- a/config/MegaCRN/PEMSD8.yaml +++ b/config/MegaCRN/PEMSD8.yaml @@ -25,6 +25,7 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 170 output_dim: 1 rnn_units: 64 use_curriculum_learning: true @@ -43,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/MegaCRN/SolarEnergy.yaml b/config/MegaCRN/SolarEnergy.yaml index ae10bdc..669c0c8 100644 --- a/config/MegaCRN/SolarEnergy.yaml +++ b/config/MegaCRN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: MegaCRN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 3 cl_decay_steps: 2000 @@ -23,10 +25,12 @@ model: mem_dim: 64 mem_num: 20 num_layers: 1 + num_nodes: 137 output_dim: 137 rnn_units: 64 use_curriculum_learning: true ycov_dim: 1 + train: batch_size: 16 debug: false @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/AirQuality.yaml b/config/NLT/AirQuality.yaml index e5c6a67..c6fa211 100644 --- a/config/NLT/AirQuality.yaml +++ b/config/NLT/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 6 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 12 output_dim: 6 output_window: 24 use_day: false use_week: false + train: batch_size: 16 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/BJTaxi-InFlow.yaml b/config/NLT/BJTaxi-InFlow.yaml index 8f54e18..8c918ab 100644 --- a/config/NLT/BJTaxi-InFlow.yaml +++ b/config/NLT/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 1 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 1024 output_dim: 1 output_window: 24 use_day: false use_week: false + train: batch_size: 32 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/BJTaxi-OutFlow.yaml b/config/NLT/BJTaxi-OutFlow.yaml index 5e989ba..e537d52 100644 --- a/config/NLT/BJTaxi-OutFlow.yaml +++ b/config/NLT/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 1 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 1024 output_dim: 1 output_window: 24 use_day: false use_week: false + train: batch_size: 32 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/METR-LA.yaml b/config/NLT/METR-LA.yaml index bcfc403..03601e9 100644 --- a/config/NLT/METR-LA.yaml +++ b/config/NLT/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 1 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 207 output_dim: 1 output_window: 12 use_day: false use_week: false + train: batch_size: 16 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/NYCBike-InFlow.yaml b/config/NLT/NYCBike-InFlow.yaml index b6ac09a..bde93dc 100644 --- a/config/NLT/NYCBike-InFlow.yaml +++ b/config/NLT/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 1 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 128 output_dim: 1 output_window: 24 use_day: false use_week: false + train: batch_size: 32 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/NYCBike-OutFlow.yaml b/config/NLT/NYCBike-OutFlow.yaml index 5e801b2..8c24df4 100644 --- a/config/NLT/NYCBike-OutFlow.yaml +++ b/config/NLT/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 1 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 128 output_dim: 1 output_window: 24 use_day: false use_week: false + train: batch_size: 32 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/PEMSD3.yaml b/config/NLT/PEMSD3.yaml index 7212f76..8086056 100755 --- a/config/NLT/PEMSD3.yaml +++ b/config/NLT/PEMSD3.yaml @@ -28,6 +28,7 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 358 output_dim: 1 output_window: 12 use_day: false @@ -46,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/PEMSD4.yaml b/config/NLT/PEMSD4.yaml index b924b9f..6d41a7c 100755 --- a/config/NLT/PEMSD4.yaml +++ b/config/NLT/PEMSD4.yaml @@ -28,6 +28,7 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 307 output_dim: 1 output_window: 12 use_day: false @@ -46,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/PEMSD7.yaml b/config/NLT/PEMSD7.yaml index c07708c..7a9783d 100755 --- a/config/NLT/PEMSD7.yaml +++ b/config/NLT/PEMSD7.yaml @@ -28,6 +28,7 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 883 output_dim: 1 output_window: 12 use_day: false @@ -46,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/PEMSD8.yaml b/config/NLT/PEMSD8.yaml index 9a3441f..de8494b 100755 --- a/config/NLT/PEMSD8.yaml +++ b/config/NLT/PEMSD8.yaml @@ -28,6 +28,7 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 170 output_dim: 1 output_window: 12 use_day: false @@ -46,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/NLT/SolarEnergy.yaml b/config/NLT/SolarEnergy.yaml index 9bcd7af..f9da4af 100644 --- a/config/NLT/SolarEnergy.yaml +++ b/config/NLT/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: NLT seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: embed_dim: 10 feature_dim: 137 @@ -26,10 +28,12 @@ model: natt_hops: 4 nfc: 256 num_layers: 2 + num_nodes: 137 output_dim: 137 output_window: 24 use_day: false use_week: false + train: batch_size: 16 debug: false @@ -43,11 +47,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/AirQuality.yaml b/config/PDG2SEQ/AirQuality.yaml index 27ec4a2..a2ad31c 100644 --- a/config/PDG2SEQ/AirQuality.yaml +++ b/config/PDG2SEQ/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 12 output_dim: 6 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/BJTaxi-InFlow.yaml b/config/PDG2SEQ/BJTaxi-InFlow.yaml index 5cbdf37..917b505 100644 --- a/config/PDG2SEQ/BJTaxi-InFlow.yaml +++ b/config/PDG2SEQ/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/BJTaxi-OutFlow.yaml b/config/PDG2SEQ/BJTaxi-OutFlow.yaml index f50e98a..a5ccc47 100644 --- a/config/PDG2SEQ/BJTaxi-OutFlow.yaml +++ b/config/PDG2SEQ/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 1024 output_dim: 1 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/METR-LA.yaml b/config/PDG2SEQ/METR-LA.yaml index 0a52a6c..6999a03 100644 --- a/config/PDG2SEQ/METR-LA.yaml +++ b/config/PDG2SEQ/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 207 output_dim: 1 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/NYCBike-InFlow.yaml b/config/PDG2SEQ/NYCBike-InFlow.yaml index d898dcc..56c3abe 100644 --- a/config/PDG2SEQ/NYCBike-InFlow.yaml +++ b/config/PDG2SEQ/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/NYCBike-OutFlow.yaml b/config/PDG2SEQ/NYCBike-OutFlow.yaml index 52dee49..39dc207 100644 --- a/config/PDG2SEQ/NYCBike-OutFlow.yaml +++ b/config/PDG2SEQ/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 128 output_dim: 1 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/PEMSD3.yaml b/config/PDG2SEQ/PEMSD3.yaml index 015116b..f0e6730 100755 --- a/config/PDG2SEQ/PEMSD3.yaml +++ b/config/PDG2SEQ/PEMSD3.yaml @@ -25,6 +25,7 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 358 output_dim: 1 rnn_units: 64 time_dim: 8 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/PEMSD4.yaml b/config/PDG2SEQ/PEMSD4.yaml index a4cb033..3f28b2e 100755 --- a/config/PDG2SEQ/PEMSD4.yaml +++ b/config/PDG2SEQ/PEMSD4.yaml @@ -25,6 +25,7 @@ model: lr_decay_step: 1500 lr_decay_step1: 60,75,90,120 num_layers: 1 + num_nodes: 307 output_dim: 1 rnn_units: 64 time_dim: 16 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/PEMSD7.yaml b/config/PDG2SEQ/PEMSD7.yaml index 5cd0707..4922dfc 100755 --- a/config/PDG2SEQ/PEMSD7.yaml +++ b/config/PDG2SEQ/PEMSD7.yaml @@ -25,6 +25,7 @@ model: lr_decay_step: 12000 lr_decay_step1: 80,100,120 num_layers: 1 + num_nodes: 883 output_dim: 1 rnn_units: 64 time_dim: 20 @@ -54,5 +55,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/PDG2SEQ/PEMSD8.yaml b/config/PDG2SEQ/PEMSD8.yaml index f250216..3c25095 100755 --- a/config/PDG2SEQ/PEMSD8.yaml +++ b/config/PDG2SEQ/PEMSD8.yaml @@ -25,6 +25,7 @@ model: lr_decay_step: 2000 lr_decay_step1: 50,75 num_layers: 1 + num_nodes: 170 output_dim: 1 rnn_units: 64 time_dim: 16 @@ -50,5 +51,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 12 weight_decay: 0 diff --git a/config/PDG2SEQ/SolarEnergy.yaml b/config/PDG2SEQ/SolarEnergy.yaml index 42c6d14..a04a56f 100644 --- a/config/PDG2SEQ/SolarEnergy.yaml +++ b/config/PDG2SEQ/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: PDG2SEQ seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 12 @@ -23,11 +25,13 @@ model: lr_decay_step: 10000 lr_decay_step1: 75,90,120 num_layers: 1 + num_nodes: 137 output_dim: 137 rnn_units: 64 time_dim: 8 use_day: true use_week: true + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index f192382..a40e11e 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -26,7 +26,7 @@ model: gpt_path: ./GPT-2 input_dim: 6 n_heads: 1 - num_nodes: 35 + num_nodes: 12 output_dim: 3 patch_len: 6 pred_len: 24 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml index 56ccf66..3191eba 100644 --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: REPST seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -31,6 +33,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 16 debug: false diff --git a/config/REPST/BJTaxi-OutFlow.yaml b/config/REPST/BJTaxi-OutFlow.yaml index 36dae39..2c251e6 100644 --- a/config/REPST/BJTaxi-OutFlow.yaml +++ b/config/REPST/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: REPST seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -31,6 +33,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 16 debug: false diff --git a/config/REPST/NYCBike-InFlow.yaml b/config/REPST/NYCBike-InFlow.yaml index b63b151..3ed89c8 100644 --- a/config/REPST/NYCBike-InFlow.yaml +++ b/config/REPST/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: REPST seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -24,13 +26,14 @@ model: gpt_path: ./GPT-2 input_dim: 1 n_heads: 1 - num_nodes: 1024 + num_nodes: 128 output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 16 debug: false diff --git a/config/REPST/NYCBike-OutFlow.yaml b/config/REPST/NYCBike-OutFlow.yaml index 9ab3c6d..59d4364 100644 --- a/config/REPST/NYCBike-OutFlow.yaml +++ b/config/REPST/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: REPST seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_ff: 128 d_model: 64 @@ -24,13 +26,14 @@ model: gpt_path: ./GPT-2 input_dim: 1 n_heads: 1 - num_nodes: 1024 + num_nodes: 128 output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 stride: 7 word_num: 1000 + train: batch_size: 16 debug: false diff --git a/config/STAEFormer/AirQuality.yaml b/config/STAEFormer/AirQuality.yaml index b622adc..e5f07e8 100644 --- a/config/STAEFormer/AirQuality.yaml +++ b/config/STAEFormer/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -26,13 +28,14 @@ model: input_embedding_dim: 24 num_heads: 4 num_layers: 3 - num_nodes: 35 + num_nodes: 12 out_steps: 24 output_dim: 6 spatial_embedding_dim: 0 steps_per_day: 24 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 16 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/BJTaxi-InFlow.yaml b/config/STAEFormer/BJTaxi-InFlow.yaml index 3404e8a..7eb24c1 100644 --- a/config/STAEFormer/BJTaxi-InFlow.yaml +++ b/config/STAEFormer/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -33,6 +35,7 @@ model: steps_per_day: 48 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 32 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/BJTaxi-OutFlow.yaml b/config/STAEFormer/BJTaxi-OutFlow.yaml index 76c7369..fbc5d56 100644 --- a/config/STAEFormer/BJTaxi-OutFlow.yaml +++ b/config/STAEFormer/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -33,6 +35,7 @@ model: steps_per_day: 48 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 32 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/METR-LA.yaml b/config/STAEFormer/METR-LA.yaml index e982b2b..003e50e 100644 --- a/config/STAEFormer/METR-LA.yaml +++ b/config/STAEFormer/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -33,6 +35,7 @@ model: steps_per_day: 288 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 16 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/NYCBike-InFlow.yaml b/config/STAEFormer/NYCBike-InFlow.yaml index 4f88780..a96571c 100644 --- a/config/STAEFormer/NYCBike-InFlow.yaml +++ b/config/STAEFormer/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -26,13 +28,14 @@ model: input_embedding_dim: 24 num_heads: 4 num_layers: 3 - num_nodes: 1024 + num_nodes: 128 out_steps: 24 output_dim: 1 spatial_embedding_dim: 0 steps_per_day: 48 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 32 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/NYCBike-OutFlow.yaml b/config/STAEFormer/NYCBike-OutFlow.yaml index ee13784..dc9d430 100644 --- a/config/STAEFormer/NYCBike-OutFlow.yaml +++ b/config/STAEFormer/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -26,13 +28,14 @@ model: input_embedding_dim: 24 num_heads: 4 num_layers: 3 - num_nodes: 1024 + num_nodes: 128 out_steps: 24 output_dim: 1 spatial_embedding_dim: 0 steps_per_day: 48 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 32 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/PEMSD3.yaml b/config/STAEFormer/PEMSD3.yaml index 79eb4de..7497a8b 100755 --- a/config/STAEFormer/PEMSD3.yaml +++ b/config/STAEFormer/PEMSD3.yaml @@ -49,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/PEMSD4.yaml b/config/STAEFormer/PEMSD4.yaml index a832b53..b248ffc 100755 --- a/config/STAEFormer/PEMSD4.yaml +++ b/config/STAEFormer/PEMSD4.yaml @@ -49,10 +49,9 @@ train: lr_decay_rate: 0.1 lr_decay_step: 5,20,40,70 lr_init: 0.001 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0.0003 diff --git a/config/STAEFormer/PEMSD7.yaml b/config/STAEFormer/PEMSD7.yaml index e41e643..be99282 100755 --- a/config/STAEFormer/PEMSD7.yaml +++ b/config/STAEFormer/PEMSD7.yaml @@ -49,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/PEMSD8.yaml b/config/STAEFormer/PEMSD8.yaml index dbee2c7..d9c91a9 100755 --- a/config/STAEFormer/PEMSD8.yaml +++ b/config/STAEFormer/PEMSD8.yaml @@ -49,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAEFormer/SolarEnergy.yaml b/config/STAEFormer/SolarEnergy.yaml index fafffd6..fd97a63 100644 --- a/config/STAEFormer/SolarEnergy.yaml +++ b/config/STAEFormer/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAEFormer seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: adaptive_embedding_dim: 80 dow_embedding_dim: 24 @@ -33,6 +35,7 @@ model: steps_per_day: 24 tod_embedding_dim: 24 use_mixed_proj: true + train: batch_size: 16 debug: false @@ -46,11 +49,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/AirQuality.yaml b/config/STAWnet/AirQuality.yaml index 8f9f94f..6d3e0d0 100644 --- a/config/STAWnet/AirQuality.yaml +++ b/config/STAWnet/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 12 output_dim: 6 residual_channels: 32 skip_channels: 256 + train: batch_size: 16 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/BJTaxi-InFlow.yaml b/config/STAWnet/BJTaxi-InFlow.yaml index 029930a..edd919a 100644 --- a/config/STAWnet/BJTaxi-InFlow.yaml +++ b/config/STAWnet/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 1024 output_dim: 1 residual_channels: 32 skip_channels: 256 + train: batch_size: 32 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/BJTaxi-OutFlow.yaml b/config/STAWnet/BJTaxi-OutFlow.yaml index f3856e8..e40975b 100644 --- a/config/STAWnet/BJTaxi-OutFlow.yaml +++ b/config/STAWnet/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 1024 output_dim: 1 residual_channels: 32 skip_channels: 256 + train: batch_size: 32 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/METR-LA.yaml b/config/STAWnet/METR-LA.yaml index dc84df8..d0fc158 100644 --- a/config/STAWnet/METR-LA.yaml +++ b/config/STAWnet/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 207 output_dim: 1 residual_channels: 32 skip_channels: 256 + train: batch_size: 16 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/NYCBike-InFlow.yaml b/config/STAWnet/NYCBike-InFlow.yaml index caea941..563d80c 100644 --- a/config/STAWnet/NYCBike-InFlow.yaml +++ b/config/STAWnet/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 128 output_dim: 1 residual_channels: 32 skip_channels: 256 + train: batch_size: 32 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/NYCBike-OutFlow.yaml b/config/STAWnet/NYCBike-OutFlow.yaml index 33a377e..38853aa 100644 --- a/config/STAWnet/NYCBike-OutFlow.yaml +++ b/config/STAWnet/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 128 output_dim: 1 residual_channels: 32 skip_channels: 256 + train: batch_size: 32 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/PEMSD3.yaml b/config/STAWnet/PEMSD3.yaml index 30aaddc..9b0f48f 100644 --- a/config/STAWnet/PEMSD3.yaml +++ b/config/STAWnet/PEMSD3.yaml @@ -32,6 +32,7 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 358 output_dim: 1 residual_channels: 32 skip_channels: 256 @@ -49,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/PEMSD4.yaml b/config/STAWnet/PEMSD4.yaml index b89454f..d17a5ab 100644 --- a/config/STAWnet/PEMSD4.yaml +++ b/config/STAWnet/PEMSD4.yaml @@ -32,6 +32,7 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 307 output_dim: 1 residual_channels: 32 skip_channels: 256 @@ -49,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/PEMSD7.yaml b/config/STAWnet/PEMSD7.yaml index 5c52d57..c018131 100644 --- a/config/STAWnet/PEMSD7.yaml +++ b/config/STAWnet/PEMSD7.yaml @@ -32,6 +32,7 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 883 output_dim: 1 residual_channels: 32 skip_channels: 256 @@ -49,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/PEMSD8.yaml b/config/STAWnet/PEMSD8.yaml index 52fcddf..a0f9e0a 100644 --- a/config/STAWnet/PEMSD8.yaml +++ b/config/STAWnet/PEMSD8.yaml @@ -32,6 +32,7 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 170 output_dim: 1 residual_channels: 32 skip_channels: 256 @@ -49,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STAWnet/SolarEnergy.yaml b/config/STAWnet/SolarEnergy.yaml index d4e3b0a..6a2cfb3 100644 --- a/config/STAWnet/SolarEnergy.yaml +++ b/config/STAWnet/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STAWnet seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: addaptadj: true aptonly: false @@ -30,9 +32,11 @@ model: kernel_size: 2 layers: 2 noapt: false + num_nodes: 137 output_dim: 137 residual_channels: 32 skip_channels: 256 + train: batch_size: 16 debug: false @@ -46,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/AirQuality.yaml b/config/STFGNN/AirQuality.yaml index 34ab48b..d559c4e 100644 --- a/config/STFGNN/AirQuality.yaml +++ b/config/STFGNN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,31 +13,34 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 6 mask: None + num_nodes: 12 out_layer_dim: 128 output_dim: 6 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 16 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/BJTaxi-InFlow.yaml b/config/STFGNN/BJTaxi-InFlow.yaml index ca1d078..0b9f284 100644 --- a/config/STFGNN/BJTaxi-InFlow.yaml +++ b/config/STFGNN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,27 +17,30 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 1 mask: None + num_nodes: 1024 out_layer_dim: 128 output_dim: 1 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 32 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/BJTaxi-OutFlow.yaml b/config/STFGNN/BJTaxi-OutFlow.yaml index 32e5a5c..75916dd 100644 --- a/config/STFGNN/BJTaxi-OutFlow.yaml +++ b/config/STFGNN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,27 +17,30 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 1 mask: None + num_nodes: 1024 out_layer_dim: 128 output_dim: 1 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 32 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/METR-LA.yaml b/config/STFGNN/METR-LA.yaml index 2f39be2..5553fd0 100644 --- a/config/STFGNN/METR-LA.yaml +++ b/config/STFGNN/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,27 +17,30 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 12 input_dim: 1 mask: None + num_nodes: 207 out_layer_dim: 128 output_dim: 1 spatial_emb: true temporal_emb: true window: 12 + train: batch_size: 16 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/NYCBike-InFlow.yaml b/config/STFGNN/NYCBike-InFlow.yaml index 7c123a3..0a903f9 100644 --- a/config/STFGNN/NYCBike-InFlow.yaml +++ b/config/STFGNN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,31 +13,34 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 1 mask: None + num_nodes: 128 out_layer_dim: 128 output_dim: 1 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 32 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/NYCBike-OutFlow.yaml b/config/STFGNN/NYCBike-OutFlow.yaml index d170b59..199c36f 100644 --- a/config/STFGNN/NYCBike-OutFlow.yaml +++ b/config/STFGNN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,31 +13,34 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 1 mask: None + num_nodes: 128 out_layer_dim: 128 output_dim: 1 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 32 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/PEMSD3.yaml b/config/STFGNN/PEMSD3.yaml index 35312b4..0935f72 100755 --- a/config/STFGNN/PEMSD3.yaml +++ b/config/STFGNN/PEMSD3.yaml @@ -21,10 +21,20 @@ data: model: activation: GLU first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 12 input_dim: 1 mask: None + num_nodes: 358 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -44,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/PEMSD4.yaml b/config/STFGNN/PEMSD4.yaml index 8df35b7..b47b851 100755 --- a/config/STFGNN/PEMSD4.yaml +++ b/config/STFGNN/PEMSD4.yaml @@ -21,10 +21,20 @@ data: model: activation: GLU first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 12 input_dim: 1 mask: None + num_nodes: 307 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -44,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/PEMSD7.yaml b/config/STFGNN/PEMSD7.yaml index d5338f8..b7320f6 100755 --- a/config/STFGNN/PEMSD7.yaml +++ b/config/STFGNN/PEMSD7.yaml @@ -21,10 +21,20 @@ data: model: activation: GLU first_layer_embedding_size: 64 - hidden_dims: [[32, 32, 32], [32, 32, 32], [32, 32, 32]] + hidden_dims: + - - 32 + - 32 + - 32 + - - 32 + - 32 + - 32 + - - 32 + - 32 + - 32 horizon: 12 input_dim: 1 mask: None + num_nodes: 883 out_layer_dim: 64 output_dim: 1 spatial_emb: true @@ -45,11 +55,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/PEMSD8.yaml b/config/STFGNN/PEMSD8.yaml index a1f49d4..08d4f72 100755 --- a/config/STFGNN/PEMSD8.yaml +++ b/config/STFGNN/PEMSD8.yaml @@ -21,10 +21,20 @@ data: model: activation: GLU first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 12 input_dim: 1 mask: None + num_nodes: 170 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -44,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STFGNN/SolarEnergy.yaml b/config/STFGNN/SolarEnergy.yaml index 5f5f052..2531a1a 100644 --- a/config/STFGNN/SolarEnergy.yaml +++ b/config/STFGNN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STFGNN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,27 +17,30 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 horizon: 24 input_dim: 137 mask: None + num_nodes: 137 out_layer_dim: 128 output_dim: 137 spatial_emb: true temporal_emb: true window: 24 + train: batch_size: 16 debug: false @@ -50,11 +54,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGCN/AirQuality.yaml b/config/STGCN/AirQuality.yaml index 4a5b272..4c684d3 100644 --- a/config/STGCN/AirQuality.yaml +++ b/config/STGCN/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 6 n_his: 24 + num_nodes: 12 output_dim: 6 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/BJTaxi-InFlow.yaml b/config/STGCN/BJTaxi-InFlow.yaml index 8a8cc89..6860b15 100644 --- a/config/STGCN/BJTaxi-InFlow.yaml +++ b/config/STGCN/BJTaxi-InFlow.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 1024 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/BJTaxi-OutFlow.yaml b/config/STGCN/BJTaxi-OutFlow.yaml index b0641dc..8480e65 100644 --- a/config/STGCN/BJTaxi-OutFlow.yaml +++ b/config/STGCN/BJTaxi-OutFlow.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 1024 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/METR-LA.yaml b/config/STGCN/METR-LA.yaml index e73ecbf..fe24edc 100644 --- a/config/STGCN/METR-LA.yaml +++ b/config/STGCN/METR-LA.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 207 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/NYCBike-InFlow.yaml b/config/STGCN/NYCBike-InFlow.yaml index b01cd36..29a07a6 100644 --- a/config/STGCN/NYCBike-InFlow.yaml +++ b/config/STGCN/NYCBike-InFlow.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 128 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/NYCBike-OutFlow.yaml b/config/STGCN/NYCBike-OutFlow.yaml index 8171033..1c747eb 100644 --- a/config/STGCN/NYCBike-OutFlow.yaml +++ b/config/STGCN/NYCBike-OutFlow.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 128 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGCN/PEMSD3.yaml b/config/STGCN/PEMSD3.yaml index ab254ec..6fa6c75 100755 --- a/config/STGCN/PEMSD3.yaml +++ b/config/STGCN/PEMSD3.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 12 + num_nodes: 358 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGCN/PEMSD4.yaml b/config/STGCN/PEMSD4.yaml index bc62528..596f6f2 100755 --- a/config/STGCN/PEMSD4.yaml +++ b/config/STGCN/PEMSD4.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 12 + num_nodes: 307 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGCN/PEMSD7.yaml b/config/STGCN/PEMSD7.yaml index 1f4139b..9615f85 100755 --- a/config/STGCN/PEMSD7.yaml +++ b/config/STGCN/PEMSD7.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 12 + num_nodes: 883 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGCN/PEMSD8.yaml b/config/STGCN/PEMSD8.yaml index 2dd7bce..846ad3f 100755 --- a/config/STGCN/PEMSD8.yaml +++ b/config/STGCN/PEMSD8.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 12 + num_nodes: 170 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGCN/SolarEnergy.yaml b/config/STGCN/SolarEnergy.yaml index fc9ecc7..d3e44dc 100644 --- a/config/STGCN/SolarEnergy.yaml +++ b/config/STGCN/SolarEnergy.yaml @@ -28,6 +28,7 @@ model: gso_type: sym_norm_lap input_dim: 1 n_his: 24 + num_nodes: 137 output_dim: 1 stblock_num: 2 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/STGNCDE/AirQuality.yaml b/config/STGNCDE/AirQuality.yaml index c7905f6..88bf8f0 100644 --- a/config/STGNCDE/AirQuality.yaml +++ b/config/STGNCDE/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 12 output_dim: 6 solver: rk4 type: type1 + train: batch_size: 16 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/BJTaxi-InFlow.yaml b/config/STGNCDE/BJTaxi-InFlow.yaml index 0bb3ab5..0de1907 100644 --- a/config/STGNCDE/BJTaxi-InFlow.yaml +++ b/config/STGNCDE/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 1024 output_dim: 1 solver: rk4 type: type1 + train: batch_size: 32 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/BJTaxi-OutFlow.yaml b/config/STGNCDE/BJTaxi-OutFlow.yaml index 4cc5fdb..2022544 100644 --- a/config/STGNCDE/BJTaxi-OutFlow.yaml +++ b/config/STGNCDE/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 1024 output_dim: 1 solver: rk4 type: type1 + train: batch_size: 32 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/METR-LA.yaml b/config/STGNCDE/METR-LA.yaml index 135de6f..127e6fa 100644 --- a/config/STGNCDE/METR-LA.yaml +++ b/config/STGNCDE/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 207 output_dim: 1 solver: rk4 type: type1 + train: batch_size: 16 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/NYCBike-InFlow.yaml b/config/STGNCDE/NYCBike-InFlow.yaml index a35aeb5..0d00183 100644 --- a/config/STGNCDE/NYCBike-InFlow.yaml +++ b/config/STGNCDE/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 128 output_dim: 1 solver: rk4 type: type1 + train: batch_size: 32 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/NYCBike-OutFlow.yaml b/config/STGNCDE/NYCBike-OutFlow.yaml index 98d94a2..b54a641 100644 --- a/config/STGNCDE/NYCBike-OutFlow.yaml +++ b/config/STGNCDE/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 128 output_dim: 1 solver: rk4 type: type1 + train: batch_size: 32 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/PEMSD3.yaml b/config/STGNCDE/PEMSD3.yaml index b6abea6..83e880c 100755 --- a/config/STGNCDE/PEMSD3.yaml +++ b/config/STGNCDE/PEMSD3.yaml @@ -26,6 +26,7 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 358 output_dim: 1 solver: rk4 type: type1 @@ -41,13 +42,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/PEMSD4.yaml b/config/STGNCDE/PEMSD4.yaml index 7c2d8d4..8df1d04 100755 --- a/config/STGNCDE/PEMSD4.yaml +++ b/config/STGNCDE/PEMSD4.yaml @@ -26,6 +26,7 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 307 output_dim: 1 solver: rk4 type: type1 @@ -41,13 +42,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/PEMSD7.yaml b/config/STGNCDE/PEMSD7.yaml index b11d474..ae7bbf5 100755 --- a/config/STGNCDE/PEMSD7.yaml +++ b/config/STGNCDE/PEMSD7.yaml @@ -26,6 +26,7 @@ model: hid_hid_dim: 64 input_dim: 2 num_layers: 2 + num_nodes: 883 output_dim: 1 solver: rk4 type: type1 @@ -41,13 +42,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/PEMSD8.yaml b/config/STGNCDE/PEMSD8.yaml index 452f280..22b22ab 100755 --- a/config/STGNCDE/PEMSD8.yaml +++ b/config/STGNCDE/PEMSD8.yaml @@ -26,6 +26,7 @@ model: hid_hid_dim: 64 input_dim: 2 num_layers: 2 + num_nodes: 170 output_dim: 1 solver: rk4 type: type1 @@ -41,13 +42,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNCDE/SolarEnergy.yaml b/config/STGNCDE/SolarEnergy.yaml index db78adc..134268a 100644 --- a/config/STGNCDE/SolarEnergy.yaml +++ b/config/STGNCDE/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNCDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_k: 2 embed_dim: 10 @@ -24,9 +26,11 @@ model: hid_hid_dim: 128 input_dim: 2 num_layers: 2 + num_nodes: 137 output_dim: 137 solver: rk4 type: type1 + train: batch_size: 16 debug: false @@ -39,16 +43,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/AirQuality.yaml b/config/STGNRDE/AirQuality.yaml index 340696a..7086bb6 100644 --- a/config/STGNRDE/AirQuality.yaml +++ b/config/STGNRDE/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 12 output_dim: 6 solver: rk4 + train: batch_size: 16 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/BJTaxi-InFlow.yaml b/config/STGNRDE/BJTaxi-InFlow.yaml index 891d32b..d1d6275 100644 --- a/config/STGNRDE/BJTaxi-InFlow.yaml +++ b/config/STGNRDE/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 1024 output_dim: 1 solver: rk4 + train: batch_size: 32 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/BJTaxi-OutFlow.yaml b/config/STGNRDE/BJTaxi-OutFlow.yaml index 8646195..36ec5b2 100644 --- a/config/STGNRDE/BJTaxi-OutFlow.yaml +++ b/config/STGNRDE/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 1024 output_dim: 1 solver: rk4 + train: batch_size: 32 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/METR-LA.yaml b/config/STGNRDE/METR-LA.yaml index 00b2934..7eafd79 100644 --- a/config/STGNRDE/METR-LA.yaml +++ b/config/STGNRDE/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 207 output_dim: 1 solver: rk4 + train: batch_size: 16 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/NYCBike-InFlow.yaml b/config/STGNRDE/NYCBike-InFlow.yaml index f25a31b..c204cdb 100644 --- a/config/STGNRDE/NYCBike-InFlow.yaml +++ b/config/STGNRDE/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 128 output_dim: 1 solver: rk4 + train: batch_size: 32 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/NYCBike-OutFlow.yaml b/config/STGNRDE/NYCBike-OutFlow.yaml index 8a3336d..27d11ee 100644 --- a/config/STGNRDE/NYCBike-OutFlow.yaml +++ b/config/STGNRDE/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 128 output_dim: 1 solver: rk4 + train: batch_size: 32 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/PEMSD3.yaml b/config/STGNRDE/PEMSD3.yaml index 14a4ba9..af1529a 100644 --- a/config/STGNRDE/PEMSD3.yaml +++ b/config/STGNRDE/PEMSD3.yaml @@ -30,6 +30,7 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 358 output_dim: 1 solver: rk4 @@ -44,13 +45,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/PEMSD4.yaml b/config/STGNRDE/PEMSD4.yaml index aadfe01..b0d392a 100644 --- a/config/STGNRDE/PEMSD4.yaml +++ b/config/STGNRDE/PEMSD4.yaml @@ -30,6 +30,7 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 307 output_dim: 1 solver: rk4 @@ -44,13 +45,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/PEMSD7.yaml b/config/STGNRDE/PEMSD7.yaml index 1d068f0..4b8e399 100644 --- a/config/STGNRDE/PEMSD7.yaml +++ b/config/STGNRDE/PEMSD7.yaml @@ -30,6 +30,7 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 883 output_dim: 1 solver: rk4 @@ -44,13 +45,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/PEMSD8.yaml b/config/STGNRDE/PEMSD8.yaml index 66c53bd..e765d25 100644 --- a/config/STGNRDE/PEMSD8.yaml +++ b/config/STGNRDE/PEMSD8.yaml @@ -30,6 +30,7 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 170 output_dim: 1 solver: rk4 @@ -44,13 +45,16 @@ train: loss_func: mae lr_decay: false lr_decay_rate: 0.3 - lr_decay_step: [5, 20, 40, 70] + lr_decay_step: + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGNRDE/SolarEnergy.yaml b/config/STGNRDE/SolarEnergy.yaml index c6a96b7..a9512d8 100644 --- a/config/STGNRDE/SolarEnergy.yaml +++ b/config/STGNRDE/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGNRDE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: adp_opt: false cheb_k: 3 @@ -28,8 +30,10 @@ model: interpolation: cubic model_type: rde num_layers: 2 + num_nodes: 137 output_dim: 137 solver: rk4 + train: batch_size: 16 debug: false @@ -42,16 +46,15 @@ train: lr_decay: false lr_decay_rate: 0.3 lr_decay_step: - - 5 - - 20 - - 40 - - 70 + - 5 + - 20 + - 40 + - 70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/AirQuality.yaml b/config/STGODE/AirQuality.yaml index 58bd244..14ad5b1 100644 --- a/config/STGODE/AirQuality.yaml +++ b/config/STGODE/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 6 num_features: 6 + num_nodes: 12 output_dim: 6 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 16 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/BJTaxi-InFlow.yaml b/config/STGODE/BJTaxi-InFlow.yaml index d596a5c..5637bf5 100644 --- a/config/STGODE/BJTaxi-InFlow.yaml +++ b/config/STGODE/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 1 num_features: 1 + num_nodes: 1024 output_dim: 1 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/BJTaxi-OutFlow.yaml b/config/STGODE/BJTaxi-OutFlow.yaml index f2a476f..4ee73d3 100644 --- a/config/STGODE/BJTaxi-OutFlow.yaml +++ b/config/STGODE/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 1 num_features: 1 + num_nodes: 1024 output_dim: 1 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/METR-LA.yaml b/config/STGODE/METR-LA.yaml index 4527f0c..895050f 100644 --- a/config/STGODE/METR-LA.yaml +++ b/config/STGODE/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: history: 12 horizon: 12 input_dim: 1 num_features: 1 + num_nodes: 207 output_dim: 1 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 16 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/NYCBike-InFlow.yaml b/config/STGODE/NYCBike-InFlow.yaml index 68f9e95..c8b1757 100644 --- a/config/STGODE/NYCBike-InFlow.yaml +++ b/config/STGODE/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 1 num_features: 1 + num_nodes: 128 output_dim: 1 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/NYCBike-OutFlow.yaml b/config/STGODE/NYCBike-OutFlow.yaml index a4fabdd..858c455 100644 --- a/config/STGODE/NYCBike-OutFlow.yaml +++ b/config/STGODE/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,20 +13,23 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 1 num_features: 1 + num_nodes: 128 output_dim: 1 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 32 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/PEMSD3.yaml b/config/STGODE/PEMSD3.yaml index 11cd3b3..c4b5790 100755 --- a/config/STGODE/PEMSD3.yaml +++ b/config/STGODE/PEMSD3.yaml @@ -23,6 +23,7 @@ model: horizon: 12 input_dim: 1 num_features: 1 + num_nodes: 358 output_dim: 1 sigma1: 0.1 sigma2: 10 @@ -42,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/PEMSD4.yaml b/config/STGODE/PEMSD4.yaml index db52560..ba17e0c 100755 --- a/config/STGODE/PEMSD4.yaml +++ b/config/STGODE/PEMSD4.yaml @@ -23,6 +23,7 @@ model: horizon: 12 input_dim: 1 num_features: 1 + num_nodes: 307 output_dim: 1 sigma1: 0.1 sigma2: 10 @@ -42,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/PEMSD7.yaml b/config/STGODE/PEMSD7.yaml index eadd560..d5f6442 100755 --- a/config/STGODE/PEMSD7.yaml +++ b/config/STGODE/PEMSD7.yaml @@ -23,6 +23,7 @@ model: horizon: 12 input_dim: 1 num_features: 1 + num_nodes: 883 output_dim: 1 sigma1: 0.1 sigma2: 10 @@ -42,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/PEMSD8.yaml b/config/STGODE/PEMSD8.yaml index 70fc37c..19d7e66 100755 --- a/config/STGODE/PEMSD8.yaml +++ b/config/STGODE/PEMSD8.yaml @@ -23,6 +23,7 @@ model: horizon: 12 input_dim: 1 num_features: 1 + num_nodes: 170 output_dim: 1 sigma1: 0.1 sigma2: 10 @@ -42,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STGODE/SolarEnergy.yaml b/config/STGODE/SolarEnergy.yaml index 304df8d..1275bdc 100644 --- a/config/STGODE/SolarEnergy.yaml +++ b/config/STGODE/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STGODE seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,16 +17,19 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: history: 24 horizon: 24 input_dim: 137 num_features: 137 + num_nodes: 137 output_dim: 137 sigma1: 0.1 sigma2: 10 thres1: 0.6 thres2: 0.5 + train: batch_size: 16 debug: false @@ -39,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STID/AirQuality.yaml b/config/STID/AirQuality.yaml index f8abb05..f499161 100755 --- a/config/STID/AirQuality.yaml +++ b/config/STID/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 @@ -31,7 +31,7 @@ model: input_len: 24 node_dim: 32 num_layer: 3 - num_nodes: 35 + num_nodes: 12 output_dim: 1 output_len: 24 temp_dim_diw: 32 @@ -51,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/BJTaxi-InFlow.yaml b/config/STID/BJTaxi-InFlow.yaml index 57b8e7f..59e9501 100644 --- a/config/STID/BJTaxi-InFlow.yaml +++ b/config/STID/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STID seed: 2023 + data: batch_size: 64 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 day_of_week_size: 7 @@ -35,6 +37,7 @@ model: temp_dim_diw: 32 temp_dim_tid: 32 time_of_day_size: 288 + train: batch_size: 64 debug: true @@ -48,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/BJTaxi-OutFlow.yaml b/config/STID/BJTaxi-OutFlow.yaml index 4a10026..e2fdf43 100644 --- a/config/STID/BJTaxi-OutFlow.yaml +++ b/config/STID/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STID seed: 2023 + data: batch_size: 64 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 day_of_week_size: 7 @@ -35,6 +37,7 @@ model: temp_dim_diw: 32 temp_dim_tid: 32 time_of_day_size: 288 + train: batch_size: 64 debug: true @@ -48,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/BJTaxi_Inflow.yaml b/config/STID/BJTaxi_Inflow.yaml index d29f33f..d50ba22 100755 --- a/config/STID/BJTaxi_Inflow.yaml +++ b/config/STID/BJTaxi_Inflow.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/BJTaxi_Outflow.yaml b/config/STID/BJTaxi_Outflow.yaml index 4c3b344..e2fdf43 100755 --- a/config/STID/BJTaxi_Outflow.yaml +++ b/config/STID/BJTaxi_Outflow.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/METR-LA.yaml b/config/STID/METR-LA.yaml index d79894a..7ceb4f0 100755 --- a/config/STID/METR-LA.yaml +++ b/config/STID/METR-LA.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/NYCBike-InFlow.yaml b/config/STID/NYCBike-InFlow.yaml index 81d392d..e509007 100644 --- a/config/STID/NYCBike-InFlow.yaml +++ b/config/STID/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STID seed: 2023 + data: batch_size: 64 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 day_of_week_size: 7 @@ -29,12 +31,13 @@ model: input_len: 24 node_dim: 32 num_layer: 3 - num_nodes: 1024 + num_nodes: 128 output_dim: 1 output_len: 24 temp_dim_diw: 32 temp_dim_tid: 32 time_of_day_size: 288 + train: batch_size: 64 debug: true @@ -48,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/NYCBike-OutFlow.yaml b/config/STID/NYCBike-OutFlow.yaml index dc305ce..155baf3 100644 --- a/config/STID/NYCBike-OutFlow.yaml +++ b/config/STID/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STID seed: 2023 + data: batch_size: 64 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: batch_size: 64 day_of_week_size: 7 @@ -29,12 +31,13 @@ model: input_len: 24 node_dim: 32 num_layer: 3 - num_nodes: 1024 + num_nodes: 128 output_dim: 1 output_len: 24 temp_dim_diw: 32 temp_dim_tid: 32 time_of_day_size: 288 + train: batch_size: 64 debug: true @@ -48,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/NYCBike_Inflow.yaml b/config/STID/NYCBike_Inflow.yaml index e014c20..e509007 100755 --- a/config/STID/NYCBike_Inflow.yaml +++ b/config/STID/NYCBike_Inflow.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/NYCBike_Outflow.yaml b/config/STID/NYCBike_Outflow.yaml index 634600a..155baf3 100755 --- a/config/STID/NYCBike_Outflow.yaml +++ b/config/STID/NYCBike_Outflow.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/PEMS-BAY.yaml b/config/STID/PEMS-BAY.yaml index 176c39f..561102d 100755 --- a/config/STID/PEMS-BAY.yaml +++ b/config/STID/PEMS-BAY.yaml @@ -51,7 +51,7 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: null mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/STID/PEMSD4.yaml b/config/STID/PEMSD4.yaml index ddfaf8f..84dee4d 100755 --- a/config/STID/PEMSD4.yaml +++ b/config/STID/PEMSD4.yaml @@ -50,11 +50,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STID/SolarEnergy.yaml b/config/STID/SolarEnergy.yaml index 6fa3ad6..0d787c9 100755 --- a/config/STID/SolarEnergy.yaml +++ b/config/STID/SolarEnergy.yaml @@ -51,11 +51,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 1,50,80 lr_init: 0.002 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 1 weight_decay: 0.0001 diff --git a/config/STIDGCN/AirQuality.yaml b/config/STIDGCN/AirQuality.yaml index 116549e..5af3fd4 100644 --- a/config/STIDGCN/AirQuality.yaml +++ b/config/STIDGCN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 12 output_dim: 6 + train: batch_size: 16 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/BJTaxi-InFlow.yaml b/config/STIDGCN/BJTaxi-InFlow.yaml index 06c2aa5..26f8c52 100644 --- a/config/STIDGCN/BJTaxi-InFlow.yaml +++ b/config/STIDGCN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 1024 output_dim: 1 + train: batch_size: 32 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/BJTaxi-OutFlow.yaml b/config/STIDGCN/BJTaxi-OutFlow.yaml index dae5ec9..f09fa95 100644 --- a/config/STIDGCN/BJTaxi-OutFlow.yaml +++ b/config/STIDGCN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 1024 output_dim: 1 + train: batch_size: 32 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/METR-LA.yaml b/config/STIDGCN/METR-LA.yaml index fac77f5..1022a11 100644 --- a/config/STIDGCN/METR-LA.yaml +++ b/config/STIDGCN/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 12 horizon: 12 input_dim: 3 + num_nodes: 207 output_dim: 1 + train: batch_size: 16 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/NYCBike-InFlow.yaml b/config/STIDGCN/NYCBike-InFlow.yaml index 1237d5c..df6f976 100644 --- a/config/STIDGCN/NYCBike-InFlow.yaml +++ b/config/STIDGCN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 128 output_dim: 1 + train: batch_size: 32 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/NYCBike-OutFlow.yaml b/config/STIDGCN/NYCBike-OutFlow.yaml index 3f95335..e7159a2 100644 --- a/config/STIDGCN/NYCBike-OutFlow.yaml +++ b/config/STIDGCN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 128 output_dim: 1 + train: batch_size: 32 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/PEMSD3.yaml b/config/STIDGCN/PEMSD3.yaml index d4548cc..a1fc024 100644 --- a/config/STIDGCN/PEMSD3.yaml +++ b/config/STIDGCN/PEMSD3.yaml @@ -25,6 +25,7 @@ model: history: 12 horizon: 12 input_dim: 3 + num_nodes: 358 output_dim: 1 train: @@ -40,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/PEMSD4.yaml b/config/STIDGCN/PEMSD4.yaml index 35c69be..edd4118 100644 --- a/config/STIDGCN/PEMSD4.yaml +++ b/config/STIDGCN/PEMSD4.yaml @@ -25,6 +25,7 @@ model: history: 12 horizon: 12 input_dim: 3 + num_nodes: 307 output_dim: 1 train: @@ -40,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/PEMSD7.yaml b/config/STIDGCN/PEMSD7.yaml index ba92d98..942ba1b 100644 --- a/config/STIDGCN/PEMSD7.yaml +++ b/config/STIDGCN/PEMSD7.yaml @@ -25,6 +25,7 @@ model: history: 12 horizon: 12 input_dim: 3 + num_nodes: 883 output_dim: 1 train: @@ -40,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/PEMSD8.yaml b/config/STIDGCN/PEMSD8.yaml index fca7310..071ab05 100644 --- a/config/STIDGCN/PEMSD8.yaml +++ b/config/STIDGCN/PEMSD8.yaml @@ -25,6 +25,7 @@ model: history: 12 horizon: 12 input_dim: 3 + num_nodes: 170 output_dim: 1 train: @@ -40,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STIDGCN/SolarEnergy.yaml b/config/STIDGCN/SolarEnergy.yaml index 243b0f8..e4d66ba 100644 --- a/config/STIDGCN/SolarEnergy.yaml +++ b/config/STIDGCN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STIDGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: channels: 32 dropout: 0.1 @@ -23,7 +25,9 @@ model: history: 24 horizon: 24 input_dim: 3 + num_nodes: 137 output_dim: 137 + train: batch_size: 16 debug: false @@ -37,11 +41,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STMLP/AirQuality.yaml b/config/STMLP/AirQuality.yaml index 7166af6..d3bbdb9 100644 --- a/config/STMLP/AirQuality.yaml +++ b/config/STMLP/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 12 num_split: 1 output_dim: 6 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 16 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/BJTaxi-InFlow.yaml b/config/STMLP/BJTaxi-InFlow.yaml index 9b90e0d..1dacb06 100644 --- a/config/STMLP/BJTaxi-InFlow.yaml +++ b/config/STMLP/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 1024 num_split: 1 output_dim: 1 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 32 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/BJTaxi-OutFlow.yaml b/config/STMLP/BJTaxi-OutFlow.yaml index cf499e3..5b34c75 100644 --- a/config/STMLP/BJTaxi-OutFlow.yaml +++ b/config/STMLP/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 1024 num_split: 1 output_dim: 1 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 32 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/METR-LA.yaml b/config/STMLP/METR-LA.yaml index 5313959..426dec7 100644 --- a/config/STMLP/METR-LA.yaml +++ b/config/STMLP/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 207 num_split: 1 output_dim: 1 output_window: 12 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 16 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/NYCBike-InFlow.yaml b/config/STMLP/NYCBike-InFlow.yaml index 053deab..ccbc983 100644 --- a/config/STMLP/NYCBike-InFlow.yaml +++ b/config/STMLP/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 128 num_split: 1 output_dim: 1 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 32 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/NYCBike-OutFlow.yaml b/config/STMLP/NYCBike-OutFlow.yaml index 0a920cc..a709a21 100644 --- a/config/STMLP/NYCBike-OutFlow.yaml +++ b/config/STMLP/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 128 num_split: 1 output_dim: 1 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 32 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/PEMSD3.yaml b/config/STMLP/PEMSD3.yaml index 55372aa..1bbaaad 100644 --- a/config/STMLP/PEMSD3.yaml +++ b/config/STMLP/PEMSD3.yaml @@ -32,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 358 num_split: 1 output_dim: 1 output_window: 12 @@ -58,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/PEMSD4.yaml b/config/STMLP/PEMSD4.yaml index 780bf77..f90156b 100644 --- a/config/STMLP/PEMSD4.yaml +++ b/config/STMLP/PEMSD4.yaml @@ -32,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 307 num_split: 1 output_dim: 1 output_window: 12 @@ -58,13 +59,12 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher: true teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/PEMSD7.yaml b/config/STMLP/PEMSD7.yaml index bae2ec9..6fb1de2 100644 --- a/config/STMLP/PEMSD7.yaml +++ b/config/STMLP/PEMSD7.yaml @@ -32,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 883 num_split: 1 output_dim: 1 output_window: 12 @@ -58,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/PEMSD8.yaml b/config/STMLP/PEMSD8.yaml index 6532996..5858ec5 100644 --- a/config/STMLP/PEMSD8.yaml +++ b/config/STMLP/PEMSD8.yaml @@ -32,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 170 num_split: 1 output_dim: 1 output_window: 12 @@ -58,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STMLP/SolarEnergy.yaml b/config/STMLP/SolarEnergy.yaml index ca6bc0d..627ba66 100644 --- a/config/STMLP/SolarEnergy.yaml +++ b/config/STMLP/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STMLP seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: buildA_true: true conv_channels: 32 @@ -30,6 +32,7 @@ model: layers: 3 model_type: stmlp node_dim: 40 + num_nodes: 137 num_split: 1 output_dim: 137 output_window: 24 @@ -42,6 +45,7 @@ model: tanhalpha: 3 task_level: 0 use_curriculum_learning: true + train: batch_size: 16 debug: false @@ -55,12 +59,11 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 teacher_stu: true weight_decay: 0 diff --git a/config/STSGCN/AirQuality.yaml b/config/STSGCN/AirQuality.yaml index 0b65da7..d9ccba1 100644 --- a/config/STSGCN/AirQuality.yaml +++ b/config/STSGCN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,36 +13,39 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 6 + num_nodes: 12 out_layer_dim: 128 output_dim: 6 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 16 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/BJTaxi-InFlow.yaml b/config/STSGCN/BJTaxi-InFlow.yaml index 074da90..ca137d5 100644 --- a/config/STSGCN/BJTaxi-InFlow.yaml +++ b/config/STSGCN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,32 +17,35 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 1 + num_nodes: 1024 out_layer_dim: 128 output_dim: 1 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 32 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/BJTaxi-OutFlow.yaml b/config/STSGCN/BJTaxi-OutFlow.yaml index 1b395db..460d836 100644 --- a/config/STSGCN/BJTaxi-OutFlow.yaml +++ b/config/STSGCN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,32 +17,35 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 1 + num_nodes: 1024 out_layer_dim: 128 output_dim: 1 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 32 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/METR-LA.yaml b/config/STSGCN/METR-LA.yaml index cd54e19..898a399 100644 --- a/config/STSGCN/METR-LA.yaml +++ b/config/STSGCN/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,32 +17,35 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 12 horizon: 12 input_dim: 1 + num_nodes: 207 out_layer_dim: 128 output_dim: 1 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 16 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/NYCBike-InFlow.yaml b/config/STSGCN/NYCBike-InFlow.yaml index 02752c5..143e0fa 100644 --- a/config/STSGCN/NYCBike-InFlow.yaml +++ b/config/STSGCN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,36 +13,39 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 1 + num_nodes: 128 out_layer_dim: 128 output_dim: 1 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 32 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/NYCBike-OutFlow.yaml b/config/STSGCN/NYCBike-OutFlow.yaml index 868d4d1..b78fcb0 100644 --- a/config/STSGCN/NYCBike-OutFlow.yaml +++ b/config/STSGCN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,36 +13,39 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 1 + num_nodes: 128 out_layer_dim: 128 output_dim: 1 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 32 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/PEMSD3.yaml b/config/STSGCN/PEMSD3.yaml index 0b08995..38250e8 100755 --- a/config/STSGCN/PEMSD3.yaml +++ b/config/STSGCN/PEMSD3.yaml @@ -22,10 +22,23 @@ model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 12 horizon: 12 input_dim: 1 + num_nodes: 358 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -46,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/PEMSD4.yaml b/config/STSGCN/PEMSD4.yaml index b5ef9fe..bfa33c2 100755 --- a/config/STSGCN/PEMSD4.yaml +++ b/config/STSGCN/PEMSD4.yaml @@ -22,10 +22,23 @@ model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 12 horizon: 12 input_dim: 1 + num_nodes: 307 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -46,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/PEMSD7.yaml b/config/STSGCN/PEMSD7.yaml index c13cde9..e1ef225 100755 --- a/config/STSGCN/PEMSD7.yaml +++ b/config/STSGCN/PEMSD7.yaml @@ -22,10 +22,23 @@ model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 12 horizon: 12 input_dim: 1 + num_nodes: 883 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -46,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/PEMSD8.yaml b/config/STSGCN/PEMSD8.yaml index df35209..183646d 100755 --- a/config/STSGCN/PEMSD8.yaml +++ b/config/STSGCN/PEMSD8.yaml @@ -22,10 +22,23 @@ model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 - hidden_dims: [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] + hidden_dims: + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 12 horizon: 12 input_dim: 1 + num_nodes: 170 out_layer_dim: 128 output_dim: 1 spatial_emb: true @@ -46,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/STSGCN/SolarEnergy.yaml b/config/STSGCN/SolarEnergy.yaml index 9baa783..3f25c2a 100644 --- a/config/STSGCN/SolarEnergy.yaml +++ b/config/STSGCN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: STSGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,32 +17,35 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: activation: GLU construct_type: connectivity first_layer_embedding_size: 64 hidden_dims: - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 - - - 64 - - 64 - - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 + - - 64 + - 64 + - 64 history: 24 horizon: 24 input_dim: 137 + num_nodes: 137 out_layer_dim: 128 output_dim: 137 spatial_emb: true strides: 3 temporal_emb: true use_mask: true + train: batch_size: 16 debug: false @@ -55,11 +59,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/AirQuality.yaml b/config/ST_SSL/AirQuality.yaml index a1ecc1d..e2459b3 100644 --- a/config/ST_SSL/AirQuality.yaml +++ b/config/ST_SSL/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 16 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 6 n_his: 24 nmb_prototype: 10 + num_nodes: 12 output_dim: 6 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/BJTaxi-InFlow.yaml b/config/ST_SSL/BJTaxi-InFlow.yaml index 4dbf256..6da077f 100644 --- a/config/ST_SSL/BJTaxi-InFlow.yaml +++ b/config/ST_SSL/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 1 n_his: 24 nmb_prototype: 10 + num_nodes: 1024 output_dim: 1 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/BJTaxi-OutFlow.yaml b/config/ST_SSL/BJTaxi-OutFlow.yaml index 6801117..969be92 100644 --- a/config/ST_SSL/BJTaxi-OutFlow.yaml +++ b/config/ST_SSL/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 1 n_his: 24 nmb_prototype: 10 + num_nodes: 1024 output_dim: 1 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/METR-LA.yaml b/config/ST_SSL/METR-LA.yaml index d80ccb9..7805bc4 100644 --- a/config/ST_SSL/METR-LA.yaml +++ b/config/ST_SSL/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 1 n_his: 12 nmb_prototype: 10 + num_nodes: 207 output_dim: 1 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/NYCBike-InFlow.yaml b/config/ST_SSL/NYCBike-InFlow.yaml index 3283cb6..af85c3a 100644 --- a/config/ST_SSL/NYCBike-InFlow.yaml +++ b/config/ST_SSL/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 1 n_his: 24 nmb_prototype: 10 + num_nodes: 128 output_dim: 1 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/NYCBike-OutFlow.yaml b/config/ST_SSL/NYCBike-OutFlow.yaml index 3a3e06f..e3b0c3c 100644 --- a/config/ST_SSL/NYCBike-OutFlow.yaml +++ b/config/ST_SSL/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 32 column_wise: false @@ -12,10 +13,11 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 1024 + num_nodes: 128 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 1 n_his: 24 nmb_prototype: 10 + num_nodes: 128 output_dim: 1 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 32 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/PEMSD3.yaml b/config/ST_SSL/PEMSD3.yaml index 70dc619..c9c933c 100644 --- a/config/ST_SSL/PEMSD3.yaml +++ b/config/ST_SSL/PEMSD3.yaml @@ -26,6 +26,7 @@ model: input_dim: 1 n_his: 12 nmb_prototype: 10 + num_nodes: 358 output_dim: 1 percent: 0.1 shm_temp: 0.1 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/PEMSD4.yaml b/config/ST_SSL/PEMSD4.yaml index 2aab8c8..7a53a20 100644 --- a/config/ST_SSL/PEMSD4.yaml +++ b/config/ST_SSL/PEMSD4.yaml @@ -26,6 +26,7 @@ model: input_dim: 1 n_his: 12 nmb_prototype: 10 + num_nodes: 307 output_dim: 1 percent: 0.1 shm_temp: 0.1 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/PEMSD7.yaml b/config/ST_SSL/PEMSD7.yaml index 4ec9986..75019a3 100644 --- a/config/ST_SSL/PEMSD7.yaml +++ b/config/ST_SSL/PEMSD7.yaml @@ -26,6 +26,7 @@ model: input_dim: 1 n_his: 12 nmb_prototype: 10 + num_nodes: 883 output_dim: 1 percent: 0.1 shm_temp: 0.1 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/PEMSD8.yaml b/config/ST_SSL/PEMSD8.yaml index 1129639..9aeb16f 100644 --- a/config/ST_SSL/PEMSD8.yaml +++ b/config/ST_SSL/PEMSD8.yaml @@ -26,6 +26,7 @@ model: input_dim: 1 n_his: 12 nmb_prototype: 10 + num_nodes: 170 output_dim: 1 percent: 0.1 shm_temp: 0.1 @@ -44,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/ST_SSL/SolarEnergy.yaml b/config/ST_SSL/SolarEnergy.yaml index cbe6d71..2752c44 100644 --- a/config/ST_SSL/SolarEnergy.yaml +++ b/config/ST_SSL/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: ST_SSL seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: d_model: 64 dropout: 0.1 @@ -24,10 +26,12 @@ model: input_dim: 137 n_his: 24 nmb_prototype: 10 + num_nodes: 137 output_dim: 137 percent: 0.1 shm_temp: 0.1 yita: 0.5 + train: batch_size: 16 debug: false @@ -41,11 +45,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: '' + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 137 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TCN/AirQuality.yaml b/config/TCN/AirQuality.yaml index 3b19a78..c04eeac 100644 --- a/config/TCN/AirQuality.yaml +++ b/config/TCN/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -21,10 +21,14 @@ data: model: batch_size: 16 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 6 kernel_size: 3 num_layers: 3 + num_nodes: 12 output_dim: 6 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 6 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/BJTaxi-InFlow.yaml b/config/TCN/BJTaxi-InFlow.yaml index 68cf5e3..c49b1dc 100644 --- a/config/TCN/BJTaxi-InFlow.yaml +++ b/config/TCN/BJTaxi-InFlow.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 32 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 1024 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/BJTaxi-OutFlow.yaml b/config/TCN/BJTaxi-OutFlow.yaml index 377bb5b..077c9a3 100644 --- a/config/TCN/BJTaxi-OutFlow.yaml +++ b/config/TCN/BJTaxi-OutFlow.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 32 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 1024 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/METR-LA.yaml b/config/TCN/METR-LA.yaml index 10fc585..a588114 100644 --- a/config/TCN/METR-LA.yaml +++ b/config/TCN/METR-LA.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 16 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 207 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/NYCBike-InFlow.yaml b/config/TCN/NYCBike-InFlow.yaml index cd5242c..f5005d0 100644 --- a/config/TCN/NYCBike-InFlow.yaml +++ b/config/TCN/NYCBike-InFlow.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 32 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 128 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/NYCBike-OutFlow.yaml b/config/TCN/NYCBike-OutFlow.yaml index 761431b..0d06c3b 100644 --- a/config/TCN/NYCBike-OutFlow.yaml +++ b/config/TCN/NYCBike-OutFlow.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 32 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 128 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TCN/PEMSD3.yaml b/config/TCN/PEMSD3.yaml index 47a59f2..00396e0 100755 --- a/config/TCN/PEMSD3.yaml +++ b/config/TCN/PEMSD3.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 64 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 358 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TCN/PEMSD4.yaml b/config/TCN/PEMSD4.yaml index 810859f..552df9e 100755 --- a/config/TCN/PEMSD4.yaml +++ b/config/TCN/PEMSD4.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 64 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 307 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TCN/PEMSD7.yaml b/config/TCN/PEMSD7.yaml index 6436803..ca414f2 100755 --- a/config/TCN/PEMSD7.yaml +++ b/config/TCN/PEMSD7.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 64 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 883 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TCN/PEMSD8.yaml b/config/TCN/PEMSD8.yaml index d47fdc5..e0f4761 100755 --- a/config/TCN/PEMSD8.yaml +++ b/config/TCN/PEMSD8.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 64 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 170 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TCN/SolarEnergy.yaml b/config/TCN/SolarEnergy.yaml index d620185..b12c1bf 100644 --- a/config/TCN/SolarEnergy.yaml +++ b/config/TCN/SolarEnergy.yaml @@ -21,10 +21,14 @@ data: model: batch_size: 64 dropout: 0.2 - hidden_channels: [32, 64, 32] + hidden_channels: + - 32 + - 64 + - 32 input_dim: 1 kernel_size: 3 num_layers: 3 + num_nodes: 137 output_dim: 1 train: @@ -40,11 +44,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 - weight_decay: 0 \ No newline at end of file + weight_decay: 0 diff --git a/config/TWDGCN/AirQuality.yaml b/config/TWDGCN/AirQuality.yaml index 97f31a1..6cef32e 100644 --- a/config/TWDGCN/AirQuality.yaml +++ b/config/TWDGCN/AirQuality.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 64 column_wise: false @@ -12,21 +13,23 @@ data: input_dim: 1 lag: 24 normalizer: std - num_nodes: 35 + num_nodes: 12 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 horizon: 24 input_dim: 1 num_layers: 1 - num_nodes: 35 + num_nodes: 12 output_dim: 1 rnn_units: 64 use_day: true use_week: false + train: batch_size: 64 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/BJTaxi-InFlow.yaml b/config/TWDGCN/BJTaxi-InFlow.yaml index cf543c8..63b28d1 100644 --- a/config/TWDGCN/BJTaxi-InFlow.yaml +++ b/config/TWDGCN/BJTaxi-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/BJTaxi-OutFlow.yaml b/config/TWDGCN/BJTaxi-OutFlow.yaml index a9ff5f9..feb1636 100644 --- a/config/TWDGCN/BJTaxi-OutFlow.yaml +++ b/config/TWDGCN/BJTaxi-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/Hainan.yaml b/config/TWDGCN/Hainan.yaml index 7774f92..058ca11 100755 --- a/config/TWDGCN/Hainan.yaml +++ b/config/TWDGCN/Hainan.yaml @@ -13,7 +13,7 @@ data: input_dim: 1 lag: 12 normalizer: std - num_nodes: 13 + num_nodes: 200 steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -25,7 +25,7 @@ model: horizon: 12 input_dim: 1 num_layers: 1 - num_nodes: 13 + num_nodes: 200 output_dim: 1 rnn_units: 32 use_day: true @@ -48,7 +48,7 @@ train: - 40 - 70 lr_init: 0.003 - mae_thresh: + mae_thresh: null mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 diff --git a/config/TWDGCN/METR-LA.yaml b/config/TWDGCN/METR-LA.yaml index 0788a9d..29015a7 100644 --- a/config/TWDGCN/METR-LA.yaml +++ b/config/TWDGCN/METR-LA.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 16 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 16 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/NYCBike-InFlow.yaml b/config/TWDGCN/NYCBike-InFlow.yaml index af27b83..0ca0c1d 100644 --- a/config/TWDGCN/NYCBike-InFlow.yaml +++ b/config/TWDGCN/NYCBike-InFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/NYCBike-OutFlow.yaml b/config/TWDGCN/NYCBike-OutFlow.yaml index 2b509a1..7226490 100644 --- a/config/TWDGCN/NYCBike-OutFlow.yaml +++ b/config/TWDGCN/NYCBike-OutFlow.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 32 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 32 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD3.yaml b/config/TWDGCN/PEMSD3.yaml index 196970e..04a0311 100755 --- a/config/TWDGCN/PEMSD3.yaml +++ b/config/TWDGCN/PEMSD3.yaml @@ -43,11 +43,10 @@ train: lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 lr_init: 0.003 - mae_thresh: + mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD4.yaml b/config/TWDGCN/PEMSD4.yaml index c8e14a6..131c972 100755 --- a/config/TWDGCN/PEMSD4.yaml +++ b/config/TWDGCN/PEMSD4.yaml @@ -24,7 +24,7 @@ model: horizon: 12 input_dim: 1 num_layers: 1 - num_nodes: 307 + num_nodes: 307 output_dim: 1 rnn_units: 64 use_day: true @@ -49,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD7(L).yaml b/config/TWDGCN/PEMSD7(L).yaml index f83a8ad..1cb5f73 100755 --- a/config/TWDGCN/PEMSD7(L).yaml +++ b/config/TWDGCN/PEMSD7(L).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD7(M).yaml b/config/TWDGCN/PEMSD7(M).yaml index c1bcd2f..19fc9a8 100755 --- a/config/TWDGCN/PEMSD7(M).yaml +++ b/config/TWDGCN/PEMSD7(M).yaml @@ -51,5 +51,4 @@ train: output_dim: 1 plot: true real_value: true - seed: 12 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD7.yaml b/config/TWDGCN/PEMSD7.yaml index 2f79918..6861724 100755 --- a/config/TWDGCN/PEMSD7.yaml +++ b/config/TWDGCN/PEMSD7.yaml @@ -53,5 +53,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/config/TWDGCN/PEMSD8.yaml b/config/TWDGCN/PEMSD8.yaml index 6dac03a..37cceea 100755 --- a/config/TWDGCN/PEMSD8.yaml +++ b/config/TWDGCN/PEMSD8.yaml @@ -49,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 3407 weight_decay: 0 diff --git a/config/TWDGCN/SolarEnergy.yaml b/config/TWDGCN/SolarEnergy.yaml index 859116a..da5d9db 100644 --- a/config/TWDGCN/SolarEnergy.yaml +++ b/config/TWDGCN/SolarEnergy.yaml @@ -4,6 +4,7 @@ basic: mode: train model: TWDGCN seed: 2023 + data: batch_size: 64 column_wise: false @@ -16,6 +17,7 @@ data: steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 + model: cheb_order: 2 embed_dim: 12 @@ -27,6 +29,7 @@ model: rnn_units: 64 use_day: true use_week: false + train: batch_size: 64 debug: false @@ -46,5 +49,4 @@ train: output_dim: 1 plot: false real_value: true - seed: 10 weight_decay: 0 diff --git a/test_configs.py b/test_configs.py deleted file mode 100644 index 4fb4bbe..0000000 --- a/test_configs.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import subprocess -import yaml -import time - -# 配置路径 -CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config" -RUN_SCRIPT = "/user/czzhangheng/code/TrafficWheel/run.py" -RESULTS_FILE = "/user/czzhangheng/code/TrafficWheel/test_results.txt" - -# 记录测试结果的字典 -results = { - "passed": [], - "failed": [], - "error": [] -} - -# 遍历所有yaml文件 -def find_all_yaml_files(directory): - yaml_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith(".yaml") and not file.startswith("BJTaxi"): - yaml_files.append(os.path.join(root, file)) - return yaml_files - -# 测试单个yaml文件 -def test_yaml_file(yaml_path): - print(f"\n=== Testing {yaml_path} ===") - - # 检查文件是否存在 - if not os.path.exists(yaml_path): - print(f"File not found: {yaml_path}") - return "error", f"File not found: {yaml_path}" - - # 运行测试命令 - command = ["python", RUN_SCRIPT, "--config", yaml_path] - try: - result = subprocess.run( - command, - capture_output=True, - text=True, - timeout=600 # 10分钟超时 - ) - - # 分析结果 - if result.returncode == 0: - if "Test passed" in result.stdout: - print(f"✓ PASSED: {yaml_path}") - return "passed", result.stdout.strip() - else: - print(f"✗ FAILED: {yaml_path}") - return "failed", result.stdout.strip() + "\n" + result.stderr.strip() - else: - print(f"✗ ERROR: {yaml_path}") - return "error", result.stdout.strip() + "\n" + result.stderr.strip() - except subprocess.TimeoutExpired: - print(f"✗ TIMEOUT: {yaml_path}") - return "error", "Timeout after 10 minutes" - except Exception as e: - print(f"✗ EXCEPTION: {yaml_path}") - return "error", str(e) - -# 生成测试报告 -def generate_report(results): - total = len(results["passed"]) + len(results["failed"]) + len(results["error"]) - - report = f"""# 测试报告 - -## 测试概述 -- 测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')} -- 总测试文件数: {total} -- 通过: {len(results['passed'])} -- 失败: {len(results['failed'])} -- 错误: {len(results['error'])} - -## 通过的配置文件 -""" - - for file_path, output in results["passed"]: - report += f"- ✅ {file_path}\n" - - report += "\n## 失败的配置文件\n" - for file_path, output in results["failed"]: - report += f"- ❌ {file_path}\n" - - report += "\n## 出错的配置文件\n" - for file_path, output in results["error"]: - report += f"- ⚠️ {file_path}\n" - - report += "\n## 详细输出\n" - - for status, files in results.items(): - report += f"\n### {status.upper()}\n\n" - for file_path, output in files: - report += f"#### {file_path}\n\n```\n{output}\n```\n\n" - - return report - -# 主函数 -def main(): - # 找到所有符合条件的yaml文件 - yaml_files = find_all_yaml_files(CONFIG_DIR) - print(f"Found {len(yaml_files)} yaml files to test") - - # 测试每个文件 - for yaml_file in yaml_files: - status, output = test_yaml_file(yaml_file) - results[status].append((yaml_file, output)) - - # 生成并保存报告 - report = generate_report(results) - with open(RESULTS_FILE, "w") as f: - f.write(report) - - print(f"\n=== Test Results ===") - print(f"Total: {len(yaml_files)}") - print(f"Passed: {len(results['passed'])}") - print(f"Failed: {len(results['failed'])}") - print(f"Error: {len(results['error'])}") - print(f"Report saved to: {RESULTS_FILE}") - -if __name__ == "__main__": - main() \ No newline at end of file -- 2.40.1 From a9313390ac3062be453f52d416c74d33ca80fff4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 12:05:02 +0800 Subject: [PATCH 06/41] =?UTF-8?q?=E9=80=82=E9=85=8DGraphWaveNet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/GWN/AirQuality.yaml | 23 +-- config/GWN/BJTaxi-InFlow.yaml | 12 +- config/GWN/BJTaxi-OutFlow.yaml | 9 +- config/GWN/METR-LA.yaml | 15 +- config/GWN/NYCBike-InFlow.yaml | 9 +- config/GWN/NYCBike-OutFlow.yaml | 15 +- config/GWN/PEMS-BAY.yaml | 61 +++++++ config/GWN/PEMSD3.yaml | 2 +- config/GWN/PEMSD4.yaml | 2 +- config/GWN/PEMSD7.yaml | 2 +- config/GWN/PEMSD8.yaml | 2 +- config/GWN/SolarEnergy.yaml | 11 +- config/tmp.py | 234 -------------------------- model/GWN/GraphWaveNet.py | 261 +++++++++++++--------------- model/GWN/GraphWaveNet_bk.py | 290 ++++++++++++-------------------- run_tests.sh | 95 +++++++++++ trainer/Trainer.py | 8 + 17 files changed, 448 insertions(+), 603 deletions(-) create mode 100644 config/GWN/PEMS-BAY.yaml delete mode 100644 config/tmp.py create mode 100755 run_tests.sh diff --git a/config/GWN/AirQuality.yaml b/config/GWN/AirQuality.yaml index 786219f..c8e57d5 100644 --- a/config/GWN/AirQuality.yaml +++ b/config/GWN/AirQuality.yaml @@ -6,40 +6,41 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 + num_nodes: 35 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 16 + batch_size: 64 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 - input_dim: 6 + in_dim: 1 + input_dim: 1 kernel_size: 2 - layers: 2 - num_nodes: 12 - out_dim: 12 - output_dim: 6 + layers: 4 + num_nodes: 35 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 @@ -54,7 +55,7 @@ train: mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 - output_dim: 6 + output_dim: 1 plot: false real_value: true weight_decay: 0 diff --git a/config/GWN/BJTaxi-InFlow.yaml b/config/GWN/BJTaxi-InFlow.yaml index f2f10c8..8f4de85 100644 --- a/config/GWN/BJTaxi-InFlow.yaml +++ b/config/GWN/BJTaxi-InFlow.yaml @@ -20,24 +20,26 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 32 + batch_size: 16 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 1024 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null + train: batch_size: 32 debug: false diff --git a/config/GWN/BJTaxi-OutFlow.yaml b/config/GWN/BJTaxi-OutFlow.yaml index cef9af4..f86270e 100644 --- a/config/GWN/BJTaxi-OutFlow.yaml +++ b/config/GWN/BJTaxi-OutFlow.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 1024 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/GWN/METR-LA.yaml b/config/GWN/METR-LA.yaml index 9ffb5d1..ef38574 100644 --- a/config/GWN/METR-LA.yaml +++ b/config/GWN/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -20,26 +20,27 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 16 + batch_size: 64 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 207 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/NYCBike-InFlow.yaml b/config/GWN/NYCBike-InFlow.yaml index c536802..a85e36c 100644 --- a/config/GWN/NYCBike-InFlow.yaml +++ b/config/GWN/NYCBike-InFlow.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 128 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/GWN/NYCBike-OutFlow.yaml b/config/GWN/NYCBike-OutFlow.yaml index c67790b..3ef3c8f 100644 --- a/config/GWN/NYCBike-OutFlow.yaml +++ b/config/GWN/NYCBike-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -20,26 +20,27 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 32 + batch_size: 16 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 128 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/PEMS-BAY.yaml b/config/GWN/PEMS-BAY.yaml new file mode 100644 index 0000000..3dc7acd --- /dev/null +++ b/config/GWN/PEMS-BAY.yaml @@ -0,0 +1,61 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: GWN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + addaptadj: true + apt_size: 10 + aptinit: null + batch_size: 64 + blocks: 4 + dilation_channels: 32 + dropout: 0.3 + do_graph_conv: True + end_channels: 512 + gcn_bool: true + in_dim: 1 + input_dim: 1 + kernel_size: 2 + layers: 4 + num_nodes: 325 + out_dim: 24 + residual_channels: 32 + skip_channels: 256 + supports: null + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: 0.0 + mape_thresh: 0.0 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/GWN/PEMSD3.yaml b/config/GWN/PEMSD3.yaml index 9e75da7..9194d3d 100755 --- a/config/GWN/PEMSD3.yaml +++ b/config/GWN/PEMSD3.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD4.yaml b/config/GWN/PEMSD4.yaml index 5435727..ab6f18e 100755 --- a/config/GWN/PEMSD4.yaml +++ b/config/GWN/PEMSD4.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD7.yaml b/config/GWN/PEMSD7.yaml index 7330998..4d82415 100755 --- a/config/GWN/PEMSD7.yaml +++ b/config/GWN/PEMSD7.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD8.yaml b/config/GWN/PEMSD8.yaml index cebe500..26d0de8 100755 --- a/config/GWN/PEMSD8.yaml +++ b/config/GWN/PEMSD8.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/SolarEnergy.yaml b/config/GWN/SolarEnergy.yaml index afdce7a..cd1d043 100644 --- a/config/GWN/SolarEnergy.yaml +++ b/config/GWN/SolarEnergy.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 64 + batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 137 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/tmp.py b/config/tmp.py deleted file mode 100644 index 17cbe0b..0000000 --- a/config/tmp.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -import os -from collections import defaultdict -from ruamel.yaml import YAML -from ruamel.yaml.comments import CommentedMap - -yaml = YAML() -yaml.preserve_quotes = True -yaml.indent(mapping=2, sequence=4, offset=2) - -# 允许的 data keys -DATA_ALLOWED_KEYS = { - "lag", - "horizon", - "num_nodes", - "steps_per_day", - "days_per_week", - "test_ratio", - "val_ratio", - "batch_size", - "input_dim", - "column_wise", - "normalizer", -} - -# 全局默认值 -GLOBAL_DEFAULTS = { - "lag": 24, - "horizon": 24, - "num_nodes": 1, - "steps_per_day": 24, - "days_per_week": 7, - "test_ratio": 0.2, - "val_ratio": 0.2, - "batch_size": 16, - "input_dim": 1, - "column_wise": False, - "normalizer": "std", -} - -# train全局默认值 -GLOBAL_TRAIN_DEFAULTS = { - "output_dim": 1 -} - - -def load_yaml(path): - try: - with open(path, "r", encoding="utf-8") as f: - return yaml.load(f) - except Exception: - return None - - -def collect_dataset_defaults(base="."): - """ - 收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值 - """ - data_defaults = defaultdict(dict) - train_output_defaults = dict() - - for root, _, files in os.walk(base): - for name in files: - if not (name.endswith(".yaml") or name.endswith(".yml")): - continue - path = os.path.join(root, name) - cm = load_yaml(path) - if not isinstance(cm, CommentedMap): - continue - basic = cm.get("basic") - if not isinstance(basic, dict): - continue - dataset = basic.get("dataset") - if dataset is None: - continue - ds = str(dataset) - - # data 默认值 - data_sec = cm.get("data") - if isinstance(data_sec, dict): - for key in DATA_ALLOWED_KEYS: - if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None: - data_defaults[ds][key] = data_sec[key] - - # train.output_dim 默认值 - train_sec = cm.get("train") - if isinstance(train_sec, dict): - val = train_sec.get("output_dim") - if val is not None and ds not in train_output_defaults: - train_output_defaults[ds] = val - - return data_defaults, train_output_defaults - - -def ensure_basic_seed(cm: CommentedMap, path: str): - if "basic" not in cm or not isinstance(cm["basic"], dict): - cm["basic"] = CommentedMap() - basic = cm["basic"] - if "seed" not in basic: - basic["seed"] = 2023 - print(f"[ADD] {path}: basic.seed = 2023") - - -def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str): - if "data" not in cm or not isinstance(cm["data"], dict): - cm["data"] = CommentedMap() - data_sec = cm["data"] - - basic = cm.get("basic", {}) - dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None - - for key in sorted(DATA_ALLOWED_KEYS): - if key in data_sec and data_sec[key] is not None: - continue - if dataset and dataset in data_defaults and key in data_defaults[dataset]: - chosen = data_defaults[dataset][key] - src = f"default_from_dataset[{dataset}]" - else: - chosen = GLOBAL_DEFAULTS[key] - src = "GLOBAL_DEFAULTS" - data_sec[key] = chosen - print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})") - - -def merge_test_log_into_train(cm: CommentedMap, path: str): - """ - 将 test 和 log 的 key 合并到 train,并删除 test 和 log - 同时确保 train.debug 存在 - """ - train_sec = cm.setdefault("train", CommentedMap()) - - for section in ["test", "log"]: - if section in cm and isinstance(cm[section], dict): - for k, v in cm[section].items(): - if k not in train_sec: - train_sec[k] = v - print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})") - del cm[section] - print(f"[DEL] {path}: deleted section '{section}'") - - # train.debug - if "debug" not in train_sec: - train_sec["debug"] = False - print(f"[ADD] {path}: train.debug = False") - - -def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str): - train_sec = cm.setdefault("train", CommentedMap()) - if "output_dim" not in train_sec or train_sec["output_dim"] is None: - basic = cm.get("basic", {}) - dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None - if dataset and dataset in train_output_defaults: - val = train_output_defaults[dataset] - src = f"default_from_dataset[{dataset}]" - else: - val = GLOBAL_TRAIN_DEFAULTS["output_dim"] - src = "GLOBAL_TRAIN_DEFAULTS" - train_sec["output_dim"] = val - print(f"[FILL] {path}: train.output_dim <- {src} ({val})") - - -def sync_train_batch_size(cm: CommentedMap, path: str): - """ - 如果 train.batch_size 与 data.batch_size 不一致,以 data 为准 - """ - data_sec = cm.get("data", {}) - train_sec = cm.get("train", {}) - data_bs = data_sec.get("batch_size") - train_bs = train_sec.get("batch_size") - - if data_bs is not None and train_bs != data_bs: - train_sec["batch_size"] = data_bs - print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})") - - -def sort_subkeys_and_insert_blanklines(cm: CommentedMap): - for sec in list(cm.keys()): - if isinstance(cm[sec], dict): - sorted_cm = CommentedMap() - for k in sorted(cm[sec].keys()): - sorted_cm[k] = cm[sec][k] - cm[sec] = sorted_cm - - keys = list(cm.keys()) - for i, k in enumerate(keys): - if i == 0: - try: - cm.yaml_set_comment_before_after_key(k, before=None) - except Exception: - pass - else: - try: - cm.yaml_set_comment_before_after_key(k, before="\n") - except Exception: - pass - - -def process_all(base="."): - print(">> Collecting dataset defaults ...") - data_defaults, train_output_defaults = collect_dataset_defaults(base) - print(">> Collected data defaults per dataset:") - for ds, kv in data_defaults.items(): - print(f" - {ds}: {kv}") - print(">> Collected train.output_dim defaults per dataset:") - for ds, val in train_output_defaults.items(): - print(f" - {ds}: output_dim = {val}") - - for root, _, files in os.walk(base): - for name in files: - if not (name.endswith(".yaml") or name.endswith(".yml")): - continue - path = os.path.join(root, name) - cm = load_yaml(path) - if not isinstance(cm, CommentedMap): - print(f"[SKIP] {path}: top-level not mapping or load failed") - continue - - ensure_basic_seed(cm, path) - fill_data_defaults(cm, data_defaults, path) - merge_test_log_into_train(cm, path) - fill_train_output_dim(cm, train_output_defaults, path) - sync_train_batch_size(cm, path) # <-- 新增逻辑 - sort_subkeys_and_insert_blanklines(cm) - - try: - with open(path, "w", encoding="utf-8") as f: - yaml.dump(cm, f) - print(f"[OK] Written: {path}") - except Exception as e: - print(f"[ERROR] Write failed {path}: {e}") - - -if __name__ == "__main__": - process_all(".") diff --git a/model/GWN/GraphWaveNet.py b/model/GWN/GraphWaveNet.py index 6f290f5..5bece37 100755 --- a/model/GWN/GraphWaveNet.py +++ b/model/GWN/GraphWaveNet.py @@ -1,53 +1,35 @@ -import torch, torch.nn as nn, torch.nn.functional as F +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter +import torch.nn.functional as F + +def nconv(x, A): + """Multiply x by adjacency matrix along source node axis""" + return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() -class nconv(nn.Module): - """ - 图卷积操作的实现类 - 使用einsum进行矩阵运算,实现图卷积操作 - """ - - def forward(self, x, A): - return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous() - - -class linear(nn.Module): - """ - 线性变换层 - 使用1x1卷积实现线性变换 - """ - - def __init__(self, c_in, c_out): - super().__init__() - self.mlp = nn.Conv2d(c_in, c_out, 1) - - def forward(self, x): - return self.mlp(x) - - -class gcn(nn.Module): - """ - 图卷积网络层 - 实现高阶图卷积操作,支持多阶邻接矩阵 - """ - +class GraphConvNet(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super().__init__() - self.nconv = nconv() c_in = (order * support_len + 1) * c_in - self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order + self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True) + self.dropout = dropout + self.order = order - def forward(self, x, support): + def forward(self, x, support: list): out = [x] for a in support: - x1 = self.nconv(x, a) + x1 = nconv(x, a) out.append(x1) - for _ in range(2, self.order + 1): - x1 = self.nconv(x1, a) - out.append(x1) - return F.dropout( - self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training - ) + for k in range(2, self.order + 1): + x2 = nconv(x1, a) + out.append(x2) + x1 = x2 + + h = torch.cat(out, dim=1) + h = self.final_conv(h) + h = F.dropout(h, self.dropout, training=self.training) + return h class gwnet(nn.Module): @@ -59,126 +41,121 @@ class gwnet(nn.Module): def __init__(self, args): super().__init__() # 初始化基本参数 - self.dropout, self.blocks, self.layers = ( - args["dropout"], - args["blocks"], - args["layers"], - ) - self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"] + self.dropout = args["dropout"] + self.blocks = args["blocks"] + self.layers = args["layers"] + self.do_graph_conv = args.get("do_graph_conv", True) + self.cat_feat_gc = args.get("cat_feat_gc", False) + self.addaptadj = args.get("addaptadj", True) + supports = None + aptinit = args.get("aptinit", None) + in_dim = args.get("in_dim") + out_dim = args.get("out_dim") + residual_channels = args.get("residual_channels") + dilation_channels = args.get("dilation_channels") + skip_channels = args.get("skip_channels") + end_channels = args.get("end_channels") + kernel_size = args.get("kernel_size") + apt_size = args.get("apt_size", 10) - # 初始化各种卷积层和模块 - self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() - self.residual_convs, self.skip_convs, self.bn, self.gconv = ( - nn.ModuleList(), - nn.ModuleList(), - nn.ModuleList(), - nn.ModuleList(), - ) - self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1) - self.supports = args.get("supports", None) - # 计算感受野 + if self.cat_feat_gc: + self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors + out_channels=residual_channels, + kernel_size=(1, 1)) + self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1, + out_channels=residual_channels, + kernel_size=(1, 1)) + else: + self.start_conv = nn.Conv2d(in_channels=in_dim, + out_channels=residual_channels, + kernel_size=(1, 1)) + + self.fixed_supports = supports or [] receptive_field = 1 - self.supports_len = len(self.supports) if self.supports is not None else 0 - # 如果使用自适应邻接矩阵,初始化相关参数 - if self.gcn_bool and self.addaptadj: - aptinit = args.get("aptinit", None) + self.supports_len = len(self.fixed_supports) + if self.do_graph_conv and self.addaptadj: if aptinit is None: - if self.supports is None: - self.supports = [] - self.nodevec1 = nn.Parameter( - torch.randn(args["num_nodes"], 10, device=args["device"]) - ) - self.nodevec2 = nn.Parameter( - torch.randn(10, args["num_nodes"], device=args["device"]) - ) - self.supports_len += 1 + nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"]) else: - if self.supports is None: - self.supports = [] - m, p, n = torch.svd(aptinit) - initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) - initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) - self.nodevec1 = nn.Parameter(initemb1) - self.nodevec2 = nn.Parameter(initemb2) - self.supports_len += 1 + nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit) + self.supports_len += 1 + self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs] - # 获取模型参数 - ks, res, dil, skip, endc, out_dim = ( - args["kernel_size"], - args["residual_channels"], - args["dilation_channels"], - args["skip_channels"], - args["end_channels"], - args["out_dim"], - ) + depth = list(range(self.blocks * self.layers)) - # 构建模型层 + # 1x1 convolution for residual and skip connections (slightly different see docstring) + self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth]) + self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth]) + self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth]) + self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len) + for _ in depth]) + + self.filter_convs = ModuleList() + self.gate_convs = ModuleList() for b in range(self.blocks): - add_scope, new_dil = ks - 1, 1 + additional_scope = kernel_size - 1 + D = 1 # dilation for i in range(self.layers): - # 添加时间卷积层 - self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) - self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) - self.residual_convs.append(nn.Conv2d(dil, res, 1)) - self.skip_convs.append(nn.Conv2d(dil, skip, 1)) - self.bn.append(nn.BatchNorm2d(res)) - new_dil *= 2 - receptive_field += add_scope - add_scope *= 2 - if self.gcn_bool: - self.gconv.append( - gcn(dil, res, args["dropout"], support_len=self.supports_len) - ) - - # 输出层 - self.end_conv_1 = nn.Conv2d(skip, endc, 1) - self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) + # dilated convolutions + self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) + self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) + D *= 2 + receptive_field += additional_scope + additional_scope *= 2 self.receptive_field = receptive_field + self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True) + self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True) + def forward(self, input): - """ - 前向传播函数 - 实现模型的推理过程 - """ - # 数据预处理 - input = input[..., 0:2].transpose(1, 3) - input = F.pad(input, (1, 0, 0, 0)) - in_len = input.size(3) - x = ( - F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) - if in_len < self.receptive_field - else input - ) - - # 初始卷积 - x, skip, new_supports = self.start_conv(x), 0, None - - # 如果使用自适应邻接矩阵,计算新的邻接矩阵 - if self.gcn_bool and self.addaptadj and self.supports is not None: + x = input[..., 0:1].transpose(1, 3) + # Input shape is (bs, features, n_nodes, n_timesteps) + in_len = x.size(3) + if in_len < self.receptive_field: + x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) + if self.cat_feat_gc: + f1, f2 = x[:, [0]], x[:, 1:] + x1 = self.start_conv(f1) + x2 = F.leaky_relu(self.cat_feature_conv(f2)) + x = x1 + x2 + else: + x = self.start_conv(x) + skip = 0 + adjacency_matrices = self.fixed_supports + # calculate the current adaptive adj matrix once per iteration + if self.addaptadj: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) - new_supports = self.supports + [adp] + adjacency_matrices = self.fixed_supports + [adp] - # 主网络层的前向传播 + # WaveNet layers for i in range(self.blocks * self.layers): residual = x - # 时间卷积操作 - f = self.filter_convs[i](residual).tanh() - g = self.gate_convs[i](residual).sigmoid() - x = f * g - s = self.skip_convs[i](x) - skip = ( - skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0 - ) + s + # dilated convolution + filter = torch.tanh(self.filter_convs[i](residual)) + gate = torch.sigmoid(self.gate_convs[i](residual)) + x = filter * gate + # parametrized skip connection + s = self.skip_convs[i](x) # what are we skipping?? + try: # if i > 0 this works + skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool? + except: + skip = 0 + skip = s + skip + if i == (self.blocks * self.layers - 1): # last X getting ignored anyway + break - # 图卷积操作 - if self.gcn_bool and self.supports is not None: - x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) + if self.do_graph_conv: + graph_out = self.graph_convs[i](x, adjacency_matrices) + x = x + graph_out if self.cat_feat_gc else graph_out else: x = self.residual_convs[i](x) - x = x + residual[:, :, :, -x.size(3) :] + x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool? x = self.bn[i](x) - # 输出层处理 - return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) + x = F.relu(skip) # ignore last X? + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures) + # x = x.transpose(1, 3) + return x diff --git a/model/GWN/GraphWaveNet_bk.py b/model/GWN/GraphWaveNet_bk.py index 19308d4..6f290f5 100755 --- a/model/GWN/GraphWaveNet_bk.py +++ b/model/GWN/GraphWaveNet_bk.py @@ -1,97 +1,98 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -import sys +import torch, torch.nn as nn, torch.nn.functional as F class nconv(nn.Module): - def __init__(self): - super(nconv, self).__init__() + """ + 图卷积操作的实现类 + 使用einsum进行矩阵运算,实现图卷积操作 + """ def forward(self, x, A): - x = torch.einsum("ncvl,vw->ncwl", (x, A)) - return x.contiguous() + return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous() class linear(nn.Module): + """ + 线性变换层 + 使用1x1卷积实现线性变换 + """ + def __init__(self, c_in, c_out): - super(linear, self).__init__() - self.mlp = torch.nn.Conv2d( - c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True - ) + super().__init__() + self.mlp = nn.Conv2d(c_in, c_out, 1) def forward(self, x): return self.mlp(x) class gcn(nn.Module): + """ + 图卷积网络层 + 实现高阶图卷积操作,支持多阶邻接矩阵 + """ + def __init__(self, c_in, c_out, dropout, support_len=3, order=2): - super(gcn, self).__init__() + super().__init__() self.nconv = nconv() c_in = (order * support_len + 1) * c_in - self.mlp = linear(c_in, c_out) - self.dropout = dropout - self.order = order + self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order def forward(self, x, support): out = [x] for a in support: x1 = self.nconv(x, a) out.append(x1) - for k in range(2, self.order + 1): - x2 = self.nconv(x1, a) - out.append(x2) - x1 = x2 - - h = torch.cat(out, dim=1) - h = self.mlp(h) - h = F.dropout(h, self.dropout, training=self.training) - return h + for _ in range(2, self.order + 1): + x1 = self.nconv(x1, a) + out.append(x1) + return F.dropout( + self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training + ) class gwnet(nn.Module): + """ + Graph WaveNet模型的主类 + 结合了图卷积网络和时序卷积网络,用于时空预测任务 + """ + def __init__(self, args): - super(gwnet, self).__init__() - self.dropout = args["dropout"] - self.blocks = args["blocks"] - self.layers = args["layers"] - self.gcn_bool = args["gcn_bool"] - self.addaptadj = args["addaptadj"] - - self.filter_convs = nn.ModuleList() - self.gate_convs = nn.ModuleList() - self.residual_convs = nn.ModuleList() - self.skip_convs = nn.ModuleList() - self.bn = nn.ModuleList() - self.gconv = nn.ModuleList() - - self.start_conv = nn.Conv2d( - in_channels=args["in_dim"], - out_channels=args["residual_channels"], - kernel_size=(1, 1), + super().__init__() + # 初始化基本参数 + self.dropout, self.blocks, self.layers = ( + args["dropout"], + args["blocks"], + args["layers"], ) + self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"] + + # 初始化各种卷积层和模块 + self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() + self.residual_convs, self.skip_convs, self.bn, self.gconv = ( + nn.ModuleList(), + nn.ModuleList(), + nn.ModuleList(), + nn.ModuleList(), + ) + self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1) self.supports = args.get("supports", None) + # 计算感受野 receptive_field = 1 + self.supports_len = len(self.supports) if self.supports is not None else 0 - self.supports_len = 0 - if self.supports is not None: - self.supports_len += len(self.supports) - + # 如果使用自适应邻接矩阵,初始化相关参数 if self.gcn_bool and self.addaptadj: aptinit = args.get("aptinit", None) if aptinit is None: if self.supports is None: self.supports = [] self.nodevec1 = nn.Parameter( - torch.randn(args["num_nodes"], 10).to(args["device"]), - requires_grad=True, - ).to(args["device"]) + torch.randn(args["num_nodes"], 10, device=args["device"]) + ) self.nodevec2 = nn.Parameter( - torch.randn(10, args["num_nodes"]).to(args["device"]), - requires_grad=True, - ).to(args["device"]) + torch.randn(10, args["num_nodes"], device=args["device"]) + ) self.supports_len += 1 else: if self.supports is None: @@ -99,156 +100,85 @@ class gwnet(nn.Module): m, p, n = torch.svd(aptinit) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) - self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to( - args["device"] - ) - self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to( - args["device"] - ) + self.nodevec1 = nn.Parameter(initemb1) + self.nodevec2 = nn.Parameter(initemb2) self.supports_len += 1 - kernel_size = args["kernel_size"] - residual_channels = args["residual_channels"] - dilation_channels = args["dilation_channels"] - kernel_size = args["kernel_size"] - skip_channels = args["skip_channels"] - end_channels = args["end_channels"] - out_dim = args["out_dim"] - dropout = args["dropout"] + # 获取模型参数 + ks, res, dil, skip, endc, out_dim = ( + args["kernel_size"], + args["residual_channels"], + args["dilation_channels"], + args["skip_channels"], + args["end_channels"], + args["out_dim"], + ) + # 构建模型层 for b in range(self.blocks): - additional_scope = kernel_size - 1 - new_dilation = 1 + add_scope, new_dil = ks - 1, 1 for i in range(self.layers): - # dilated convolutions - self.filter_convs.append( - nn.Conv2d( - in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), - dilation=new_dilation, - ) - ) - - self.gate_convs.append( - nn.Conv2d( - in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), - dilation=new_dilation, - ) - ) - - # 1x1 convolution for residual connection - self.residual_convs.append( - nn.Conv2d( - in_channels=dilation_channels, - out_channels=residual_channels, - kernel_size=(1, 1), - ) - ) - - # 1x1 convolution for skip connection - self.skip_convs.append( - nn.Conv2d( - in_channels=dilation_channels, - out_channels=skip_channels, - kernel_size=(1, 1), - ) - ) - self.bn.append(nn.BatchNorm2d(residual_channels)) - new_dilation *= 2 - receptive_field += additional_scope - additional_scope *= 2 + # 添加时间卷积层 + self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.residual_convs.append(nn.Conv2d(dil, res, 1)) + self.skip_convs.append(nn.Conv2d(dil, skip, 1)) + self.bn.append(nn.BatchNorm2d(res)) + new_dil *= 2 + receptive_field += add_scope + add_scope *= 2 if self.gcn_bool: self.gconv.append( - gcn( - dilation_channels, - residual_channels, - dropout, - support_len=self.supports_len, - ) + gcn(dil, res, args["dropout"], support_len=self.supports_len) ) - self.end_conv_1 = nn.Conv2d( - in_channels=skip_channels, - out_channels=end_channels, - kernel_size=(1, 1), - bias=True, - ) - - self.end_conv_2 = nn.Conv2d( - in_channels=end_channels, - out_channels=out_dim, - kernel_size=(1, 1), - bias=True, - ) - + # 输出层 + self.end_conv_1 = nn.Conv2d(skip, endc, 1) + self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.receptive_field = receptive_field def forward(self, input): - input = input[..., 0:2] - input = input.transpose(1, 3) - input = nn.functional.pad(input, (1, 0, 0, 0)) + """ + 前向传播函数 + 实现模型的推理过程 + """ + # 数据预处理 + input = input[..., 0:2].transpose(1, 3) + input = F.pad(input, (1, 0, 0, 0)) in_len = input.size(3) - if in_len < self.receptive_field: - x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) - else: - x = input - x = self.start_conv(x) - skip = 0 + x = ( + F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) + if in_len < self.receptive_field + else input + ) - # calculate the current adaptive adj matrix once per iteration - new_supports = None + # 初始卷积 + x, skip, new_supports = self.start_conv(x), 0, None + + # 如果使用自适应邻接矩阵,计算新的邻接矩阵 if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] - # WaveNet layers + # 主网络层的前向传播 for i in range(self.blocks * self.layers): - # |----------------------------------------| *residual* - # | | - # | |-- conv -- tanh --| | - # -> dilate -|----| * ----|-- 1x1 -- + --> *input* - # |-- conv -- sigm --| | - # 1x1 - # | - # ---------------------------------------> + -------------> *skip* - - # (dilation, init_dilation) = self.dilations[i] - - # residual = dilation_func(x, dilation, init_dilation, i) residual = x - # dilated convolution - filter = self.filter_convs[i](residual) - filter = torch.tanh(filter) - gate = self.gate_convs[i](residual) - gate = torch.sigmoid(gate) - x = filter * gate - - # parametrized skip connection - - s = x - s = self.skip_convs[i](s) - try: - skip = skip[:, :, :, -s.size(3) :] - except: - skip = 0 - skip = s + skip + # 时间卷积操作 + f = self.filter_convs[i](residual).tanh() + g = self.gate_convs[i](residual).sigmoid() + x = f * g + s = self.skip_convs[i](x) + skip = ( + skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0 + ) + s + # 图卷积操作 if self.gcn_bool and self.supports is not None: - if self.addaptadj: - x = self.gconv[i](x, new_supports) - else: - x = self.gconv[i](x, self.supports) + x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) else: x = self.residual_convs[i](x) - x = x + residual[:, :, :, -x.size(3) :] - x = self.bn[i](x) - x = F.relu(skip) - x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) - return x + # 输出层处理 + return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..a27a3bf --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# 设置默认模型名和数据集列表 +MODEL_NAME="GWN" +DATASETS=( + "METR-LA" + "PEMS-BAY" + "NYCBike-InFlow" + "NYCBike-OutFlow" + "AirQuality" + "SolarEnergy" +) + +# 初始化统计变量 +success_count=0 +failure_count=0 +missing_count=0 +total_count=0 +success_datasets=() +failure_datasets=() +missing_datasets=() + +# 检查是否有参数传入来覆盖默认值 +if [ $# -gt 0 ]; then + MODEL_NAME=$1 + # 如果传入了更多参数,使用它们作为数据集列表 + if [ $# -gt 1 ]; then + DATASETS=(${@:2}) + fi +fi + +echo "使用模型: $MODEL_NAME" +echo "数据集列表: ${DATASETS[*]}" +echo "开始测试..." +echo "" + +# 循环测试每个数据集 +for dataset in "${DATASETS[@]}"; do + total_count=$((total_count + 1)) + # 构建配置文件路径 + CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml" + + echo "测试数据集: $dataset" + echo "使用配置文件: $CONFIG_PATH" + + # 检查配置文件是否存在 + if [ ! -f "$CONFIG_PATH" ]; then + echo "错误: 配置文件 $CONFIG_PATH 不存在!" + missing_count=$((missing_count + 1)) + missing_datasets+=("$dataset") + echo "----------------------------------------" + continue + fi + + # 执行测试命令并捕获输出 + echo "执行: python run.py --config $CONFIG_PATH" + output=$(python run.py --config "$CONFIG_PATH" 2>&1) + + # 如果没有找到明确的标记,回退到检查退出码 + if [ $? -eq 0 ]; then + echo "数据集 $dataset 测试成功! (基于退出码)" + success_count=$((success_count + 1)) + success_datasets+=("$dataset") + else + echo "数据集 $dataset 测试失败! (基于退出码)" + failure_count=$((failure_count + 1)) + failure_datasets+=("$dataset") + fi + + echo "----------------------------------------" +done + +# 输出总结 +echo "=======================================" +echo "测试总结" +echo "=======================================" +echo "总数据集数量: $total_count" +echo "成功数量: $success_count" +echo "失败数量: $failure_count" +echo "缺失配置文件数量: $missing_count" + +if [ ${#success_datasets[@]} -gt 0 ]; then + echo "成功的数据集: ${success_datasets[*]}" +fi + +if [ ${#failure_datasets[@]} -gt 0 ]; then + echo "失败的数据集: ${failure_datasets[*]}" +fi + +if [ ${#missing_datasets[@]} -gt 0 ]; then + echo "缺失配置的数据集: ${missing_datasets[*]}" +fi + +echo "=======================================" +echo "所有测试完成!" \ No newline at end of file diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 85060f1..2bd7e6e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -177,6 +177,14 @@ class Trainer: # 前向传播 label = target[..., : self.args["output_dim"]] output = self.model(data).to(self.device) + # if output.shape != label.shape: + # import sys + # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(1) + # else: + # import sys + # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(0) loss = self.loss(output, label) # 反归一化 -- 2.40.1 From 440cb6936bea4e1ba4f855959f72d11787862c2d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 17:12:46 +0800 Subject: [PATCH 07/41] =?UTF-8?q?=E5=85=BC=E5=AE=B9STAEFormer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/AirQuality.yaml | 4 +-- config/STAEFormer/AirQuality.yaml | 10 +++--- config/STAEFormer/PEMS-BAY.yaml | 58 ++++++++++++++++++++++++++++++ config/STAEFormer/SolarEnergy.yaml | 8 ++--- model/STAEFormer/STAEFormer.py | 14 ++++---- run_tests.sh | 2 +- 6 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 config/STAEFormer/PEMS-BAY.yaml diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml index c2d905a..d6061d9 100644 --- a/config/AEPSA/AirQuality.yaml +++ b/config/AEPSA/AirQuality.yaml @@ -13,7 +13,7 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 + num_nodes: 35 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 @@ -26,7 +26,7 @@ model: gpt_path: ./GPT-2 input_dim: 6 n_heads: 1 - num_nodes: 12 + num_nodes: 35 patch_len: 6 pred_len: 24 seq_len: 24 diff --git a/config/STAEFormer/AirQuality.yaml b/config/STAEFormer/AirQuality.yaml index e5f07e8..b7956da 100644 --- a/config/STAEFormer/AirQuality.yaml +++ b/config/STAEFormer/AirQuality.yaml @@ -13,8 +13,8 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 - steps_per_day: 24 + num_nodes: 35 + steps_per_day: 288 test_ratio: 0.2 val_ratio: 0.2 @@ -28,7 +28,7 @@ model: input_embedding_dim: 24 num_heads: 4 num_layers: 3 - num_nodes: 12 + num_nodes: 35 out_steps: 24 output_dim: 6 spatial_embedding_dim: 0 @@ -41,9 +41,9 @@ train: debug: false early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false - log_step: 200 + log_step: 20000 loss_func: mae lr_decay: false lr_decay_rate: 0.3 diff --git a/config/STAEFormer/PEMS-BAY.yaml b/config/STAEFormer/PEMS-BAY.yaml new file mode 100644 index 0000000..353da67 --- /dev/null +++ b/config/STAEFormer/PEMS-BAY.yaml @@ -0,0 +1,58 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: STAEFormer + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + adaptive_embedding_dim: 80 + dow_embedding_dim: 24 + dropout: 0.1 + feed_forward_dim: 256 + in_steps: 24 + input_dim: 1 + input_embedding_dim: 24 + num_heads: 4 + num_layers: 3 + num_nodes: 325 + out_steps: 24 + output_dim: 1 + spatial_embedding_dim: 0 + steps_per_day: 288 + tod_embedding_dim: 24 + use_mixed_proj: true + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + log_step: 200 + loss_func: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: 0.0 + mape_thresh: 0.0 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/STAEFormer/SolarEnergy.yaml b/config/STAEFormer/SolarEnergy.yaml index fd97a63..c1151ca 100644 --- a/config/STAEFormer/SolarEnergy.yaml +++ b/config/STAEFormer/SolarEnergy.yaml @@ -24,13 +24,13 @@ model: dropout: 0.1 feed_forward_dim: 256 in_steps: 24 - input_dim: 137 + input_dim: 1 input_embedding_dim: 24 num_heads: 4 num_layers: 3 num_nodes: 137 out_steps: 24 - output_dim: 137 + output_dim: 1 spatial_embedding_dim: 0 steps_per_day: 24 tod_embedding_dim: 24 @@ -41,7 +41,7 @@ train: debug: false early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false log_step: 200 loss_func: mae @@ -52,7 +52,7 @@ train: mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 plot: false real_value: true weight_decay: 0 diff --git a/model/STAEFormer/STAEFormer.py b/model/STAEFormer/STAEFormer.py index 63fdb01..91b8188 100755 --- a/model/STAEFormer/STAEFormer.py +++ b/model/STAEFormer/STAEFormer.py @@ -187,17 +187,19 @@ class STAEformer(nn.Module): batch_size = x.shape[0] if self.tod_embedding_dim > 0: - tod = x[..., 1] + tod = x[..., -2] if self.dow_embedding_dim > 0: - dow = x[..., 2] - x = x[..., 0:1] + dow = x[..., -1] + x = x[..., 0:self.input_dim] x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) features = [x] if self.tod_embedding_dim > 0: - tod_emb = self.tod_embedding( - (tod * self.steps_per_day).long() - ) # (batch_size, in_steps, num_nodes, tod_embedding_dim) + # 确保索引在有效范围内 + tod_index = (tod * self.steps_per_day).long() + # 防止索引越界 + tod_index = torch.clamp(tod_index, 0, self.steps_per_day - 1) + tod_emb = self.tod_embedding(tod_index) # (batch_size, in_steps, num_nodes, tod_embedding_dim) features.append(tod_emb) if self.dow_embedding_dim > 0: dow_emb = self.dow_embedding( diff --git a/run_tests.sh b/run_tests.sh index a27a3bf..94080b6 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,7 +1,7 @@ #!/bin/bash # 设置默认模型名和数据集列表 -MODEL_NAME="GWN" +MODEL_NAME="STAEFormer" DATASETS=( "METR-LA" "PEMS-BAY" -- 2.40.1 From 9b1cf5f0ce94e215df3943026de36c26904dc2ff Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 18:11:06 +0800 Subject: [PATCH 08/41] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run_tests.sh b/run_tests.sh index 94080b6..ef700a1 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -52,9 +52,9 @@ for dataset in "${DATASETS[@]}"; do continue fi - # 执行测试命令并捕获输出 + # 执行测试命令,同时捕获输出并显示在控制台上 echo "执行: python run.py --config $CONFIG_PATH" - output=$(python run.py --config "$CONFIG_PATH" 2>&1) + output=$(python run.py --config "$CONFIG_PATH" 2>&1 | tee /dev/tty) # 如果没有找到明确的标记,回退到检查退出码 if [ $? -eq 0 ]; then -- 2.40.1 From 07d7d4385748fcad1faf1433cc1cb1094326780c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 18:51:54 +0800 Subject: [PATCH 09/41] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=EF=BC=8CMETR-LA=20-STAEFormer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/STAEFormer/METR-LA.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/STAEFormer/METR-LA.yaml b/config/STAEFormer/METR-LA.yaml index 003e50e..6fa708e 100644 --- a/config/STAEFormer/METR-LA.yaml +++ b/config/STAEFormer/METR-LA.yaml @@ -9,9 +9,9 @@ data: batch_size: 16 column_wise: false days_per_week: 7 - horizon: 12 + horizon: 24 input_dim: 1 - lag: 12 + lag: 24 normalizer: std num_nodes: 207 steps_per_day: 288 @@ -23,13 +23,13 @@ model: dow_embedding_dim: 24 dropout: 0.1 feed_forward_dim: 256 - in_steps: 12 + in_steps: 24 input_dim: 1 input_embedding_dim: 24 num_heads: 4 num_layers: 3 num_nodes: 207 - out_steps: 12 + out_steps: 24 output_dim: 1 spatial_embedding_dim: 0 steps_per_day: 288 -- 2.40.1 From 865c5a30823c507b486fddb8faf530142e584aa1 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 6 Dec 2025 19:47:33 +0800 Subject: [PATCH 10/41] ASTRA-V3 --- .vscode/launch.json | 8 ++ config/AEPSA/v3_METR-LA.yaml | 57 ++++++++++ config/AEPSA/v3_PEMS-BAY.yaml | 54 +++++++++ model/AEPSA/aepsav3.py | 209 ++++++++++++++++++++++++++++++++++ model/model_selector.py | 3 + 5 files changed, 331 insertions(+) create mode 100644 config/AEPSA/v3_METR-LA.yaml create mode 100755 config/AEPSA/v3_PEMS-BAY.yaml create mode 100644 model/AEPSA/aepsav3.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 2b530ca..3dc2b03 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -234,6 +234,14 @@ "console": "integratedTerminal", "args": "--config ./config/AEPSA/v2_SolarEnergy.yaml" }, + { + "name": "AEPSA_v3: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/AEPSA/v3_METR-LA.yaml" + }, { "name": "EXPB: NYCBike-InFlow", "type": "debugpy", diff --git a/config/AEPSA/v3_METR-LA.yaml b/config/AEPSA/v3_METR-LA.yaml new file mode 100644 index 0000000..5d22820 --- /dev/null +++ b/config/AEPSA/v3_METR-LA.yaml @@ -0,0 +1,57 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: AEPSA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + chebyshev_order: 3 + d_ff: 128 + d_model: 64 + dropout: 0.2 + graph_hidden_dim: 32 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 207 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/AEPSA/v3_PEMS-BAY.yaml b/config/AEPSA/v3_PEMS-BAY.yaml new file mode 100755 index 0000000..9f98483 --- /dev/null +++ b/config/AEPSA/v3_PEMS-BAY.yaml @@ -0,0 +1,54 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: AEPSA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 325 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/model/AEPSA/aepsav3.py b/model/AEPSA/aepsav3.py new file mode 100644 index 0000000..6a579b6 --- /dev/null +++ b/model/AEPSA/aepsav3.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange +from model.AEPSA.normalizer import GumbelSoftmax +from model.AEPSA.reprogramming import ReprogrammingLayer +import torch.nn.functional as F + +# 基于动态图增强的时空序列预测模型实现 + +class DynamicGraphEnhancer(nn.Module): + """动态图增强编码器""" + def __init__(self, num_nodes, in_dim, embed_dim=10): + super().__init__() + self.num_nodes = num_nodes # 节点个数 + self.embed_dim = embed_dim # 节点嵌入维度 + + self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数 + + self.feature_transform = nn.Sequential( # 特征转换网络 + nn.Linear(in_dim, 16), + nn.Sigmoid(), + nn.Linear(16, 2), + nn.Sigmoid(), + nn.Linear(2, embed_dim) + ) + + self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵 + + def get_laplacian(self, graph, I, normalize=True): + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根 + D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 + if normalize: + return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵 + else: + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵 + + def forward(self, X): + """生成动态拉普拉斯矩阵""" + batch_size = X.size(0) # 批次大小 + laplacians = [] # 存储各批次的拉普拉斯矩阵 + I = self.eye.to(X.device) # 移动单位矩阵到目标设备 + + for b in range(batch_size): + filt = self.feature_transform(X[b]) # 特征转换 + nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵 + laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵 + laplacians.append(laplacian) + return torch.stack(laplacians, dim=0) # 堆叠并返回 + +class GraphEnhancedEncoder(nn.Module): + """图增强编码器""" + def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu', + temporal_dim=12, num_features=1): + super().__init__() + self.K = K # Chebyshev多项式阶数 + self.in_dim = in_dim # 输入特征维度 + self.hidden_dim = hidden_dim # 隐藏层维度 + self.device = device # 运行设备 + self.temporal_dim = temporal_dim # 时间序列长度 + self.num_features = num_features # 特征通道数量 + + self.input_projection = nn.Sequential( # 输入投影层 + nn.Conv2d(num_features, 16, kernel_size=(1, 3), padding=(0, 1)), + nn.ReLU(), + nn.Conv2d(16, in_dim, kernel_size=(1, temporal_dim)), + nn.ReLU() + ) + + self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim) # 动态图增强器 + self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # 谱系数 + self.W = nn.ParameterList([nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)]) # 传播权重 + self.to(device) # 移动到指定设备 + + def chebyshev_polynomials(self, L_tilde, X): + """计算Chebyshev多项式展开""" + T_k_list = [X] # T_0(X) = X + if self.K >= 1: + T_k_list.append(torch.matmul(L_tilde, X)) # T_1(X) = L_tilde * X + for k in range(2, self.K + 1): + T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2]) # 递推计算 + return T_k_list # 返回多项式列表 + + def forward(self, X): + """输入特征[B,N,C,T],返回增强特征[B,N,hidden_dim*(K+1)]""" + batch_size = X.size(0) # 批次大小 + num_nodes = X.size(1) # 节点数量 + + x = X.permute(0, 2, 1, 3) # [B,C,N,T] + x_proj = self.input_projection(x).squeeze(-1) # [B,in_dim,N] + x_proj = x_proj.permute(0, 2, 1) # [B,N,in_dim] + + enhanced_features = [] # 存储增强特征 + laplacians = self.graph_enhancer(x_proj) # 生成动态拉普拉斯矩阵 + + for b in range(batch_size): + L = laplacians[b] # 当前批次的拉普拉斯矩阵 + + # 特征值缩放 + try: + lambda_max = torch.linalg.eigvalsh(L).max().real # 最大特征值 + lambda_max = 1.0 if lambda_max < 1e-6 else lambda_max # 防止除零 + L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) # 归一化拉普拉斯 + except: + L_tilde = torch.eye(num_nodes, device=X.device) # 异常处理 + + # 计算展开并应用权重 + T_k_list = self.chebyshev_polynomials(L_tilde, x_proj[b]) # 计算Chebyshev多项式 + H_list = [torch.matmul(T_k_list[k], self.W[k]) for k in range(self.K + 1)] # 应用权重 + X_enhanced = torch.cat(H_list, dim=-1) # 拼接特征 + enhanced_features.append(X_enhanced) + + return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 + +class AEPSA(nn.Module): + """自适应特征投影时空自注意力模型""" + def __init__(self, configs): + super(AEPSA, self).__init__() + self.device = configs['device'] # 运行设备 + self.pred_len = configs['pred_len'] # 预测序列长度 + self.seq_len = configs['seq_len'] # 输入序列长度 + self.patch_len = configs['patch_len'] # 补丁长度 + self.input_dim = configs['input_dim'] # 输入特征维度 + self.stride = configs['stride'] # 步长 + self.dropout = configs['dropout'] # Dropout概率 + self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 + self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 + self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 + self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + + self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 + + self.d_model = configs['d_model'] # 模型维度 + self.n_heads = configs['n_heads'] # 注意力头数量 + self.d_keys = None # 键维度 + self.d_llm = 768 # GPT2隐藏层维度 + + self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量 + self.head_nf = self.d_ff * self.patch_nums # 头特征维度 + + # 初始化GPT2模型 + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) # GPT2模型 + self.gpts.h = self.gpts.h[:self.gpt_layers] # 截取指定层数 + self.gpts.apply(self.reset_parameters) # 重置参数 + + self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重 + self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小 + self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层 + self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 + + # 初始化图增强编码器 + self.graph_encoder = GraphEnhancedEncoder( + K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 + in_dim=self.d_model, # 输入特征维度 + hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + num_nodes=self.num_nodes, # 节点数量 + embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + device=self.device, # 运行设备 + temporal_dim=self.seq_len, # 时间序列长度 + num_features=self.input_dim # 特征通道数 + ) + + self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 + self.d_model # 输出维度 + ) + + self.out_mlp = nn.Sequential( + nn.Linear(self.d_llm, 128), + nn.ReLU(), + nn.Linear(128, self.pred_len) + ) + + # 设置参数可训练性 wps=word position embeddings + for name, param in self.gpts.named_parameters(): + param.requires_grad = 'wpe' in name + + def reset_parameters(self, module): + if hasattr(module, 'weight') and module.weight is not None: + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + def forward(self, x): + # 数据处理 + x = x[..., :1] # [B,T,N,1] + x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] + + # 图编码 + H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] + X_t_1 = self.graph_projection(H_t) # [B,N,d_model] + enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] + + # 词嵌入处理 + self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) # [d_llm,1] + source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] + + # 重编程与预测 + enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) + enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm] + dec_out = self.out_mlp(enc_out) # [B,N,pred_len] + + # 维度调整 + outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] + outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + + return outputs \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index c669d82..633b02c 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -25,6 +25,7 @@ from model.STAWnet.STAWnet import STAWnet from model.REPST.repst import repst as REPST from model.AEPSA.aepsa import AEPSA as AEPSA from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 +from model.AEPSA.aepsav3 import AEPSA as AEPSAv3 @@ -86,3 +87,5 @@ def model_selector(config): return AEPSA(model_config) case "AEPSA_v2": return AEPSAv2(model_config) + case "AEPSA_v3": + return AEPSAv3(model_config) -- 2.40.1 From f899f50b163c9c5b4c2882a00c432bd0125eaa11 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 6 Dec 2025 19:48:27 +0800 Subject: [PATCH 11/41] =?UTF-8?q?=E5=8F=98=E9=87=8F=E5=90=8D=E6=9B=B4?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/AEPSA/aepsav3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/AEPSA/aepsav3.py b/model/AEPSA/aepsav3.py index 6a579b6..99c6748 100644 --- a/model/AEPSA/aepsav3.py +++ b/model/AEPSA/aepsav3.py @@ -190,7 +190,7 @@ class AEPSA(nn.Module): # 图编码 H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] X_t_1 = self.graph_projection(H_t) # [B,N,d_model] - enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] + X_enc = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] # 词嵌入处理 self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -198,9 +198,9 @@ class AEPSA(nn.Module): source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] # 重编程与预测 - enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) - enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm] - dec_out = self.out_mlp(enc_out) # [B,N,pred_len] + X_enc = self.reprogramming_layer(X_enc, source_embeddings, source_embeddings) + X_enc = self.gpts(inputs_embeds=X_enc).last_hidden_state # [B,N,d_llm] + dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] -- 2.40.1 From 15989181126cecd49b4b7a5673f5f29b519d2586 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 6 Dec 2025 19:49:31 +0800 Subject: [PATCH 12/41] =?UTF-8?q?=E6=9B=B4=E6=94=B9v3=5FPEWSBAY=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/v2_PEMS-BAY.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/AEPSA/v2_PEMS-BAY.yaml b/config/AEPSA/v2_PEMS-BAY.yaml index 7b5c97e..c40034d 100755 --- a/config/AEPSA/v2_PEMS-BAY.yaml +++ b/config/AEPSA/v2_PEMS-BAY.yaml @@ -19,9 +19,11 @@ data: val_ratio: 0.2 model: + chebyshev_order: 3 d_ff: 128 d_model: 64 dropout: 0.2 + graph_hidden_dim: 32 gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 -- 2.40.1 From aed1e53f0f163c6ea62b426cab8ef3e64eff5f04 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 14:06:56 +0800 Subject: [PATCH 13/41] =?UTF-8?q?=E6=9B=B4=E5=90=8DASTRA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/{AEPSA => ASTRA}/Chebyshev+Laplacian_construction.py | 0 model/{AEPSA/aepsa.py => ASTRA/astra.py} | 0 model/{AEPSA/aepsav2.py => ASTRA/astrav2.py} | 0 model/{AEPSA/aepsav3.py => ASTRA/astrav3.py} | 0 model/{AEPSA => ASTRA}/normalizer.py | 0 model/{AEPSA => ASTRA}/reprogramming.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename model/{AEPSA => ASTRA}/Chebyshev+Laplacian_construction.py (100%) rename model/{AEPSA/aepsa.py => ASTRA/astra.py} (100%) rename model/{AEPSA/aepsav2.py => ASTRA/astrav2.py} (100%) rename model/{AEPSA/aepsav3.py => ASTRA/astrav3.py} (100%) rename model/{AEPSA => ASTRA}/normalizer.py (100%) rename model/{AEPSA => ASTRA}/reprogramming.py (100%) diff --git a/model/AEPSA/Chebyshev+Laplacian_construction.py b/model/ASTRA/Chebyshev+Laplacian_construction.py similarity index 100% rename from model/AEPSA/Chebyshev+Laplacian_construction.py rename to model/ASTRA/Chebyshev+Laplacian_construction.py diff --git a/model/AEPSA/aepsa.py b/model/ASTRA/astra.py similarity index 100% rename from model/AEPSA/aepsa.py rename to model/ASTRA/astra.py diff --git a/model/AEPSA/aepsav2.py b/model/ASTRA/astrav2.py similarity index 100% rename from model/AEPSA/aepsav2.py rename to model/ASTRA/astrav2.py diff --git a/model/AEPSA/aepsav3.py b/model/ASTRA/astrav3.py similarity index 100% rename from model/AEPSA/aepsav3.py rename to model/ASTRA/astrav3.py diff --git a/model/AEPSA/normalizer.py b/model/ASTRA/normalizer.py similarity index 100% rename from model/AEPSA/normalizer.py rename to model/ASTRA/normalizer.py diff --git a/model/AEPSA/reprogramming.py b/model/ASTRA/reprogramming.py similarity index 100% rename from model/AEPSA/reprogramming.py rename to model/ASTRA/reprogramming.py -- 2.40.1 From 9c76975056eb8fe18320aecd3121dc9a3bc2dc84 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 14:07:38 +0800 Subject: [PATCH 14/41] =?UTF-8?q?=E6=9B=B4=E5=90=8DASTRA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 2 +- config/AEPSA/AirQuality.yaml | 2 +- config/AEPSA/BJTaxi-InFlow.yaml | 2 +- config/AEPSA/BJTaxi-Inflow.yaml | 2 +- config/AEPSA/BJTaxi-OutFlow.yaml | 2 +- config/AEPSA/BJTaxi-outflow.yaml | 2 +- config/AEPSA/METR-LA.yaml | 2 +- config/AEPSA/NYCBike-InFlow.yaml | 2 +- config/AEPSA/NYCBike-OutFlow.yaml | 2 +- config/AEPSA/NYCBike-inflow.yaml | 2 +- config/AEPSA/NYCBike-outflow.yaml | 2 +- config/AEPSA/PEMS-BAY.yaml | 2 +- config/AEPSA/SolarEnergy.yaml | 2 +- model/ASTRA/astra.py | 8 ++++---- model/ASTRA/astrav2.py | 8 ++++---- model/ASTRA/astrav3.py | 8 ++++---- model/model_selector.py | 18 +++++++++--------- trainer/DCRNN_Trainer.py | 16 ++++++++-------- 18 files changed, 42 insertions(+), 42 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 3dc2b03..4fb2121 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -153,7 +153,7 @@ "args": "--config ./config/REPST/AirQuality.yaml" }, - // AEPSA 模型组 + // ASTRA 模型组 { "name": "AEPSA: PEMS-BAY", "type": "debugpy", diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml index d6061d9..455fc4b 100644 --- a/config/AEPSA/AirQuality.yaml +++ b/config/AEPSA/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-InFlow.yaml b/config/AEPSA/BJTaxi-InFlow.yaml index a453b38..c2766bb 100644 --- a/config/AEPSA/BJTaxi-InFlow.yaml +++ b/config/AEPSA/BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-Inflow.yaml b/config/AEPSA/BJTaxi-Inflow.yaml index a453b38..c2766bb 100644 --- a/config/AEPSA/BJTaxi-Inflow.yaml +++ b/config/AEPSA/BJTaxi-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-OutFlow.yaml b/config/AEPSA/BJTaxi-OutFlow.yaml index 9fa0f5f..ee570f3 100644 --- a/config/AEPSA/BJTaxi-OutFlow.yaml +++ b/config/AEPSA/BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-outflow.yaml b/config/AEPSA/BJTaxi-outflow.yaml index 9fa0f5f..ee570f3 100644 --- a/config/AEPSA/BJTaxi-outflow.yaml +++ b/config/AEPSA/BJTaxi-outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/METR-LA.yaml b/config/AEPSA/METR-LA.yaml index a623226..87bf1ac 100644 --- a/config/AEPSA/METR-LA.yaml +++ b/config/AEPSA/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-InFlow.yaml b/config/AEPSA/NYCBike-InFlow.yaml index b561493..1c80773 100644 --- a/config/AEPSA/NYCBike-InFlow.yaml +++ b/config/AEPSA/NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-OutFlow.yaml b/config/AEPSA/NYCBike-OutFlow.yaml index 5c4da71..1ece121 100644 --- a/config/AEPSA/NYCBike-OutFlow.yaml +++ b/config/AEPSA/NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-inflow.yaml b/config/AEPSA/NYCBike-inflow.yaml index e4ba138..5431fba 100644 --- a/config/AEPSA/NYCBike-inflow.yaml +++ b/config/AEPSA/NYCBike-inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-outflow.yaml b/config/AEPSA/NYCBike-outflow.yaml index 7cb6798..194c330 100644 --- a/config/AEPSA/NYCBike-outflow.yaml +++ b/config/AEPSA/NYCBike-outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/PEMS-BAY.yaml b/config/AEPSA/PEMS-BAY.yaml index f75c63a..e111654 100755 --- a/config/AEPSA/PEMS-BAY.yaml +++ b/config/AEPSA/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/SolarEnergy.yaml b/config/AEPSA/SolarEnergy.yaml index 669c9f4..4160077 100644 --- a/config/AEPSA/SolarEnergy.yaml +++ b/config/AEPSA/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 7ea003d..0ed2333 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer import torch.nn.functional as F class DynamicGraphEnhancer(nn.Module): @@ -147,10 +147,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) -class AEPSA(nn.Module): +class ASTRA(nn.Module): def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index aac9149..79a1330 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import ReprogrammingLayer import torch.nn.functional as F # 基于动态图增强的时空序列预测模型实现 @@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 -class AEPSA(nn.Module): +class ASTRA(nn.Module): """自适应特征投影时空自注意力模型""" def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] # 运行设备 self.pred_len = configs['pred_len'] # 预测序列长度 self.seq_len = configs['seq_len'] # 输入序列长度 diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 99c6748..a29bfc3 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import ReprogrammingLayer import torch.nn.functional as F # 基于动态图增强的时空序列预测模型实现 @@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 -class AEPSA(nn.Module): +class ASTRA(nn.Module): """自适应特征投影时空自注意力模型""" def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] # 运行设备 self.pred_len = configs['pred_len'] # 预测序列长度 self.seq_len = configs['seq_len'] # 输入序列长度 diff --git a/model/model_selector.py b/model/model_selector.py index 633b02c..da54b33 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -23,9 +23,9 @@ from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet from model.REPST.repst import repst as REPST -from model.AEPSA.aepsa import AEPSA as AEPSA -from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 -from model.AEPSA.aepsav3 import AEPSA as AEPSAv3 +from model.ASTRA.astra import ASTRA as ASTRA +from model.ASTRA.astrav2 import ASTRA as ASTRAv2 +from model.ASTRA.astrav3 import ASTRA as ASTRAv3 @@ -83,9 +83,9 @@ def model_selector(config): return STAWnet(model_config) case "REPST": return REPST(model_config) - case "AEPSA": - return AEPSA(model_config) - case "AEPSA_v2": - return AEPSAv2(model_config) - case "AEPSA_v3": - return AEPSAv3(model_config) + case "ASTRA": + return ASTRA(model_config) + case "ASTRA_v2": + return ASTRAv2(model_config) + case "ASTRA_v3": + return ASTRAv3(model_config) diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index a60eddb..1911248 100755 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -104,14 +104,14 @@ class Trainer: loss = self.loss(output, label) # 检查output和label的shape是否一致 - if output.shape == label.shape: - print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") - import sys - sys.exit(0) - else: - print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") - import sys - sys.exit(1) + # if output.shape == label.shape: + # print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + # import sys + # sys.exit(0) + # else: + # print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + # import sys + # sys.exit(1) # 反归一化 d_output = self.scaler.inverse_transform(output) -- 2.40.1 From e4a7884c987909f4bf223424124b0a12508e9642 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 14:08:59 +0800 Subject: [PATCH 15/41] =?UTF-8?q?=E6=9B=B4=E5=90=8DASTRA=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 52 ++++++++++---------- config/{AEPSA => ASTRA}/AirQuality.yaml | 0 config/{AEPSA => ASTRA}/BJTaxi-InFlow.yaml | 0 config/{AEPSA => ASTRA}/BJTaxi-Inflow.yaml | 0 config/{AEPSA => ASTRA}/BJTaxi-OutFlow.yaml | 0 config/{AEPSA => ASTRA}/BJTaxi-outflow.yaml | 0 config/{AEPSA => ASTRA}/METR-LA.yaml | 0 config/{AEPSA => ASTRA}/NYCBike-InFlow.yaml | 0 config/{AEPSA => ASTRA}/NYCBike-OutFlow.yaml | 0 config/{AEPSA => ASTRA}/NYCBike-inflow.yaml | 0 config/{AEPSA => ASTRA}/NYCBike-outflow.yaml | 0 config/{AEPSA => ASTRA}/PEMS-BAY.yaml | 0 config/{AEPSA => ASTRA}/SolarEnergy.yaml | 0 config/{AEPSA => ASTRA}/v2_METR-LA.yaml | 0 config/{AEPSA => ASTRA}/v2_PEMS-BAY.yaml | 0 config/{AEPSA => ASTRA}/v2_SolarEnergy.yaml | 0 config/{AEPSA => ASTRA}/v3_METR-LA.yaml | 0 config/{AEPSA => ASTRA}/v3_PEMS-BAY.yaml | 0 18 files changed, 26 insertions(+), 26 deletions(-) rename config/{AEPSA => ASTRA}/AirQuality.yaml (100%) rename config/{AEPSA => ASTRA}/BJTaxi-InFlow.yaml (100%) rename config/{AEPSA => ASTRA}/BJTaxi-Inflow.yaml (100%) rename config/{AEPSA => ASTRA}/BJTaxi-OutFlow.yaml (100%) rename config/{AEPSA => ASTRA}/BJTaxi-outflow.yaml (100%) rename config/{AEPSA => ASTRA}/METR-LA.yaml (100%) rename config/{AEPSA => ASTRA}/NYCBike-InFlow.yaml (100%) rename config/{AEPSA => ASTRA}/NYCBike-OutFlow.yaml (100%) rename config/{AEPSA => ASTRA}/NYCBike-inflow.yaml (100%) rename config/{AEPSA => ASTRA}/NYCBike-outflow.yaml (100%) rename config/{AEPSA => ASTRA}/PEMS-BAY.yaml (100%) rename config/{AEPSA => ASTRA}/SolarEnergy.yaml (100%) rename config/{AEPSA => ASTRA}/v2_METR-LA.yaml (100%) rename config/{AEPSA => ASTRA}/v2_PEMS-BAY.yaml (100%) rename config/{AEPSA => ASTRA}/v2_SolarEnergy.yaml (100%) rename config/{AEPSA => ASTRA}/v3_METR-LA.yaml (100%) rename config/{AEPSA => ASTRA}/v3_PEMS-BAY.yaml (100%) diff --git a/.vscode/launch.json b/.vscode/launch.json index 4fb2121..e855fc7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -155,92 +155,92 @@ // ASTRA 模型组 { - "name": "AEPSA: PEMS-BAY", + "name": "ASTRA: PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/PEMS-BAY.yaml" + "args": "--config ./config/ASTRA/PEMS-BAY.yaml" }, { - "name": "AEPSA: METR-LA", + "name": "ASTRA: METR-LA", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/METR-LA.yaml" + "args": "--config ./config/ASTRA/METR-LA.yaml" }, { - "name": "AEPSA: AirQuality", + "name": "ASTRA: AirQuality", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/AirQuality.yaml" + "args": "--config ./config/ASTRA/AirQuality.yaml" }, { - "name": "AEPSA: BJTaxi-Inflow", + "name": "ASTRA: BJTaxi-Inflow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/BJTaxi-Inflow.yaml" + "args": "--config ./config/ASTRA/BJTaxi-Inflow.yaml" }, { - "name": "AEPSA: BJTaxi-outflow", + "name": "ASTRA: BJTaxi-outflow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/BJTaxi-outflow.yaml" + "args": "--config ./config/ASTRA/BJTaxi-outflow.yaml" }, { - "name": "AEPSA: NYCBike-inflow", + "name": "ASTRA: NYCBike-inflow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/NYCBike-inflow.yaml" + "args": "--config ./config/ASTRA/NYCBike-inflow.yaml" }, { - "name": "AEPSA: NYCBike-outflow", + "name": "ASTRA: NYCBike-outflow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/NYCBike-outflow.yaml" + "args": "--config ./config/ASTRA/NYCBike-outflow.yaml" }, { - "name": "AEPSA: SolarEnergy", + "name": "ASTRA: SolarEnergy", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/SolarEnergy.yaml" + "args": "--config ./config/ASTRA/SolarEnergy.yaml" }, { - "name": "AEPSA_v2: METR-LA", + "name": "ASTRA_v2: METR-LA", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/v2_METR-LA.yaml" + "args": "--config ./config/ASTRA/v2_METR-LA.yaml" }, { - "name": "AEPSA_v2: SolarEnergy", + "name": "ASTRA_v2: SolarEnergy", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/v2_SolarEnergy.yaml" + "args": "--config ./config/ASTRA/v2_SolarEnergy.yaml" }, { - "name": "AEPSA_v3: METR-LA", + "name": "ASTRA_v3: METR-LA", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/v3_METR-LA.yaml" + "args": "--config ./config/ASTRA/v3_METR-LA.yaml" }, { "name": "EXPB: NYCBike-InFlow", @@ -843,20 +843,20 @@ "args": "--config ./config/STGNCDE/PEMSD7.yaml" }, { - "name": "AEPSA: NYCBike-InFlow", + "name": "ASTRA: NYCBike-InFlow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/NYCBike-InFlow.yaml" + "args": "--config ./config/ASTRA/NYCBike-InFlow.yaml" }, { - "name": "AEPSA: NYCBike-OutFlow", + "name": "ASTRA: NYCBike-OutFlow", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/AEPSA/NYCBike-OutFlow.yaml" + "args": "--config ./config/ASTRA/NYCBike-OutFlow.yaml" }, { "name": "ST_SSL: NYCBike-InFlow", diff --git a/config/AEPSA/AirQuality.yaml b/config/ASTRA/AirQuality.yaml similarity index 100% rename from config/AEPSA/AirQuality.yaml rename to config/ASTRA/AirQuality.yaml diff --git a/config/AEPSA/BJTaxi-InFlow.yaml b/config/ASTRA/BJTaxi-InFlow.yaml similarity index 100% rename from config/AEPSA/BJTaxi-InFlow.yaml rename to config/ASTRA/BJTaxi-InFlow.yaml diff --git a/config/AEPSA/BJTaxi-Inflow.yaml b/config/ASTRA/BJTaxi-Inflow.yaml similarity index 100% rename from config/AEPSA/BJTaxi-Inflow.yaml rename to config/ASTRA/BJTaxi-Inflow.yaml diff --git a/config/AEPSA/BJTaxi-OutFlow.yaml b/config/ASTRA/BJTaxi-OutFlow.yaml similarity index 100% rename from config/AEPSA/BJTaxi-OutFlow.yaml rename to config/ASTRA/BJTaxi-OutFlow.yaml diff --git a/config/AEPSA/BJTaxi-outflow.yaml b/config/ASTRA/BJTaxi-outflow.yaml similarity index 100% rename from config/AEPSA/BJTaxi-outflow.yaml rename to config/ASTRA/BJTaxi-outflow.yaml diff --git a/config/AEPSA/METR-LA.yaml b/config/ASTRA/METR-LA.yaml similarity index 100% rename from config/AEPSA/METR-LA.yaml rename to config/ASTRA/METR-LA.yaml diff --git a/config/AEPSA/NYCBike-InFlow.yaml b/config/ASTRA/NYCBike-InFlow.yaml similarity index 100% rename from config/AEPSA/NYCBike-InFlow.yaml rename to config/ASTRA/NYCBike-InFlow.yaml diff --git a/config/AEPSA/NYCBike-OutFlow.yaml b/config/ASTRA/NYCBike-OutFlow.yaml similarity index 100% rename from config/AEPSA/NYCBike-OutFlow.yaml rename to config/ASTRA/NYCBike-OutFlow.yaml diff --git a/config/AEPSA/NYCBike-inflow.yaml b/config/ASTRA/NYCBike-inflow.yaml similarity index 100% rename from config/AEPSA/NYCBike-inflow.yaml rename to config/ASTRA/NYCBike-inflow.yaml diff --git a/config/AEPSA/NYCBike-outflow.yaml b/config/ASTRA/NYCBike-outflow.yaml similarity index 100% rename from config/AEPSA/NYCBike-outflow.yaml rename to config/ASTRA/NYCBike-outflow.yaml diff --git a/config/AEPSA/PEMS-BAY.yaml b/config/ASTRA/PEMS-BAY.yaml similarity index 100% rename from config/AEPSA/PEMS-BAY.yaml rename to config/ASTRA/PEMS-BAY.yaml diff --git a/config/AEPSA/SolarEnergy.yaml b/config/ASTRA/SolarEnergy.yaml similarity index 100% rename from config/AEPSA/SolarEnergy.yaml rename to config/ASTRA/SolarEnergy.yaml diff --git a/config/AEPSA/v2_METR-LA.yaml b/config/ASTRA/v2_METR-LA.yaml similarity index 100% rename from config/AEPSA/v2_METR-LA.yaml rename to config/ASTRA/v2_METR-LA.yaml diff --git a/config/AEPSA/v2_PEMS-BAY.yaml b/config/ASTRA/v2_PEMS-BAY.yaml similarity index 100% rename from config/AEPSA/v2_PEMS-BAY.yaml rename to config/ASTRA/v2_PEMS-BAY.yaml diff --git a/config/AEPSA/v2_SolarEnergy.yaml b/config/ASTRA/v2_SolarEnergy.yaml similarity index 100% rename from config/AEPSA/v2_SolarEnergy.yaml rename to config/ASTRA/v2_SolarEnergy.yaml diff --git a/config/AEPSA/v3_METR-LA.yaml b/config/ASTRA/v3_METR-LA.yaml similarity index 100% rename from config/AEPSA/v3_METR-LA.yaml rename to config/ASTRA/v3_METR-LA.yaml diff --git a/config/AEPSA/v3_PEMS-BAY.yaml b/config/ASTRA/v3_PEMS-BAY.yaml similarity index 100% rename from config/AEPSA/v3_PEMS-BAY.yaml rename to config/ASTRA/v3_PEMS-BAY.yaml -- 2.40.1 From 4984d245065f7d997011494391224b7cb22e1170 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 16:11:49 +0800 Subject: [PATCH 16/41] =?UTF-8?q?=E5=AE=9E=E7=8E=B0iTransformer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 + config/iTransformer/METR-LA.yaml | 52 +++++ dataloader/PeMSDdataloader.py | 14 +- dataloader/TSloader.py | 216 ++++++++++++++++++ dataloader/loader_selector.py | 3 + model/iTransformer/iTransformer.py | 43 ++++ model/iTransformer/layers/Embed.py | 19 ++ model/iTransformer/layers/SelfAttn.py | 82 +++++++ .../iTransformer/layers/Transformer_EncDec.py | 57 +++++ model/model_selector.py | 3 + requirements.txt | 1 + run.py | 26 +-- 12 files changed, 499 insertions(+), 25 deletions(-) create mode 100644 config/iTransformer/METR-LA.yaml create mode 100755 dataloader/TSloader.py create mode 100644 model/iTransformer/iTransformer.py create mode 100644 model/iTransformer/layers/Embed.py create mode 100644 model/iTransformer/layers/SelfAttn.py create mode 100644 model/iTransformer/layers/Transformer_EncDec.py diff --git a/.vscode/launch.json b/.vscode/launch.json index e855fc7..40c72a3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2089,6 +2089,14 @@ "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/STGCN/PEMSD7.yaml" + }, + { + "name": "iTransformer: METR-LA", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/iTransformer/METR-LA.yaml" } ] } \ No newline at end of file diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml new file mode 100644 index 0000000..60772c8 --- /dev/null +++ b/config/iTransformer/METR-LA.yaml @@ -0,0 +1,52 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/PeMSDdataloader.py b/dataloader/PeMSDdataloader.py index ea6e89b..0e079e1 100755 --- a/dataloader/PeMSDdataloader.py +++ b/dataloader/PeMSDdataloader.py @@ -1,9 +1,9 @@ -from utils.normalization import normalize_dataset -from dataloader.data_selector import load_st_dataset - import numpy as np import torch +from dataloader.data_selector import load_st_dataset +from utils.normalization import normalize_dataset + def get_dataloader(args, normalizer="std", single=True): data = load_st_dataset(args) @@ -152,7 +152,7 @@ def add_window_y(data, window=3, horizon=1, single=False): offset = window if not single else window + horizon - 1 return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset) -if __name__ == "__main__": - from dataloader.data_selector import load_st_dataset - res = load_st_dataset({"dataset": "SD"}) - print(f"Dataset shape: {res.shape}") +# if __name__ == "__main__": +# from dataloader.data_selector import load_st_dataset +# res = load_st_dataset({"dataset": "SD"}) +# print(f"Dataset shape: {res.shape}") diff --git a/dataloader/TSloader.py b/dataloader/TSloader.py new file mode 100755 index 0000000..abcc604 --- /dev/null +++ b/dataloader/TSloader.py @@ -0,0 +1,216 @@ +from dataloader.data_selector import load_st_dataset +from utils.normalization import normalize_dataset + +import numpy as np +import torch + + +def get_dataloader(args, normalizer="std", single=True): + data = load_st_dataset(args) + + args = args["data"] + L, N, F = data.shape + data = data.reshape(L, N*F) # [L, N*F] + + # Generate sliding windows for main data and add time features + x, y = _prepare_data_with_windows(data, args, single) + + # Split data + split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio + x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) + y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) + + # Normalize x and y using the same scaler + scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) + _apply_existing_scaler(y_train, y_val, y_test, scaler, args) + + # Create dataloaders + return ( + _create_dataloader(x_train, y_train, args["batch_size"], True, False), + _create_dataloader(x_val, y_val, args["batch_size"], False, False), + _create_dataloader(x_test, y_test, args["batch_size"], False, False), + scaler + ) + + +def _prepare_data_with_windows(data, args, single): + # Generate sliding windows for main data + x = add_window_x(data, args["lag"], args["horizon"], single) + y = add_window_y(data, args["lag"], args["horizon"], single) + + # Generate time features + time_features = _generate_time_features(data.shape[0], args) + + # Add time features to x and y + x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) + y = _add_time_features(y, time_features, args["lag"], args["horizon"], single, add_window_y) + + return x, y + + +def _generate_time_features(L, args): + # For time series data, we generate time features for each time step + # [L, 1] -> [L, T, 1] by repeating across time dimension + T = args.get("time_dim", 1) # Get time dimension size if available + + time_in_day = [i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)] + time_in_day = np.array(time_in_day)[:, None, None] # [L, 1, 1] + time_in_day = np.tile(time_in_day, (1, T, 1)) # [L, T, 1] + + day_in_week = [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)] + day_in_week = np.array(day_in_week)[:, None, None] # [L, 1, 1] + day_in_week = np.tile(day_in_week, (1, T, 1)) # [L, T, 1] + + return time_in_day, day_in_week + + + +def _add_time_features(data, time_features, lag, horizon, single, window_fn): + time_in_day, day_in_week = time_features + time_day = window_fn(time_in_day, lag, horizon, single) + time_week = window_fn(day_in_week, lag, horizon, single) + return np.concatenate([data, time_day, time_week], axis=-1) + + +def _normalize_data(train_data, val_data, test_data, args, normalizer): + scaler = normalize_dataset(train_data[..., : args["input_dim"]], normalizer, args["column_wise"]) + + for data in [train_data, val_data, test_data]: + data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) + + return scaler + + +def _apply_existing_scaler(train_data, val_data, test_data, scaler, args): + for data in [train_data, val_data, test_data]: + data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) + + +def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + X_tensor = torch.tensor(X_data, dtype=torch.float32, device=device) + Y_tensor = torch.tensor(Y_data, dtype=torch.float32, device=device) + dataset = torch.utils.data.TensorDataset(X_tensor, Y_tensor) + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) + + +def split_data_by_days(data, val_days, test_days, interval=30): + t = int((24 * 60) / interval) + test_data = data[-t * int(test_days) :] + val_data = data[-t * int(test_days + val_days) : -t * int(test_days)] + train_data = data[: -t * int(test_days + val_days)] + return train_data, val_data, test_data + + +def split_data_by_ratio(data, val_ratio, test_ratio): + data_len = data.shape[0] + test_data = data[-int(data_len * test_ratio) :] + val_data = data[ + -int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio) + ] + train_data = data[: -int(data_len * (test_ratio + val_ratio))] + return train_data, val_data, test_data + + + + +def _generate_windows(data, window=3, horizon=1, offset=0): + """ + Internal helper function to generate sliding windows. + + :param data: Input data, shape [L, T, C] + :param window: Window size + :param horizon: Horizon size + :param offset: Offset from window start + :return: Windowed data, shape [num_windows, window, T, C] + """ + length = len(data) + end_index = length - horizon - window + 1 + windows = [] + index = 0 + + if end_index <= 0: + raise ValueError(f"end_index is non-positive: {end_index}, length={length}, horizon={horizon}, window={window}") + + while index < end_index: + window_data = data[index + offset : index + offset + window] + windows.append(window_data) + index += 1 + + if not windows: + raise ValueError("No windows generated") + + # Check window shapes + first_shape = windows[0].shape + for i, w in enumerate(windows): + if w.shape != first_shape: + raise ValueError(f"Window {i} has shape {w.shape}, expected {first_shape}") + + return np.array(windows) + +def add_window_x(data, window=3, horizon=1, single=False): + """ + Generate windowed X values from the input data. + + :param data: Input data, shape [L, T, C] + :param window: Size of the sliding window + :param horizon: Horizon size + :param single: If True, generate single-step windows, else multi-step + :return: X with shape [num_windows, window, T, C] + """ + return _generate_windows(data, window, horizon, offset=0) + +def add_window_y(data, window=3, horizon=1, single=False): + """ + Generate windowed Y values from the input data. + + :param data: Input data, shape [L, T, C] + :param window: Size of the sliding window + :param horizon: Horizon size + :param single: If True, generate single-step windows, else multi-step + :return: Y with shape [num_windows, horizon, T, C] + """ + offset = window if not single else window + horizon - 1 + return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset) + +if __name__ == "__main__": + + # Test with a dummy config using METR-LA dataset + dummy_args = { + "basic": { + "dataset": "METR-LA" + }, + "data": { + "lag": 3, + "horizon": 1, + "val_ratio": 0.1, + "test_ratio": 0.2, + "steps_per_day": 288, + "days_per_week": 7, + "input_dim": 1, + "column_wise": False, + "batch_size": 32, + "time_dim": 1 # Add time dimension parameter + } + } + + try: + # Load data + data = load_st_dataset(dummy_args) + print(f"Original data shape: {data.shape}") + + # Get dataloader + train_loader, val_loader, test_loader, scaler = get_dataloader(dummy_args) + + # Test data loader + for batch_x, batch_y in train_loader: + print(f"Batch X shape: {batch_x.shape}") + print(f"Batch Y shape: {batch_y.shape}") + break + + print("Test passed successfully!") + + except Exception as e: + print(f"Test failed with error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index f8dacdb..afd606d 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -3,6 +3,7 @@ from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader +from dataloader.TSloader import get_dataloader as TS_loader def get_dataloader(config, normalizer, single): @@ -16,5 +17,7 @@ def get_dataloader(config, normalizer, single): return DCRNN_loader(config, normalizer, single) case "EXP": return EXP_loader(config, normalizer, single) + case "iTransformer": + return TS_loader(config, normalizer, single) case _: return normal_loader(config, normalizer, single) diff --git a/model/iTransformer/iTransformer.py b/model/iTransformer/iTransformer.py new file mode 100644 index 0000000..3fa7422 --- /dev/null +++ b/model/iTransformer/iTransformer.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from model.iTransformer.layers.Transformer_EncDec import Encoder, EncoderLayer +from model.iTransformer.layers.SelfAttn import FullAttention, AttentionLayer +from model.iTransformer.layers.Embed import DataEmbedding_inverted + +class iTransformer(nn.Module): + """ + Paper link: https://arxiv.org/abs/2310.06625 + """ + + def __init__(self, args): + super(iTransformer, self).__init__() + self.pred_len = args['pred_len'] + # Embedding + self.enc_embedding = DataEmbedding_inverted(args['seq_len'], args['d_model'], args['dropout']) + # Encoder-only architecture + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention(False, attention_dropout=args['dropout'], + output_attention=args['output_attention']), args['d_model'], args['n_heads']), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'] + ) for l in range(args['e_layers']) + ], + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + self.projector = nn.Linear(args['d_model'], args['pred_len'], bias=True) + + def forecast(self, x_enc, x_mark_enc): + _, _, N = x_enc.shape # B, T, C + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates + return dec_out, attns + + def forward(self, x_enc, x_mark_enc): + dec_out, attns = self.forecast(x_enc, x_mark_enc) + return dec_out[:, -self.pred_len:, :] # [B, T, C] \ No newline at end of file diff --git a/model/iTransformer/layers/Embed.py b/model/iTransformer/layers/Embed.py new file mode 100644 index 0000000..8e7209b --- /dev/null +++ b/model/iTransformer/layers/Embed.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, dropout=0.1): + super(DataEmbedding_inverted, self).__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + # the potential to take covariates (e.g. timestamps) as tokens + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) \ No newline at end of file diff --git a/model/iTransformer/layers/SelfAttn.py b/model/iTransformer/layers/SelfAttn.py new file mode 100644 index 0000000..e5670e1 --- /dev/null +++ b/model/iTransformer/layers/SelfAttn.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import numpy as np +from math import sqrt + + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask, + tau=tau, + delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + diff --git a/model/iTransformer/layers/Transformer_EncDec.py b/model/iTransformer/layers/Transformer_EncDec.py new file mode 100644 index 0000000..6116325 --- /dev/null +++ b/model/iTransformer/layers/Transformer_EncDec.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask, + tau=tau, delta=delta + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index da54b33..0d9d120 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -26,6 +26,7 @@ from model.REPST.repst import repst as REPST from model.ASTRA.astra import ASTRA as ASTRA from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 +from model.iTransformer.iTransformer import iTransformer @@ -89,3 +90,5 @@ def model_selector(config): return ASTRAv2(model_config) case "ASTRA_v3": return ASTRAv3(model_config) + case "iTransformer": + return iTransformer(model_config) diff --git a/requirements.txt b/requirements.txt index c964fff..a3568f0 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +numpy pyyaml tqdm statsmodels diff --git a/run.py b/run.py index 95867f4..e62bfd6 100755 --- a/run.py +++ b/run.py @@ -11,36 +11,28 @@ from trainer.trainer_selector import select_trainer def main(): + # 读取配置 args = parse_args() + + # 初始化 device, seed, model, data, trainer args = init.init_device(args) init.init_seed(args["basic"]["seed"]) - - # Load model model = init.init_model(args) - - # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, normalizer=args["data"]["normalizer"], single=False ) - loss = init.init_loss(args, scaler) optimizer, lr_scheduler = init.init_optimizer(model, args["train"]) init.create_logs(args) - - # Start training or testing trainer = select_trainer( model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, + loss, optimizer, + train_loader, val_loader, test_loader, scaler, args, - lr_scheduler, - extra_data, + lr_scheduler, extra_data, ) + # 开始训练 match args["basic"]["mode"]: case "train": trainer.train() @@ -54,9 +46,7 @@ def main(): ) trainer.test( model.to(args["basic"]["device"]), - trainer.args, - test_loader, - scaler, + trainer.args, test_loader, scaler, trainer.logger, ) case _: -- 2.40.1 From faeb90e734408ad5af3b687f7cd4fb340120ea29 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 16:12:05 +0800 Subject: [PATCH 17/41] =?UTF-8?q?=E5=AE=9E=E7=8E=B0iTransformer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/iTransformer/METR-LA.yaml | 4 +-- dataloader/TSloader.py | 47 +++--------------------------- model/iTransformer/iTransformer.py | 2 +- 3 files changed, 7 insertions(+), 46 deletions(-) diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 60772c8..925c6d2 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -42,11 +42,11 @@ train: lr_decay: true lr_decay_rate: 0.3 lr_decay_step: 5,20,40,70 - lr_init: 0.003 + lr_init: 0.0001 mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1 + output_dim: 207 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/dataloader/TSloader.py b/dataloader/TSloader.py index abcc604..bd72b9a 100755 --- a/dataloader/TSloader.py +++ b/dataloader/TSloader.py @@ -32,58 +32,24 @@ def get_dataloader(args, normalizer="std", single=True): scaler ) - def _prepare_data_with_windows(data, args, single): # Generate sliding windows for main data x = add_window_x(data, args["lag"], args["horizon"], single) y = add_window_y(data, args["lag"], args["horizon"], single) - - # Generate time features - time_features = _generate_time_features(data.shape[0], args) - - # Add time features to x and y - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - y = _add_time_features(y, time_features, args["lag"], args["horizon"], single, add_window_y) - return x, y - -def _generate_time_features(L, args): - # For time series data, we generate time features for each time step - # [L, 1] -> [L, T, 1] by repeating across time dimension - T = args.get("time_dim", 1) # Get time dimension size if available - - time_in_day = [i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)] - time_in_day = np.array(time_in_day)[:, None, None] # [L, 1, 1] - time_in_day = np.tile(time_in_day, (1, T, 1)) # [L, T, 1] - - day_in_week = [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)] - day_in_week = np.array(day_in_week)[:, None, None] # [L, 1, 1] - day_in_week = np.tile(day_in_week, (1, T, 1)) # [L, T, 1] - - return time_in_day, day_in_week - - - -def _add_time_features(data, time_features, lag, horizon, single, window_fn): - time_in_day, day_in_week = time_features - time_day = window_fn(time_in_day, lag, horizon, single) - time_week = window_fn(day_in_week, lag, horizon, single) - return np.concatenate([data, time_day, time_week], axis=-1) - - def _normalize_data(train_data, val_data, test_data, args, normalizer): - scaler = normalize_dataset(train_data[..., : args["input_dim"]], normalizer, args["column_wise"]) + scaler = normalize_dataset(train_data[..., : args["num_nodes"]], normalizer, args["column_wise"]) for data in [train_data, val_data, test_data]: - data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) + data[..., : args["num_nodes"]] = scaler.transform(data[..., : args["num_nodes"]]) return scaler def _apply_existing_scaler(train_data, val_data, test_data, scaler, args): for data in [train_data, val_data, test_data]: - data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) + data[..., : args["num_nodes"]] = scaler.transform(data[..., : args["num_nodes"]]) def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last): @@ -105,15 +71,10 @@ def split_data_by_days(data, val_days, test_days, interval=30): def split_data_by_ratio(data, val_ratio, test_ratio): data_len = data.shape[0] test_data = data[-int(data_len * test_ratio) :] - val_data = data[ - -int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio) - ] + val_data = data[-int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio)] train_data = data[: -int(data_len * (test_ratio + val_ratio))] return train_data, val_data, test_data - - - def _generate_windows(data, window=3, horizon=1, offset=0): """ Internal helper function to generate sliding windows. diff --git a/model/iTransformer/iTransformer.py b/model/iTransformer/iTransformer.py index 3fa7422..3cc0818 100644 --- a/model/iTransformer/iTransformer.py +++ b/model/iTransformer/iTransformer.py @@ -38,6 +38,6 @@ class iTransformer(nn.Module): dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates return dec_out, attns - def forward(self, x_enc, x_mark_enc): + def forward(self, x_enc, x_mark_enc=None): dec_out, attns = self.forecast(x_enc, x_mark_enc) return dec_out[:, -self.pred_len:, :] # [B, T, C] \ No newline at end of file -- 2.40.1 From b57fcef039983637791e204ec10896738247101a Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 16:56:30 +0800 Subject: [PATCH 18/41] =?UTF-8?q?=E6=96=B0=E5=A2=9ESolarEnergy-iTransforme?= =?UTF-8?q?r=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 18 +++++++- config/iTransformer/AirQuality.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/BJTaxi-Inflow.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/BJTaxi-Outflow.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/METR-LA.yaml | 2 +- config/iTransformer/NYCBike-Inflow.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/NYCBike-Outflow.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/PEMS-BAY.yaml | 52 ++++++++++++++++++++++++ config/iTransformer/SolarEnergy.yaml | 52 ++++++++++++++++++++++++ dataloader/TSloader.py | 1 + 10 files changed, 383 insertions(+), 2 deletions(-) create mode 100644 config/iTransformer/AirQuality.yaml create mode 100644 config/iTransformer/BJTaxi-Inflow.yaml create mode 100644 config/iTransformer/BJTaxi-Outflow.yaml create mode 100644 config/iTransformer/NYCBike-Inflow.yaml create mode 100644 config/iTransformer/NYCBike-Outflow.yaml create mode 100644 config/iTransformer/PEMS-BAY.yaml create mode 100644 config/iTransformer/SolarEnergy.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 40c72a3..21af481 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2097,6 +2097,22 @@ "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/iTransformer/METR-LA.yaml" - } + }, + { + "name": "iTransformer: AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/iTransformer/AirQuality.yaml" + }, + { + "name": "iTransformer: SolarEnergy", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/iTransformer/SolarEnergy.yaml" + }, ] } \ No newline at end of file diff --git a/config/iTransformer/AirQuality.yaml b/config/iTransformer/AirQuality.yaml new file mode 100644 index 0000000..74bf69d --- /dev/null +++ b/config/iTransformer/AirQuality.yaml @@ -0,0 +1,52 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 35 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..8a0e7c9 --- /dev/null +++ b/config/iTransformer/BJTaxi-Inflow.yaml @@ -0,0 +1,52 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1024 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..ea4af50 --- /dev/null +++ b/config/iTransformer/BJTaxi-Outflow.yaml @@ -0,0 +1,52 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1024 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 925c6d2..3d02d8b 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -1,6 +1,6 @@ basic: dataset: METR-LA - device: cuda:0 + device: cuda:1 mode: train model: iTransformer seed: 2023 diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-Inflow.yaml new file mode 100644 index 0000000..598ca1e --- /dev/null +++ b/config/iTransformer/NYCBike-Inflow.yaml @@ -0,0 +1,52 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 128 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-Outflow.yaml new file mode 100644 index 0000000..b6a8994 --- /dev/null +++ b/config/iTransformer/NYCBike-Outflow.yaml @@ -0,0 +1,52 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 128 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/PEMS-BAY.yaml b/config/iTransformer/PEMS-BAY.yaml new file mode 100644 index 0000000..5140b73 --- /dev/null +++ b/config/iTransformer/PEMS-BAY.yaml @@ -0,0 +1,52 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 325 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/SolarEnergy.yaml b/config/iTransformer/SolarEnergy.yaml new file mode 100644 index 0000000..bab0108 --- /dev/null +++ b/config/iTransformer/SolarEnergy.yaml @@ -0,0 +1,52 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 137 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/TSloader.py b/dataloader/TSloader.py index bd72b9a..33b5a17 100755 --- a/dataloader/TSloader.py +++ b/dataloader/TSloader.py @@ -7,6 +7,7 @@ import torch def get_dataloader(args, normalizer="std", single=True): data = load_st_dataset(args) + data = data[..., 0:1] args = args["data"] L, N, F = data.shape -- 2.40.1 From 5c2380ae2115dc34cdc4d927f56a40f677cd2e7f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 9 Dec 2025 17:49:18 +0800 Subject: [PATCH 19/41] =?UTF-8?q?=E9=80=82=E9=85=8DHI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 4 +-- config/HI/AirQuality.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/BJTaxi-Inflow.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/BJTaxi-Outflow.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/METR-LA.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/NYCBike-Inflow.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/NYCBike-Outflow.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/PEMS-BAY.yaml | 48 ++++++++++++++++++++++++++++++++++ config/HI/SolarEnergy.yaml | 48 ++++++++++++++++++++++++++++++++++ dataloader/loader_selector.py | 28 +++++++++++--------- model/HI/HI.py | 45 +++++++++++++++++++++++++++++++ model/model_selector.py | 4 +++ utils/initializer.py | 3 +++ 13 files changed, 453 insertions(+), 15 deletions(-) create mode 100644 config/HI/AirQuality.yaml create mode 100644 config/HI/BJTaxi-Inflow.yaml create mode 100644 config/HI/BJTaxi-Outflow.yaml create mode 100644 config/HI/METR-LA.yaml create mode 100644 config/HI/NYCBike-Inflow.yaml create mode 100644 config/HI/NYCBike-Outflow.yaml create mode 100644 config/HI/PEMS-BAY.yaml create mode 100644 config/HI/SolarEnergy.yaml create mode 100644 model/HI/HI.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 21af481..cc2c023 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2107,12 +2107,12 @@ "args": "--config ./config/iTransformer/AirQuality.yaml" }, { - "name": "iTransformer: SolarEnergy", + "name": "HI: PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/iTransformer/SolarEnergy.yaml" + "args": "--config ./config/HI/PEMS-BAY.yaml" }, ] } \ No newline at end of file diff --git a/config/HI/AirQuality.yaml b/config/HI/AirQuality.yaml new file mode 100644 index 0000000..07300c4 --- /dev/null +++ b/config/HI/AirQuality.yaml @@ -0,0 +1,48 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 35 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/BJTaxi-Inflow.yaml b/config/HI/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..d752667 --- /dev/null +++ b/config/HI/BJTaxi-Inflow.yaml @@ -0,0 +1,48 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1024 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..271fbc7 --- /dev/null +++ b/config/HI/BJTaxi-Outflow.yaml @@ -0,0 +1,48 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1024 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/METR-LA.yaml b/config/HI/METR-LA.yaml new file mode 100644 index 0000000..0826302 --- /dev/null +++ b/config/HI/METR-LA.yaml @@ -0,0 +1,48 @@ +basic: + dataset: METR-LA + device: cuda:1 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 207 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/NYCBike-Inflow.yaml b/config/HI/NYCBike-Inflow.yaml new file mode 100644 index 0000000..be217a9 --- /dev/null +++ b/config/HI/NYCBike-Inflow.yaml @@ -0,0 +1,48 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 128 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/NYCBike-Outflow.yaml b/config/HI/NYCBike-Outflow.yaml new file mode 100644 index 0000000..0f93fe5 --- /dev/null +++ b/config/HI/NYCBike-Outflow.yaml @@ -0,0 +1,48 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 128 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/PEMS-BAY.yaml b/config/HI/PEMS-BAY.yaml new file mode 100644 index 0000000..832f455 --- /dev/null +++ b/config/HI/PEMS-BAY.yaml @@ -0,0 +1,48 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 325 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml new file mode 100644 index 0000000..8f55fac --- /dev/null +++ b/config/HI/SolarEnergy.yaml @@ -0,0 +1,48 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: HI + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + input_len: 24 + output_len: 24 + reverse: False + + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 1 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: + lr_decay_rate: + lr_decay_step: + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 137 + optimizer: null + plot: false + real_value: true + scheduler: null + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index afd606d..f9bf823 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -7,17 +7,19 @@ from dataloader.TSloader import get_dataloader as TS_loader def get_dataloader(config, normalizer, single): + TS_model = ["iTransformer", "HI"] model_name = config["basic"]["model"] - match model_name: - case "STGNCDE": - return cde_loader(config, normalizer, single) - case "STGNRDE": - return nrde_loader(config, normalizer, single) - case "DCRNN": - return DCRNN_loader(config, normalizer, single) - case "EXP": - return EXP_loader(config, normalizer, single) - case "iTransformer": - return TS_loader(config, normalizer, single) - case _: - return normal_loader(config, normalizer, single) + if model_name in TS_model: + return TS_loader(config, normalizer, single) + else : + match model_name: + case "STGNCDE": + return cde_loader(config, normalizer, single) + case "STGNRDE": + return nrde_loader(config, normalizer, single) + case "DCRNN": + return DCRNN_loader(config, normalizer, single) + case "EXP": + return EXP_loader(config, normalizer, single) + case _: + return normal_loader(config, normalizer, single) diff --git a/model/HI/HI.py b/model/HI/HI.py new file mode 100644 index 0000000..aefbd12 --- /dev/null +++ b/model/HI/HI.py @@ -0,0 +1,45 @@ +from typing import List +import torch +from torch import nn + + +class HI(nn.Module): + """ + Paper: Historical Inertia: A Neglected but Powerful Baseline for Long Sequence Time-series Forecasting + Link: https://arxiv.org/abs/2103.16349 + Official code: None + Venue: CIKM 2021 + Task: Long-term Time Series Forecasting + """ + + def __init__(self, config): + """ + Init HI. + + Args: + config (HIConfig): model config. + """ + + super().__init__() + self.input_len = config['input_len'] + self.output_len = config['output_len'] + assert self.input_len >= self.output_len, "HI model requires input length > output length" + self.reverse = config['reverse'] + # self.fake_param = nn.Linear(1, 1, bias=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Forward function of HI. + + Args: + inputs (torch.Tensor): shape = [B, L_in, N] + + Returns: + torch.Tensor: model prediction [B, L_out, N]. + """ + # historical inertia + prediction = inputs[:, -self.output_len:, :] + # last point + # prediction = inputs[:, [-1], :].expand(-1, self.output_len, -1) + if self.reverse: + prediction = prediction.flip(dims=[1]) + return prediction \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 0d9d120..7403893 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -1,4 +1,5 @@ from model.DDGCRN.DDGCRN import DDGCRN +from model.HI import HI from model.TWDGCN.TWDGCN import TWDGCN from model.AGCRN.AGCRN import AGCRN from model.NLT.HierAttnLstm import HierAttnLstm @@ -27,6 +28,7 @@ from model.ASTRA.astra import ASTRA as ASTRA from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer +from model.HI.HI import HI @@ -92,3 +94,5 @@ def model_selector(config): return ASTRAv3(model_config) case "iTransformer": return iTransformer(model_config) + case "HI": + return HI(model_config) diff --git a/utils/initializer.py b/utils/initializer.py index b69c67f..7bee2be 100755 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -23,6 +23,9 @@ def init_model(args): def init_optimizer(model, args): + optimizer = None + lr_scheduler = None + optim = args.get("optimizer", "Adam") match optim : case "Adam": -- 2.40.1 From 560d24e5a86a1eb67a62303e6f61df692faf4f9b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 10:39:41 +0800 Subject: [PATCH 20/41] =?UTF-8?q?=E6=9B=B4=E6=96=B0v2=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 4 +- config/ASTRA/v2_AirQuality.yaml | 54 +++++++++++++++++++ ...Taxi-Inflow.yaml => v2_BJTaxi-InFlow.yaml} | 2 +- ...xi-outflow.yaml => v2_BJTaxi-OutFlow.yaml} | 2 +- ...ike-inflow.yaml => v2_NYCBike-InFlow.yaml} | 4 +- ...e-outflow.yaml => v2_NYCBike-OutFlow.yaml} | 4 +- model/ASTRA/astrav2.py | 8 +-- 7 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 config/ASTRA/v2_AirQuality.yaml rename config/ASTRA/{BJTaxi-Inflow.yaml => v2_BJTaxi-InFlow.yaml} (97%) rename config/ASTRA/{BJTaxi-outflow.yaml => v2_BJTaxi-OutFlow.yaml} (97%) rename config/ASTRA/{NYCBike-inflow.yaml => v2_NYCBike-InFlow.yaml} (95%) rename config/ASTRA/{NYCBike-outflow.yaml => v2_NYCBike-OutFlow.yaml} (95%) diff --git a/.vscode/launch.json b/.vscode/launch.json index cc2c023..54aad8a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -219,12 +219,12 @@ "args": "--config ./config/ASTRA/SolarEnergy.yaml" }, { - "name": "ASTRA_v2: METR-LA", + "name": "ASTRA_v2: AirQuality", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_METR-LA.yaml" + "args": "--config ./config/ASTRA/v2_AirQuality.yaml" }, { "name": "ASTRA_v2: SolarEnergy", diff --git a/config/ASTRA/v2_AirQuality.yaml b/config/ASTRA/v2_AirQuality.yaml new file mode 100644 index 0000000..10796d2 --- /dev/null +++ b/config/ASTRA/v2_AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: ASTRA_v2 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/BJTaxi-Inflow.yaml b/config/ASTRA/v2_BJTaxi-InFlow.yaml similarity index 97% rename from config/ASTRA/BJTaxi-Inflow.yaml rename to config/ASTRA/v2_BJTaxi-InFlow.yaml index c2766bb..d1cc5ea 100644 --- a/config/ASTRA/BJTaxi-Inflow.yaml +++ b/config/ASTRA/v2_BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/BJTaxi-outflow.yaml b/config/ASTRA/v2_BJTaxi-OutFlow.yaml similarity index 97% rename from config/ASTRA/BJTaxi-outflow.yaml rename to config/ASTRA/v2_BJTaxi-OutFlow.yaml index ee570f3..d6e0723 100644 --- a/config/ASTRA/BJTaxi-outflow.yaml +++ b/config/ASTRA/v2_BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/NYCBike-inflow.yaml b/config/ASTRA/v2_NYCBike-InFlow.yaml similarity index 95% rename from config/ASTRA/NYCBike-inflow.yaml rename to config/ASTRA/v2_NYCBike-InFlow.yaml index 5431fba..de5b6a1 100644 --- a/config/ASTRA/NYCBike-inflow.yaml +++ b/config/ASTRA/v2_NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: @@ -14,7 +14,7 @@ data: lag: 24 normalizer: std num_nodes: 128 - steps_per_day: 24 + steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 diff --git a/config/ASTRA/NYCBike-outflow.yaml b/config/ASTRA/v2_NYCBike-OutFlow.yaml similarity index 95% rename from config/ASTRA/NYCBike-outflow.yaml rename to config/ASTRA/v2_NYCBike-OutFlow.yaml index 194c330..dda718d 100644 --- a/config/ASTRA/NYCBike-outflow.yaml +++ b/config/ASTRA/v2_NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: ASTRA + model: ASTRA_v2 seed: 2023 data: @@ -14,7 +14,7 @@ data: lag: 24 normalizer: std num_nodes: 128 - steps_per_day: 24 + steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 79a1330..6a47206 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -184,7 +184,7 @@ class ASTRA(nn.Module): def forward(self, x): # 数据处理 - x = x[..., :1] # [B,T,N,1] + x = x[..., :self.input_dim] # [B,T,N,1] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] # 图编码 @@ -202,7 +202,9 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(enc_out) # [B,N,pred_len] # 维度调整 - outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] - outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, self.input_dim) + outputs = outputs.permute(0,2,1,3) return outputs \ No newline at end of file -- 2.40.1 From 44ffe94c95edd61ed3b7f0d0333a87ce445468cf Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 11:01:00 +0800 Subject: [PATCH 21/41] =?UTF-8?q?=E6=9B=B4=E6=96=B0iTransformer,=20HI?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E3=80=82=E6=9B=B4=E6=96=B0TS=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E8=BD=BD=E5=85=A5=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/HI/AirQuality.yaml | 6 +++--- config/HI/BJTaxi-Inflow.yaml | 6 +++--- config/HI/BJTaxi-Outflow.yaml | 6 +++--- config/HI/METR-LA.yaml | 6 +++--- config/HI/NYCBike-Inflow.yaml | 6 +++--- config/HI/NYCBike-Outflow.yaml | 6 +++--- config/HI/PEMS-BAY.yaml | 6 +++--- config/HI/SolarEnergy.yaml | 6 +++--- config/iTransformer/AirQuality.yaml | 6 +++--- config/iTransformer/BJTaxi-Inflow.yaml | 6 +++--- config/iTransformer/BJTaxi-Outflow.yaml | 6 +++--- config/iTransformer/METR-LA.yaml | 6 +++--- config/iTransformer/NYCBike-Inflow.yaml | 6 +++--- config/iTransformer/NYCBike-Outflow.yaml | 6 +++--- config/iTransformer/PEMS-BAY.yaml | 6 +++--- config/iTransformer/SolarEnergy.yaml | 6 +++--- dataloader/TSloader.py | 20 +++++++++++++++++--- 17 files changed, 65 insertions(+), 51 deletions(-) diff --git a/config/HI/AirQuality.yaml b/config/HI/AirQuality.yaml index 07300c4..147da8a 100644 --- a/config/HI/AirQuality.yaml +++ b/config/HI/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 35 + output_dim: 6 optimizer: null plot: false real_value: true diff --git a/config/HI/BJTaxi-Inflow.yaml b/config/HI/BJTaxi-Inflow.yaml index d752667..d3b39ea 100644 --- a/config/HI/BJTaxi-Inflow.yaml +++ b/config/HI/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-Outflow.yaml index 271fbc7..96f4253 100644 --- a/config/HI/BJTaxi-Outflow.yaml +++ b/config/HI/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/METR-LA.yaml b/config/HI/METR-LA.yaml index 0826302..203db0d 100644 --- a/config/HI/METR-LA.yaml +++ b/config/HI/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 207 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/NYCBike-Inflow.yaml b/config/HI/NYCBike-Inflow.yaml index be217a9..a24a481 100644 --- a/config/HI/NYCBike-Inflow.yaml +++ b/config/HI/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/NYCBike-Outflow.yaml b/config/HI/NYCBike-Outflow.yaml index 0f93fe5..87d6156 100644 --- a/config/HI/NYCBike-Outflow.yaml +++ b/config/HI/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/PEMS-BAY.yaml b/config/HI/PEMS-BAY.yaml index 832f455..e012772 100644 --- a/config/HI/PEMS-BAY.yaml +++ b/config/HI/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 325 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml index 8f55fac..1558d3a 100644 --- a/config/HI/SolarEnergy.yaml +++ b/config/HI/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/iTransformer/AirQuality.yaml b/config/iTransformer/AirQuality.yaml index 74bf69d..23eba27 100644 --- a/config/iTransformer/AirQuality.yaml +++ b/config/iTransformer/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 35 + output_dim: 6 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-Inflow.yaml index 8a0e7c9..dfc2df2 100644 --- a/config/iTransformer/BJTaxi-Inflow.yaml +++ b/config/iTransformer/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-Outflow.yaml index ea4af50..d14bed5 100644 --- a/config/iTransformer/BJTaxi-Outflow.yaml +++ b/config/iTransformer/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 3d02d8b..20c4068 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 207 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-Inflow.yaml index 598ca1e..8afa656 100644 --- a/config/iTransformer/NYCBike-Inflow.yaml +++ b/config/iTransformer/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-Outflow.yaml index b6a8994..7abba88 100644 --- a/config/iTransformer/NYCBike-Outflow.yaml +++ b/config/iTransformer/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/PEMS-BAY.yaml b/config/iTransformer/PEMS-BAY.yaml index 5140b73..17f2fd4 100644 --- a/config/iTransformer/PEMS-BAY.yaml +++ b/config/iTransformer/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 325 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/SolarEnergy.yaml b/config/iTransformer/SolarEnergy.yaml index bab0108..cce005a 100644 --- a/config/iTransformer/SolarEnergy.yaml +++ b/config/iTransformer/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/dataloader/TSloader.py b/dataloader/TSloader.py index 33b5a17..66ef45c 100755 --- a/dataloader/TSloader.py +++ b/dataloader/TSloader.py @@ -7,16 +7,16 @@ import torch def get_dataloader(args, normalizer="std", single=True): data = load_st_dataset(args) - data = data[..., 0:1] + # data = data[..., 0:1] args = args["data"] L, N, F = data.shape - data = data.reshape(L, N*F) # [L, N*F] + # data = data.reshape(L, N*F) # [L, N*F] # Generate sliding windows for main data and add time features x, y = _prepare_data_with_windows(data, args, single) - # Split data + # Split data [b,t,n,c] split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) @@ -25,6 +25,10 @@ def get_dataloader(args, normalizer="std", single=True): scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) _apply_existing_scaler(y_train, y_val, y_test, scaler, args) + # reshape [b,t,n,c] -> [b*n, t, c] + x_train, x_val, x_test, y_train, y_val, y_test = \ + _reshape_tensor(x_train, x_val, x_test, y_train, y_val, y_test) + # Create dataloaders return ( _create_dataloader(x_train, y_train, args["batch_size"], True, False), @@ -33,6 +37,16 @@ def get_dataloader(args, normalizer="std", single=True): scaler ) +def _reshape_tensor(*tensors): + """Reshape tensors from [b, t, n, c] -> [b*n, t, c].""" + reshaped = [] + for x in tensors: + # x 是 ndarray:shape (b, t, n, c) + b, t, n, c = x.shape + x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c) + reshaped.append(x_new) + return reshaped + def _prepare_data_with_windows(data, args, single): # Generate sliding windows for main data x = add_window_x(data, args["lag"], args["horizon"], single) -- 2.40.1 From d8f4cc5825647c26f819040b5edbe4f99cf40ab0 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 21:08:20 +0800 Subject: [PATCH 22/41] =?UTF-8?q?=E7=AE=80=E5=8C=96trainer=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=AE=BE=E5=A4=87bug=EF=BC=8C=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E6=89=B9=E9=87=8F=E8=BF=90=E8=A1=8C=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/args_parser.py | 36 +- generate_launch_configs.py | 134 - mypy.ini | 4 - run_tests.sh | 95 - test_results.txt | 5406 ------------------------------------ train.py | 63 + trainer/Trainer.py | 250 +- trainer/Trainer_bk.py | 420 +++ trainer/Trainer_old.py | 229 -- utils/initializer.py | 6 +- 10 files changed, 550 insertions(+), 6093 deletions(-) delete mode 100644 generate_launch_configs.py delete mode 100644 mypy.ini delete mode 100755 run_tests.sh delete mode 100644 test_results.txt create mode 100644 train.py create mode 100755 trainer/Trainer_bk.py delete mode 100755 trainer/Trainer_old.py diff --git a/config/args_parser.py b/config/args_parser.py index ebd7bda..256c1f7 100755 --- a/config/args_parser.py +++ b/config/args_parser.py @@ -15,39 +15,5 @@ def parse_args(): config = yaml.safe_load(file) else: raise ValueError("Configuration file path must be provided using --config") - - # Update configuration with command-line arguments - # Merge 'basic' configuration into the root dictionary - # config.update(config.get('basic', {})) - - # Add adaptive configuration based on external commands - if "data" in config and "type" in config["data"]: - config["data"]["type"] = config["basic"].get("dataset", config["data"]["type"]) - if "model" in config and "type" in config["model"]: - config["model"]["type"] = config["basic"].get("model", config["model"]["type"]) - if "model" in config and "rnn_units" in config["model"]: - config["model"]["rnn_units"] = config["basic"].get( - "rnn", config["model"]["rnn_units"] - ) - if "model" in config and "embed_dim" in config["model"]: - config["model"]["embed_dim"] = config["basic"].get( - "emb", config["model"]["embed_dim"] - ) - if "data" in config and "sample" in config["data"]: - config["data"]["sample"] = config["basic"].get( - "sample", config["data"]["sample"] - ) - if "train" in config and "device" in config["train"]: - config["train"]["device"] = config["basic"].get( - "device", config["train"]["device"] - ) - if "train" in config and "debug" in config["train"]: - config["train"]["debug"] = config["basic"].get( - "debug", config["train"]["debug"] - ) - if "cuda" in config: - config["cuda"] = config["basic"].get("cuda", config["cuda"]) - if "mode" in config: - config["mode"] = config["basic"].get("mode", config["mode"]) - + return config diff --git a/generate_launch_configs.py b/generate_launch_configs.py deleted file mode 100644 index 6477e16..0000000 --- a/generate_launch_configs.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import re - -# 配置路径 -CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config" -LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json" - -# 遍历所有yaml文件 -def find_all_yaml_files(directory): - yaml_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith(".yaml") and not file.startswith("BJTaxi"): - yaml_files.append(os.path.join(root, file)) - return yaml_files - -# 生成launch配置字符串 -def generate_launch_config_string(yaml_files): - config_strings = [] - - for file_path in yaml_files: - # 提取模型名和数据集名 - relative_path = os.path.relpath(file_path, CONFIG_DIR) - model_name = relative_path.split(os.sep)[0] - dataset_name = os.path.splitext(os.path.basename(file_path))[0] - - # 处理v2版本 - if "v2_" in dataset_name: - model_display_name = f"{model_name}_v2" - dataset_display_name = dataset_name.replace("v2_", "") - else: - model_display_name = model_name - dataset_display_name = dataset_name - - # 生成配置字符串 - config_string = f''' - {{ - "name": "{model_display_name}: {dataset_display_name}", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/{model_name}/{os.path.basename(file_path)}" - }}''' - - config_strings.append(config_string) - - return ",".join(config_strings) - -# 读取现有的launch.json文件,提取配置名称 -def get_existing_config_names(): - with open(LAUNCH_FILE, 'r') as f: - content = f.read() - - # 提取所有配置名称 - name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"') - matches = name_pattern.findall(content) - - return set(matches) - -# 生成新的配置,过滤掉已存在的 -def generate_new_configs(yaml_files, existing_names): - new_configs = [] - - for file_path in yaml_files: - # 提取模型名和数据集名 - relative_path = os.path.relpath(file_path, CONFIG_DIR) - model_name = relative_path.split(os.sep)[0] - dataset_name = os.path.splitext(os.path.basename(file_path))[0] - - # 处理v2版本 - if "v2_" in dataset_name: - model_display_name = f"{model_name}_v2" - dataset_display_name = dataset_name.replace("v2_", "") - else: - model_display_name = model_name - dataset_display_name = dataset_name - - # 生成配置名称 - config_name = f"{model_display_name}: {dataset_display_name}" - - # 如果配置不存在,则添加 - if config_name not in existing_names: - new_configs.append(file_path) - - return new_configs - -# 更新launch.json文件 -def update_launch_json(new_configs_string): - with open(LAUNCH_FILE, 'r') as f: - content = f.read() - - # 找到configurations数组的结束位置 - configs_end_match = re.search(r'\s*\]\s*\}', content) - if not configs_end_match: - return False - - # 插入新的配置 - insert_pos = configs_end_match.start() - new_content = content[:insert_pos] + new_configs_string + content[insert_pos:] - - # 保存文件 - with open(LAUNCH_FILE, 'w') as f: - f.write(new_content) - - return True - -# 主函数 -def main(): - # 查找所有yaml文件 - yaml_files = find_all_yaml_files(CONFIG_DIR) - - # 获取现有配置名称 - existing_names = get_existing_config_names() - - # 生成新的配置,过滤掉已存在的 - new_config_files = generate_new_configs(yaml_files, existing_names) - - if not new_config_files: - print("No new configurations to add") - return - - # 生成新的配置字符串 - new_configs_string = generate_launch_config_string(new_config_files) - - # 更新launch.json文件 - if update_launch_json(new_configs_string): - print(f"Added {len(new_config_files)} new launch configurations") - print(f"Total configurations: {len(existing_names) + len(new_config_files)}") - else: - print("Failed to update launch.json") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index c77f418..0000000 --- a/mypy.ini +++ /dev/null @@ -1,4 +0,0 @@ -[mypy] -explicit_package_bases = True -ignore_missing_imports = True -no_site_packages = True diff --git a/run_tests.sh b/run_tests.sh deleted file mode 100755 index ef700a1..0000000 --- a/run_tests.sh +++ /dev/null @@ -1,95 +0,0 @@ -#!/bin/bash - -# 设置默认模型名和数据集列表 -MODEL_NAME="STAEFormer" -DATASETS=( - "METR-LA" - "PEMS-BAY" - "NYCBike-InFlow" - "NYCBike-OutFlow" - "AirQuality" - "SolarEnergy" -) - -# 初始化统计变量 -success_count=0 -failure_count=0 -missing_count=0 -total_count=0 -success_datasets=() -failure_datasets=() -missing_datasets=() - -# 检查是否有参数传入来覆盖默认值 -if [ $# -gt 0 ]; then - MODEL_NAME=$1 - # 如果传入了更多参数,使用它们作为数据集列表 - if [ $# -gt 1 ]; then - DATASETS=(${@:2}) - fi -fi - -echo "使用模型: $MODEL_NAME" -echo "数据集列表: ${DATASETS[*]}" -echo "开始测试..." -echo "" - -# 循环测试每个数据集 -for dataset in "${DATASETS[@]}"; do - total_count=$((total_count + 1)) - # 构建配置文件路径 - CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml" - - echo "测试数据集: $dataset" - echo "使用配置文件: $CONFIG_PATH" - - # 检查配置文件是否存在 - if [ ! -f "$CONFIG_PATH" ]; then - echo "错误: 配置文件 $CONFIG_PATH 不存在!" - missing_count=$((missing_count + 1)) - missing_datasets+=("$dataset") - echo "----------------------------------------" - continue - fi - - # 执行测试命令,同时捕获输出并显示在控制台上 - echo "执行: python run.py --config $CONFIG_PATH" - output=$(python run.py --config "$CONFIG_PATH" 2>&1 | tee /dev/tty) - - # 如果没有找到明确的标记,回退到检查退出码 - if [ $? -eq 0 ]; then - echo "数据集 $dataset 测试成功! (基于退出码)" - success_count=$((success_count + 1)) - success_datasets+=("$dataset") - else - echo "数据集 $dataset 测试失败! (基于退出码)" - failure_count=$((failure_count + 1)) - failure_datasets+=("$dataset") - fi - - echo "----------------------------------------" -done - -# 输出总结 -echo "=======================================" -echo "测试总结" -echo "=======================================" -echo "总数据集数量: $total_count" -echo "成功数量: $success_count" -echo "失败数量: $failure_count" -echo "缺失配置文件数量: $missing_count" - -if [ ${#success_datasets[@]} -gt 0 ]; then - echo "成功的数据集: ${success_datasets[*]}" -fi - -if [ ${#failure_datasets[@]} -gt 0 ]; then - echo "失败的数据集: ${failure_datasets[*]}" -fi - -if [ ${#missing_datasets[@]} -gt 0 ]; then - echo "缺失配置的数据集: ${missing_datasets[*]}" -fi - -echo "=======================================" -echo "所有测试完成!" \ No newline at end of file diff --git a/test_results.txt b/test_results.txt deleted file mode 100644 index 6116217..0000000 --- a/test_results.txt +++ /dev/null @@ -1,5406 +0,0 @@ -# 测试报告 - -## 测试概述 -- 测试时间: 2025-12-01 22:20:35 -- 总测试文件数: 252 -- 通过: 41 -- 失败: 0 -- 错误: 211 - -## 通过的配置文件 -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml - -## 失败的配置文件 - -## 出错的配置文件 -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml - -## 详细输出 - -### PASSED - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml - -``` -模型参数量: 118040 -加载 NYCBike-InFlow 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml - -``` -模型参数量: 120568 -加载 METR-LA 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml - -``` -模型参数量: 115064 -加载 AirQuality 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 35, 1]) matches label shape torch.Size([64, 24, 35, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml - -``` -模型参数量: 124344 -加载 PEMS-BAY 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 325, 1]) matches label shape torch.Size([64, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml - -``` -模型参数量: 118040 -加载 NYCBike-OutFlow 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml - -``` -模型参数量: 118328 -加载 SolarEnergy 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml - -``` -模型参数量: 1354932 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD4/2025-12-01_21-52-10/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml - -``` -模型参数量: 1258932 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/METR-LA/2025-12-01_21-52-24/run.log -✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml - -``` -模型参数量: 1223412 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD8/2025-12-01_21-52-49/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml - -``` -模型参数量: 1403892 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD3/2025-12-01_21-53-06/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml - -``` -模型参数量: 1907892 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD7/2025-12-01_21-54-04/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml - -``` -模型参数量: 103504579 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-InFlow/2025-12-01_21-55-58/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml - -``` -模型参数量: 103505369 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/METR-LA/2025-12-01_21-56-29/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml - -``` -模型参数量: 103503669 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/AirQuality/2025-12-01_21-56-40/run.log -✓ Test passed: output shape torch.Size([16, 24, 35, 6]) matches label shape torch.Size([16, 24, 35, 6]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml - -``` -模型参数量: 103506549 -加载 PEMS-BAY 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/PEMS-BAY/2025-12-01_21-57-30/run.log -✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml - -``` -模型参数量: 103504669 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/SolarEnergy/2025-12-01_21-57-55/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml - -``` -模型参数量: 103524820 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/METR-LA/2025-12-01_21-58-18/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml - -``` -模型参数量: 103504579 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-OutFlow/2025-12-01_21-58-29/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml - -``` -模型参数量: 103524120 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/SolarEnergy/2025-12-01_21-58-54/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml - -``` -模型参数量: 35873 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-InFlow/2025-12-01_21-59-55/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml - -``` -模型参数量: 35873 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD4/2025-12-01_22-00-07/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml - -``` -模型参数量: 35873 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/METR-LA/2025-12-01_22-00-28/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml - -``` -模型参数量: 35873 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-OutFlow/2025-12-01_22-00-44/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml - -``` -模型参数量: 35873 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD8/2025-12-01_22-00-54/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml - -``` -模型参数量: 35873 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD3/2025-12-01_22-01-10/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml - -``` -模型参数量: 35873 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/SolarEnergy/2025-12-01_22-01-33/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml - -``` -模型参数量: 35873 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD7/2025-12-01_22-02-05/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml - -``` -模型参数量: 671644 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/METR-LA/2025-12-01_22-03-35/run.log -✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml - -``` -模型参数量: 311759 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/PEMSD8/2025-12-01_22-03-57/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml - -``` -模型参数量: 37896712 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD4/2025-12-01_22-04-52/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml - -``` -模型参数量: 37896712 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/METR-LA/2025-12-01_22-05-06/run.log -✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml - -``` -模型参数量: 37896712 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD8/2025-12-01_22-05-33/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml - -``` -模型参数量: 37896712 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD3/2025-12-01_22-05-49/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml - -``` -模型参数量: 615304 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD7/2025-12-01_22-06-48/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-InFlow/2025-12-01_22-09-34/run.log -✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml - -``` -模型参数量: 103815937 -加载 BeijingAirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/BeijingAirQuality/2025-12-01_22-09-59/run.log -✓ Test passed: output shape torch.Size([16, 24, 7, 3]) matches label shape torch.Size([16, 24, 7, 3]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml - -``` -模型参数量: 103481647 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/METR-LA/2025-12-01_22-10-22/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml - -``` -模型参数量: 103815973 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/AirQuality/2025-12-01_22-10-33/run.log -✓ Test passed: output shape torch.Size([16, 24, 35, 3]) matches label shape torch.Size([16, 24, 35, 3]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml - -``` -模型参数量: 103481647 -加载 PEMS-BAY 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/PEMS-BAY/2025-12-01_22-11-23/run.log -✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml - -``` -模型参数量: 103481647 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/SolarEnergy/2025-12-01_22-11-48/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-OutFlow/2025-12-01_22-11-58/run.log -✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) -``` - - -### FAILED - - -### ERROR - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml - -``` -模型参数量: 146712 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 70, in model_selector - return STID(model_config) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STID/STID.py", line 13, in __init__ - self.embed_dim = model_args["embed_dim"] - ~~~~~~~~~~^^^^^^^^^^^^^ -KeyError: 'embed_dim' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml - -``` -模型参数量: 146712 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 29, in get_adj - return adj - ^^^ -UnboundLocalError: cannot access local variable 'adj' where it is not associated with a value -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml - -``` -模型参数量: 3086208 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml - -``` -模型参数量: 1624752 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33/run.log -2025/12/01 21:52:33 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33 -2025/12/01 21:52:33 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward - x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward - return F.linear(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: mat1 and mat2 shapes cannot be multiplied (13440x1 and 6x24) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml - -``` -模型参数量: 3086208 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml - -``` -模型参数量: 13296192 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32/run.log -2025/12/01 21:53:32 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32 -2025/12/01 21:53:32 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward - x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward - return F.linear(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: mat1 and mat2 shapes cannot be multiplied (52608x1 and 137x24) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml - -``` -模型参数量: 103513539 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml - -``` -模型参数量: 103513539 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml - -``` -模型参数量: 36678 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37/run.log -2025/12/01 22:00:37 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37 -2025/12/01 22:00:37 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 43, in forward - x = self.network(x) - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward - input = module(input) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 89, in forward - res = x if self.downsample is None else self.downsample(x) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 554, in forward - return self._conv_forward(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward - return F.conv2d( - ^^^^^^^^^ -RuntimeError: Given groups=1, weight of size [32, 6, 1, 1], expected input[16, 1, 35, 24] to have 6 channels, but got 1 channels instead -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml - -``` -模型参数量: 235788 -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 18, in get_dataloader - return EXP_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/EXPdataloader.py", line 8, in get_dataloader - data = load_st_dataset(args["type"], args["sample"]) # [T, N, F] - ~~~~^^^^^^^^ -KeyError: 'type' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml - -``` -模型参数量: 37897240 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml - -``` -模型参数量: 37897240 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16/run.log -2025/12/01 22:05:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16 -2025/12/01 22:05:16 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch - loss = self.loss(output, label) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward - return F.l1_loss(input, target, reduction=self.reduction) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss - expanded_input, expanded_target = torch.broadcast_tensors(input, target) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors - return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml - -``` -模型参数量: 37897240 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml - -``` -模型参数量: 37897240 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16/run.log -2025/12/01 22:06:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16 -2025/12/01 22:06:16 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch - loss = self.loss(output, label) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward - return F.l1_loss(input, target, reduction=self.reduction) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss - expanded_input, expanded_target = torch.broadcast_tensors(input, target) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors - return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 53, in __init__ - self.input_dim = args["input_dim"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'input_dim' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 84, in model_selector - return REPST(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/REPST/repst.py", line 24, in __init__ - self.word_choice = GumbelSoftmax(configs['word_num']) - ~~~~~~~^^^^^^^^^^^^ -KeyError: 'word_num' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml - -``` -模型参数量: 4 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml - -``` -模型参数量: 4 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54/run.log -2025/12/01 22:14:54 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54 -2025/12/01 22:14:54 - Training process started - -Train Epoch 1: 0%| | 0/160 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml - -``` -模型参数量: 4 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07/run.log -2025/12/01 22:15:07 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07 -2025/12/01 22:15:07 - Training process started - -Train Epoch 1: 0%| | 0/1285 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml - -``` -模型参数量: 4 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15/run.log -2025/12/01 22:15:15 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15 -2025/12/01 22:15:15 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml - -``` -模型参数量: 4 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml - -``` -模型参数量: 4 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31/run.log -2025/12/01 22:15:31 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31 -2025/12/01 22:15:31 - Training process started - -Train Epoch 1: 0%| | 0/168 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml - -``` -模型参数量: 4 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53/run.log -2025/12/01 22:15:53 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53 -2025/12/01 22:15:53 - Training process started - -Train Epoch 1: 0%| | 0/982 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml - -``` -模型参数量: 4 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18/run.log -2025/12/01 22:16:18 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18 -2025/12/01 22:16:18 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml - -``` -模型参数量: 4 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56/run.log -2025/12/01 22:16:56 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56 -2025/12/01 22:16:56 - Training process started - -Train Epoch 1: 0%| | 0/1058 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - diff --git a/train.py b/train.py new file mode 100644 index 0000000..9d58921 --- /dev/null +++ b/train.py @@ -0,0 +1,63 @@ +import yaml +import torch + +import utils.initializer as init +from dataloader.loader_selector import get_dataloader +from trainer.trainer_selector import select_trainer + +def run(config): + init.init_seed(config["basic"]["seed"]) + model = init.init_model(config) + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + config, normalizer=config["data"]["normalizer"], single=False + ) + loss = init.init_loss(config, scaler) + optimizer, lr_scheduler = init.init_optimizer(model, config["train"]) + init.create_logs(config) + trainer = select_trainer( + model, + loss, optimizer, + train_loader, val_loader, test_loader, scaler, + config, + lr_scheduler, extra_data, + ) + + # 开始训练 + match config["basic"]["mode"]: + case "train": + trainer.train() + case "test": + model.load_state_dict( + torch.load( + f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth", + map_location=config["basic"]["device"], + weights_only=True, + ) + ) + trainer.test( + model.to(config["basic"]["device"]), + trainer.args, test_loader, scaler, + trainer.logger, + ) + case _: + raise ValueError(f"Unsupported mode: {config['basic']['mode']}") + + +if __name__ == "__main__": + # 指定模型 + model_list = ["HI"] + # 指定数据集 + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + device = "cuda:0" # 指定设备 + seed = 2023 # 随机种子 + for model in model_list: + for dataset in dataset_list: + config_path = f"./config/{model}/{dataset}.yaml" + with open(config_path, "r") as file: + config = yaml.safe_load(file) + config["basic"]["device"] = device + config["basic"]["seed"] = seed + print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") + print(f"config: {config}") + run(config) + diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 2bd7e6e..4bd82a4 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -8,125 +8,31 @@ from utils.logger import get_logger from utils.loss_function import all_metrics from tqdm import tqdm - -class TrainingStats: - """记录训练过程中的统计信息""" - - def __init__(self, device): - self.device = device - self.reset() - - def reset(self): - """重置所有统计数据""" - self.gpu_mem_usage_list = [] - self.cpu_mem_usage_list = [] - self.train_time_list = [] - self.infer_time_list = [] - self.total_iters = 0 - self.start_time = None - self.end_time = None - - def start_training(self): - """记录训练开始时间""" - self.start_time = time.time() - - def end_training(self): - """记录训练结束时间""" - self.end_time = time.time() - - def record_step_time(self, duration, mode): - """记录单步耗时和总迭代次数""" - if mode == "train": - self.train_time_list.append(duration) - else: - self.infer_time_list.append(duration) - self.total_iters += 1 - - def record_memory_usage(self): - """记录当前 GPU 和 CPU 内存占用""" - process = psutil.Process(os.getpid()) - cpu_mem = process.memory_info().rss / (1024**2) - - if torch.cuda.is_available(): - gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2) - torch.cuda.reset_peak_memory_stats(device=self.device) - else: - gpu_mem = 0.0 - - self.cpu_mem_usage_list.append(cpu_mem) - self.gpu_mem_usage_list.append(gpu_mem) - - def _calculate_average(self, values_list): - """安全计算平均值,避免除零错误""" - return sum(values_list) / len(values_list) if values_list else 0 - - def report(self, logger): - """在训练结束时输出汇总统计""" - if not self.start_time or not self.end_time: - logger.warning("TrainingStats: start/end time not recorded properly.") - return - - total_time = self.end_time - self.start_time - avg_gpu_mem = self._calculate_average(self.gpu_mem_usage_list) - avg_cpu_mem = self._calculate_average(self.cpu_mem_usage_list) - avg_train_time = self._calculate_average(self.train_time_list) - avg_infer_time = self._calculate_average(self.infer_time_list) - iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 - - logger.info("===== Training Summary =====") - logger.info(f"Total training time: {total_time:.2f} s") - logger.info(f"Total iterations: {self.total_iters}") - logger.info(f"Average iterations per second: {iters_per_sec:.2f}") - logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB") - logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB") - if avg_train_time: - logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms") - if avg_infer_time: - logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms") - - class Trainer: """模型训练器,负责整个训练流程的管理""" - def __init__( - self, - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler=None, - ): + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): # 设备和基本参数 + self.config = args self.device = args["basic"]["device"] train_args = args["train"] - # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer self.lr_scheduler = lr_scheduler - # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader - # 数据处理工具 self.scaler = scaler self.args = train_args - - # 统计信息 - self.train_per_epoch = len(train_loader) - self.val_per_epoch = len(val_loader) if val_loader else 0 - # 初始化路径、日志和统计 self._initialize_paths(train_args) self._initialize_logger(train_args) - self._initialize_stats() def _initialize_paths(self, args): """初始化模型保存路径""" @@ -138,24 +44,14 @@ class Trainer: """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger( - args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] - ) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) self.logger.info(f"Experiment log path in: {args['log_dir']}") - - def _initialize_stats(self): - """初始化统计信息记录器""" - self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): """运行一个训练/验证/测试epoch""" # 设置模型模式和是否进行优化 - if mode == "train": - self.model.train() - optimizer_step = True - else: - self.model.eval() - optimizer_step = False + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False # 初始化变量 total_loss = 0 @@ -169,73 +65,42 @@ class Trainer: total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" ) - for _, (data, target) in progress_bar: - # 记录步骤开始时间 - start_time = time.time() - - # 前向传播 + # 转移数据 + data = data.to(self.device) + target = target.to(self.device) label = target[..., : self.args["output_dim"]] - output = self.model(data).to(self.device) - # if output.shape != label.shape: - # import sys - # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") - # sys.exit(1) - # else: - # import sys - # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") - # sys.exit(0) + # 计算loss和反归一化loss + output = self.model(data) loss = self.loss(output, label) - - # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) - - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - # 梯度裁剪(如果需要) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() - - # 反归一化的loss d_loss = self.loss(d_output, d_label) - - # 记录步骤时间和内存使用 - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - # 累积损失和预测结果 total_loss += d_loss.item() y_pred.append(d_output.detach().cpu()) y_true.append(d_label.detach().cpu()) - + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) # 合并所有批次的预测结果 y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - - # 计算平均损失 + # 计算损失并记录指标 avg_loss = total_loss / len(dataloader) - - # 计算并记录指标 - mae, rmse, mape = all_metrics( - y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] - ) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" ) - - # 记录内存使用情况 - self.stats.record_memory_usage() - return avg_loss def train_epoch(self, epoch): @@ -248,28 +113,22 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): - """执行完整的训练流程""" - # 初始化最佳模型和损失记录 + # 初始化记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 - # 开始训练 - self.stats.start_training() self.logger.info("Training process started") - # 训练循环 for epoch in range(1, self.args["epochs"] + 1): # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) - # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break - # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss @@ -278,29 +137,18 @@ class Trainer: self.logger.info("Best validation model saved!") else: not_improved_count += 1 - - # 检查早停条件 + # 早停 if self._should_early_stop(not_improved_count): break - # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 if not self.args["debug"]: self._save_best_models(best_model, best_test_model) - - # 结束训练并输出统计信息 - self.stats.end_training() - self.stats.report(self.logger) - # 最终评估 self._finalize_training(best_model, best_test_model) - - # 输出模型参数量 - self._log_model_params() def _should_early_stop(self, not_improved_count): """检查是否满足早停条件""" @@ -331,20 +179,35 @@ class Trainer: def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) self.model.load_state_dict(best_test_model) self.logger.info("Testing on best test model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) @staticmethod def test(model, args, data_loader, scaler, logger, path=None): """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况,从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["basic"]["device"]) + model.to(device) # 设置为评估模式 model.eval() @@ -355,27 +218,40 @@ class Trainer: # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: - label = target[..., : args["output_dim"]] + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + label = target[..., : output_dim] output = model(data) y_pred.append(output.detach().cpu()) y_true.append(label.detach().cpu()) - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + # 计算并记录每个时间步的指标 for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( d_y_pred[:, t, ...], d_y_true[:, t, ...], - args["mae_thresh"], - args["mape_thresh"], + mae_thresh, + mape_thresh, ) logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod diff --git a/trainer/Trainer_bk.py b/trainer/Trainer_bk.py new file mode 100755 index 0000000..ee6e388 --- /dev/null +++ b/trainer/Trainer_bk.py @@ -0,0 +1,420 @@ +import math +import os +import time +import copy +import psutil +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + + +# class TrainingStats: +# """记录训练过程中的统计信息""" + +# def __init__(self, device): +# self.device = device +# self.reset() + +# def reset(self): +# """重置所有统计数据""" +# self.gpu_mem_usage_list = [] +# self.cpu_mem_usage_list = [] +# self.train_time_list = [] +# self.infer_time_list = [] +# self.total_iters = 0 +# self.start_time = None +# self.end_time = None + +# def start_training(self): +# """记录训练开始时间""" +# self.start_time = time.time() + +# def end_training(self): +# """记录训练结束时间""" +# self.end_time = time.time() + +# def record_step_time(self, duration, mode): +# """记录单步耗时和总迭代次数""" +# if mode == "train": +# self.train_time_list.append(duration) +# else: +# self.infer_time_list.append(duration) +# self.total_iters += 1 + +# def record_memory_usage(self): +# """记录当前 GPU 和 CPU 内存占用""" +# process = psutil.Process(os.getpid()) +# cpu_mem = process.memory_info().rss / (1024**2) + +# if torch.cuda.is_available(): +# gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2) +# torch.cuda.reset_peak_memory_stats(device=self.device) +# else: +# gpu_mem = 0.0 + +# self.cpu_mem_usage_list.append(cpu_mem) +# self.gpu_mem_usage_list.append(gpu_mem) + +# def _calculate_average(self, values_list): +# """安全计算平均值,避免除零错误""" +# return sum(values_list) / len(values_list) if values_list else 0 + +# def report(self, logger): +# """在训练结束时输出汇总统计""" +# if not self.start_time or not self.end_time: +# logger.warning("TrainingStats: start/end time not recorded properly.") +# return + +# total_time = self.end_time - self.start_time +# avg_gpu_mem = self._calculate_average(self.gpu_mem_usage_list) +# avg_cpu_mem = self._calculate_average(self.cpu_mem_usage_list) +# avg_train_time = self._calculate_average(self.train_time_list) +# avg_infer_time = self._calculate_average(self.infer_time_list) +# iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 + +# logger.info("===== Training Summary =====") +# logger.info(f"Total training time: {total_time:.2f} s") +# logger.info(f"Total iterations: {self.total_iters}") +# logger.info(f"Average iterations per second: {iters_per_sec:.2f}") +# logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB") +# logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB") +# if avg_train_time: +# logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms") +# if avg_infer_time: +# logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms") + + +class Trainer: + """模型训练器,负责整个训练流程的管理""" + + def __init__( + self, + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler=None, + ): + # 设备和基本参数 + self.device = args["basic"]["device"] + self.config = args # 保存完整的配置参数 + train_args = args["train"] + + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + + # 数据处理工具 + self.scaler = scaler + self.args = train_args + + # 统计信息 + # self.train_per_epoch = len(train_loader) + # self.val_per_epoch = len(val_loader) if val_loader else 0 + + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger( + args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] + ) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + # def _initialize_stats(self): + # """初始化统计信息记录器""" + # self.stats = TrainingStats(device=self.device) + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 + if mode == "train": + self.model.train() + optimizer_step = True + else: + self.model.eval() + optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for _, (data, target) in progress_bar: + # 记录步骤开始时间 + start_time = time.time() + + # 将数据和标签移动到指定设备 + data = data.to(self.device) + target = target.to(self.device) + + # 前向传播 + label = target[..., : self.args["output_dim"]] + output = self.model(data) + # if output.shape != label.shape: + # import sys + # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(1) + # else: + # import sys + # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(0) + loss = self.loss(output, label) + + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] + ) + self.optimizer.step() + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) + + # 记录步骤时间和内存使用 + # step_time = time.time() - start_time + # self.stats.record_step_time(step_time, mode) + + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 + avg_loss = total_loss / len(dataloader) + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] + ) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 + # self.stats.record_memory_usage() + + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + + # 开始训练 + # self.stats.start_training() + self.logger.info("Training process started") + + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + + # 检查早停条件 + if self._should_early_stop(not_improved_count): + break + + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + + # 结束训练并输出统计信息 + # self.stats.end_training() + # self.stats.report(self.logger) + + # 最终评估 + self._finalize_training(best_model, best_test_model) + + # 输出模型参数量 + self._log_model_params() + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况 + # 从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + + # 加载模型检查点(如果提供了路径) + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for data, target in data_loader: + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + label = target[..., : output_dim] + output = model(data) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) + + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/Trainer_old.py b/trainer/Trainer_old.py deleted file mode 100755 index bd49b29..0000000 --- a/trainer/Trainer_old.py +++ /dev/null @@ -1,229 +0,0 @@ -import math -import os -import time -import copy -from tqdm import tqdm - -import torch -from utils.logger import get_logger -from utils.loss_function import all_metrics -from utils.training_stats import TrainingStats - - -class Trainer: - def __init__( - self, - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler=None, - ): - self.model = model - self.loss = loss - self.optimizer = optimizer - self.train_loader = train_loader - self.val_loader = val_loader - self.test_loader = test_loader - self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler - self.train_per_epoch = len(train_loader) - self.val_per_epoch = len(val_loader) if val_loader else 0 - - # Paths for saving models and logs - self.best_path = os.path.join(args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger - if not os.path.isdir(args["log_dir"]) and not args["debug"]: - os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger( - args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] - ) - self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) - - def _run_epoch(self, epoch, dataloader, mode): - if mode == "train": - self.model.train() - optimizer_step = True - else: - self.model.eval() - optimizer_step = False - - total_loss = 0 - epoch_time = time.time() - - with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, (data, target) in enumerate(dataloader): - start_time = time.time() - label = target[..., : self.args["output_dim"]] - output = self.model(data).to(self.args["device"]) - - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) - - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() - - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" - ) - - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) - - avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" - ) - # 记录内存 - self.stats.record_memory_usage() - return avg_loss - - def train_epoch(self, epoch): - return self._run_epoch(epoch, self.train_loader, "train") - - def val_epoch(self, epoch): - return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - - def test_epoch(self, epoch): - return self._run_epoch(epoch, self.test_loader, "test") - - def train(self): - best_model, best_test_model = None, None - best_loss, best_test_loss = float("inf"), float("inf") - not_improved_count = 0 - - self.stats.start_training() - self.logger.info("Training process started") - for epoch in range(1, self.args["epochs"] + 1): - train_epoch_loss = self.train_epoch(epoch) - val_epoch_loss = self.val_epoch(epoch) - test_epoch_loss = self.test_epoch(epoch) - - if train_epoch_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") - break - - if val_epoch_loss < best_loss: - best_loss = val_epoch_loss - not_improved_count = 0 - best_model = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") - else: - not_improved_count += 1 - - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - break - - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) - - if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) - - # 输出统计与参数 - self.stats.end_training() - self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass - self._finalize_training(best_model, best_test_model) - - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) - - model.eval() - y_pred, y_true = [], [] - - with torch.no_grad(): - for data, target in data_loader: - label = target[..., : args["output_dim"]] - output = model(data) - y_pred.append(output) - y_true.append(label) - - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) - - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - - for t in range(y_true.shape[1]): - mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], - args["mae_thresh"], - args["mape_thresh"], - ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) - - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) - - @staticmethod - def _compute_sampling_threshold(global_step, k): - return k / (k + math.exp(global_step / k)) diff --git a/utils/initializer.py b/utils/initializer.py index 7bee2be..183bfd3 100755 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -9,9 +9,9 @@ import os import yaml -def init_model(args): - device = args["device"] - model = model_selector(args).to(device) +def init_model(config): + device = config["basic"]["device"] + model = model_selector(config).to(device) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) -- 2.40.1 From 4ccb029d7ee8948e1ff2533d1fe122a993457e8d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 21:53:46 +0800 Subject: [PATCH 23/41] impl PatchTST --- config/PatchTST/AirQuality.yaml | 54 ++++ config/PatchTST/BJTaxi-Inflow.yaml | 54 ++++ config/PatchTST/BJTaxi-Outflow.yaml | 54 ++++ config/PatchTST/METR-LA.yaml | 54 ++++ config/PatchTST/NYCBike-Inflow.yaml | 54 ++++ config/PatchTST/NYCBike-Outflow.yaml | 54 ++++ config/PatchTST/PEMS-BAY.yaml | 54 ++++ config/PatchTST/SolarEnergy.yaml | 54 ++++ dataloader/loader_selector.py | 2 +- model/MTGNN/MTGNN.py | 134 ++++++++++ model/MTGNN/layer.py | 328 +++++++++++++++++++++++++ model/PatchTST/PatchTST.py | 109 ++++++++ model/PatchTST/layers/Embed.py | 29 +++ model/PatchTST/layers/SelfAttention.py | 80 ++++++ model/PatchTST/layers/Transformer.py | 57 +++++ model/model_selector.py | 3 + train.py | 5 +- trainer/Trainer.py | 1 - 18 files changed, 1177 insertions(+), 3 deletions(-) create mode 100644 config/PatchTST/AirQuality.yaml create mode 100644 config/PatchTST/BJTaxi-Inflow.yaml create mode 100644 config/PatchTST/BJTaxi-Outflow.yaml create mode 100644 config/PatchTST/METR-LA.yaml create mode 100644 config/PatchTST/NYCBike-Inflow.yaml create mode 100644 config/PatchTST/NYCBike-Outflow.yaml create mode 100644 config/PatchTST/PEMS-BAY.yaml create mode 100644 config/PatchTST/SolarEnergy.yaml create mode 100644 model/MTGNN/MTGNN.py create mode 100644 model/MTGNN/layer.py create mode 100644 model/PatchTST/PatchTST.py create mode 100644 model/PatchTST/layers/Embed.py create mode 100644 model/PatchTST/layers/SelfAttention.py create mode 100644 model/PatchTST/layers/Transformer.py diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml new file mode 100644 index 0000000..a3e6418 --- /dev/null +++ b/config/PatchTST/AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..9bd66d9 --- /dev/null +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..2382695 --- /dev/null +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml new file mode 100644 index 0000000..d076d35 --- /dev/null +++ b/config/PatchTST/METR-LA.yaml @@ -0,0 +1,54 @@ +basic: + dataset: METR-LA + device: cuda:1 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml new file mode 100644 index 0000000..2c3026c --- /dev/null +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml new file mode 100644 index 0000000..16eee20 --- /dev/null +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml new file mode 100644 index 0000000..6186db3 --- /dev/null +++ b/config/PatchTST/PEMS-BAY.yaml @@ -0,0 +1,54 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + patch_len: 6 + stride: 8 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml new file mode 100644 index 0000000..28b85b9 --- /dev/null +++ b/config/PatchTST/SolarEnergy.yaml @@ -0,0 +1,54 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + patch_len: 6 + stride: 8 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index f9bf823..88d1e2d 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -7,7 +7,7 @@ from dataloader.TSloader import get_dataloader as TS_loader def get_dataloader(config, normalizer, single): - TS_model = ["iTransformer", "HI"] + TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] if model_name in TS_model: return TS_loader(config, normalizer, single) diff --git a/model/MTGNN/MTGNN.py b/model/MTGNN/MTGNN.py new file mode 100644 index 0000000..483a184 --- /dev/null +++ b/model/MTGNN/MTGNN.py @@ -0,0 +1,134 @@ +import torch.nn as nn +from model.MTGNN.layer import * + + +class gtnet(nn.Module): + def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): + super(gtnet, self).__init__() + self.gcn_true = gcn_true + self.buildA_true = buildA_true + self.num_nodes = num_nodes + self.dropout = dropout + self.predefined_A = predefined_A + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.gconv1 = nn.ModuleList() + self.gconv2 = nn.ModuleList() + self.norm = nn.ModuleList() + self.start_conv = nn.Conv2d(in_channels=in_dim, + out_channels=residual_channels, + kernel_size=(1, 1)) + self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) + + self.seq_length = seq_length + kernel_size = 7 + if dilation_exponential>1: + self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + self.receptive_field = layers*(kernel_size-1) + 1 + + for i in range(1): + if dilation_exponential>1: + rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + rf_size_i = i*layers*(kernel_size-1)+1 + new_dilation = 1 + for j in range(1,layers+1): + if dilation_exponential > 1: + rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) + else: + rf_size_j = rf_size_i+j*(kernel_size-1) + + self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=residual_channels, + kernel_size=(1, 1))) + if self.seq_length>self.receptive_field: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.seq_length-rf_size_j+1))) + else: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.receptive_field-rf_size_j+1))) + + if self.gcn_true: + self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + + if self.seq_length>self.receptive_field: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + else: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + + new_dilation *= dilation_exponential + + self.layers = layers + self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, + out_channels=end_channels, + kernel_size=(1,1), + bias=True) + self.end_conv_2 = nn.Conv2d(in_channels=end_channels, + out_channels=out_dim, + kernel_size=(1,1), + bias=True) + if self.seq_length > self.receptive_field: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + + else: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) + + + self.idx = torch.arange(self.num_nodes).to(device) + + + def forward(self, input, idx=None): + seq_len = input.size(3) + assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' + + if self.seq_lengthncvl',(x,A)) + return x.contiguous() + +class dy_nconv(nn.Module): + def __init__(self): + super(dy_nconv,self).__init__() + + def forward(self,x, A): + x = torch.einsum('ncvl,nvwl->ncwl',(x,A)) + return x.contiguous() + +class linear(nn.Module): + def __init__(self,c_in,c_out,bias=True): + super(linear,self).__init__() + self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias) + + def forward(self,x): + return self.mlp(x) + + +class prop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(prop, self).__init__() + self.nconv = nconv() + self.mlp = linear(c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + dv = d + a = adj / dv.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + ho = self.mlp(h) + return ho + + +class mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(mixprop, self).__init__() + self.nconv = nconv() + self.mlp = linear((gdep+1)*c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + out = [h] + a = adj / d.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + out.append(h) + ho = torch.cat(out,dim=1) + ho = self.mlp(ho) + return ho + +class dy_mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(dy_mixprop, self).__init__() + self.nconv = dy_nconv() + self.mlp1 = linear((gdep+1)*c_in,c_out) + self.mlp2 = linear((gdep+1)*c_in,c_out) + + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + self.lin1 = linear(c_in,c_in) + self.lin2 = linear(c_in,c_in) + + + def forward(self,x): + #adj = adj + torch.eye(adj.size(0)).to(x.device) + #d = adj.sum(1) + x1 = torch.tanh(self.lin1(x)) + x2 = torch.tanh(self.lin2(x)) + adj = self.nconv(x1.transpose(2,1),x2) + adj0 = torch.softmax(adj, dim=2) + adj1 = torch.softmax(adj.transpose(2,1), dim=2) + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0) + out.append(h) + ho = torch.cat(out,dim=1) + ho1 = self.mlp1(ho) + + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) + out.append(h) + ho = torch.cat(out, dim=1) + ho2 = self.mlp2(ho) + + return ho1+ho2 + + + +class dilated_1D(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_1D, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor)) + + def forward(self,input): + x = self.tconv(input) + return x + +class dilated_inception(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_inception, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + cout = int(cout/len(self.kernel_set)) + for kern in self.kernel_set: + self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor))) + + def forward(self,input): + x = [] + for i in range(len(self.kernel_set)): + x.append(self.tconv[i](input)) + for i in range(len(self.kernel_set)): + x[i] = x[i][...,-x[-1].size(3):] + x = torch.cat(x,dim=1) + return x + + +class graph_constructor(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_constructor, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + def fullA(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + return adj + +class graph_global(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_global, self).__init__() + self.nnodes = nnodes + self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device) + + def forward(self, idx): + return F.relu(self.A) + + +class graph_undirected(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_undirected, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb1(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + + +class graph_directed(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_directed, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + +class LayerNorm(nn.Module): + __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.reset_parameters() + + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input, idx): + if self.elementwise_affine: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) + else: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) + + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) \ No newline at end of file diff --git a/model/PatchTST/PatchTST.py b/model/PatchTST/PatchTST.py new file mode 100644 index 0000000..3112030 --- /dev/null +++ b/model/PatchTST/PatchTST.py @@ -0,0 +1,109 @@ +import torch +from torch import nn +from model.PatchTST.layers.Transformer import Encoder, EncoderLayer +from model.PatchTST.layers.SelfAttention import FullAttention, AttentionLayer +from model.PatchTST.layers.Embed import PatchEmbedding + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +class FlattenHead(nn.Module): + def __init__(self, n_vars, nf, target_window, head_dropout=0): + super().__init__() + self.n_vars = n_vars + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + +class Model(nn.Module): + """ + Paper link: https://arxiv.org/pdf/2211.14730.pdf + """ + + def __init__(self, configs): + """ + patch_len: int, patch len for patch_embedding + stride: int, stride for patch_embedding + """ + super().__init__() + self.seq_len = configs['seq_len'] + self.pred_len = configs['pred_len'] + self.patch_len = configs['patch_len'] + self.stride = configs['stride'] + padding = self.stride + + # patching and embedding + self.patch_embedding = PatchEmbedding( + configs['d_model'], self.patch_len, self.stride, padding, configs['dropout']) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention(False, attention_dropout=configs['dropout'], + output_attention=False), configs['d_model'], configs['n_heads']), + configs['d_model'], + configs['d_ff'], + dropout=configs['dropout'], + activation=configs['activation'] + ) for l in range(configs['e_layers']) + ], + norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2)) + ) + + # Prediction Head + self.head_nf = configs.d_model * \ + int((configs.seq_len - self.patch_len) / self.stride + 2) + self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len, + head_dropout=configs.dropout) + + def forecast(self, x_enc): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc /= stdev + + # do patching and embedding + x_enc = x_enc.permute(0, 2, 1) + # u: [bs * nvars x patch_num x d_model] + enc_out, n_vars = self.patch_embedding(x_enc) + + # Encoder + # z: [bs * nvars x patch_num x d_model] + enc_out, attns = self.encoder(enc_out) + # z: [bs x nvars x patch_num x d_model] + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) + # z: [bs x nvars x d_model x patch_num] + enc_out = enc_out.permute(0, 1, 3, 2) + + # Decoder + dec_out = self.head(enc_out) # z: [bs x nvars x target_window] + dec_out = dec_out.permute(0, 2, 1) + + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * \ + (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + dec_out = dec_out + \ + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + return dec_out + + def forward(self, x_enc): + dec_out = self.forecast(x_enc) + return dec_out[:, -self.pred_len:, :] # [B, L, D] diff --git a/model/PatchTST/layers/Embed.py b/model/PatchTST/layers/Embed.py new file mode 100644 index 0000000..94896e0 --- /dev/null +++ b/model/PatchTST/layers/Embed.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars \ No newline at end of file diff --git a/model/PatchTST/layers/SelfAttention.py b/model/PatchTST/layers/SelfAttention.py new file mode 100644 index 0000000..55b2493 --- /dev/null +++ b/model/PatchTST/layers/SelfAttention.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +import numpy as np +from math import sqrt + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask, + tau=tau, + delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask \ No newline at end of file diff --git a/model/PatchTST/layers/Transformer.py b/model/PatchTST/layers/Transformer.py new file mode 100644 index 0000000..6116325 --- /dev/null +++ b/model/PatchTST/layers/Transformer.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask, + tau=tau, delta=delta + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 7403893..09b7fdc 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -29,6 +29,7 @@ from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer from model.HI.HI import HI +from model.PatchTST.PatchTST import Model as PatchTST @@ -96,3 +97,5 @@ def model_selector(config): return iTransformer(model_config) case "HI": return HI(model_config) + case "PatchTST": + return PatchTST(model_config) diff --git a/train.py b/train.py index 9d58921..dad4609 100644 --- a/train.py +++ b/train.py @@ -45,11 +45,13 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["HI"] + model_list = ["PatchTST"] # 指定数据集 dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] device = "cuda:0" # 指定设备 seed = 2023 # 随机种子 + epochs = 1 for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" @@ -57,6 +59,7 @@ if __name__ == "__main__": config = yaml.safe_load(file) config["basic"]["device"] = device config["basic"]["seed"] = seed + config["train"]["epochs"] = epochs print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") print(f"config: {config}") run(config) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 4bd82a4..80a6672 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -2,7 +2,6 @@ import math import os import time import copy -import psutil import torch from utils.logger import get_logger from utils.loss_function import all_metrics -- 2.40.1 From 600420e8df134d7e3e342830d482f413c5067a7c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 23:31:17 +0800 Subject: [PATCH 24/41] impl mtgnn --- .vscode/launch.json | 2109 +------------------------- config/MTGNN/AirQuality.yaml | 64 + config/MTGNN/BJTaxi-Inflow.yaml | 64 + config/MTGNN/BJTaxi-Outflow.yaml | 64 + config/MTGNN/METR-LA.yaml | 64 + config/MTGNN/NYCBike-Inflow.yaml | 64 + config/MTGNN/NYCBike-Outflow.yaml | 64 + config/MTGNN/PEMS-BAY.yaml | 64 + config/MTGNN/SolarEnergy.yaml | 64 + config/PatchTST/AirQuality.yaml | 3 +- config/PatchTST/BJTaxi-Inflow.yaml | 3 +- config/PatchTST/BJTaxi-Outflow.yaml | 3 +- config/PatchTST/METR-LA.yaml | 3 +- config/PatchTST/NYCBike-Inflow.yaml | 3 +- config/PatchTST/NYCBike-Outflow.yaml | 3 +- config/PatchTST/PEMS-BAY.yaml | 3 +- config/PatchTST/SolarEnergy.yaml | 3 +- model/MTGNN/MTGNN.py | 123 +- model/PatchTST/PatchTST.py | 10 +- model/PatchTST/layers/Embed.py | 21 + model/model_selector.py | 3 + train.py | 47 +- trainer/Trainer.py | 4 + 23 files changed, 671 insertions(+), 2182 deletions(-) create mode 100644 config/MTGNN/AirQuality.yaml create mode 100644 config/MTGNN/BJTaxi-Inflow.yaml create mode 100644 config/MTGNN/BJTaxi-Outflow.yaml create mode 100644 config/MTGNN/METR-LA.yaml create mode 100644 config/MTGNN/NYCBike-Inflow.yaml create mode 100644 config/MTGNN/NYCBike-Outflow.yaml create mode 100644 config/MTGNN/PEMS-BAY.yaml create mode 100644 config/MTGNN/SolarEnergy.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 54aad8a..fb16dc0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,2113 +6,12 @@ "configurations": [ { - "name": "DDGCRN: METR-LA", + "name": "train", "type": "debugpy", "request": "launch", - "program": "run.py", + "program": "train.py", "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/METR-LA.yaml" - }, - // STID 模型组 - { - "name": "STID: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/PEMS-BAY.yaml" - }, - { - "name": "STID: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/METR-LA.yaml" - }, - { - "name": "STID: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/PEMSD4.yaml" - }, - { - "name": "STID: BJTaxi-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/BJTaxi_Inflow.yaml" - }, - { - "name": "STID: BJTaxi-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/BJTaxi_Outflow.yaml" - }, - { - "name": "STID: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Inflow.yaml" - }, - { - "name": "STID: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Outflow.yaml" - }, - { - "name": "STID: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/SolarEnergy.yaml" - }, - - // REPST 模型组 - { - "name": "REPST: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMSD8.yaml" - }, - { - "name": "REPST: BJTaxi-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" - }, - { - "name": "REPST: NYCBike-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-outflow.yaml" - }, - { - "name": "REPST: NYCBike-inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-inflow.yaml" - }, - { - "name": "REPST: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMS-BAY.yaml" - }, - { - "name": "REPST: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/METR-LA.yaml" - }, - { - "name": "REPST: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/SolarEnergy.yaml" - }, - { - "name": "REPST: BeijingAirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BeijingAirQuality.yaml" - }, - { - "name": "REPST: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/AirQuality.yaml" - }, - - // ASTRA 模型组 - { - "name": "ASTRA: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/PEMS-BAY.yaml" - }, - { - "name": "ASTRA: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/METR-LA.yaml" - }, - { - "name": "ASTRA: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/AirQuality.yaml" - }, - { - "name": "ASTRA: BJTaxi-Inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/BJTaxi-Inflow.yaml" - }, - { - "name": "ASTRA: BJTaxi-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/BJTaxi-outflow.yaml" - }, - { - "name": "ASTRA: NYCBike-inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-inflow.yaml" - }, - { - "name": "ASTRA: NYCBike-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-outflow.yaml" - }, - { - "name": "ASTRA: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/SolarEnergy.yaml" - }, - { - "name": "ASTRA_v2: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_AirQuality.yaml" - }, - { - "name": "ASTRA_v2: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_SolarEnergy.yaml" - }, - { - "name": "ASTRA_v3: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v3_METR-LA.yaml" - }, - { - "name": "EXPB: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/NYCBike-InFlow.yaml" - }, - { - "name": "EXPB: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/PEMSD4.yaml" - }, - { - "name": "EXPB: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/METR-LA.yaml" - }, - { - "name": "EXPB: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/AirQuality.yaml" - }, - { - "name": "EXPB: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/NYCBike-OutFlow.yaml" - }, - { - "name": "EXPB: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/SolarEnergy.yaml" - }, - { - "name": "TWDGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/NYCBike-InFlow.yaml" - }, - { - "name": "TWDGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD4.yaml" - }, - { - "name": "TWDGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/METR-LA.yaml" - }, - { - "name": "TWDGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/AirQuality.yaml" - }, - { - "name": "TWDGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "TWDGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD8.yaml" - }, - { - "name": "TWDGCN: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7(L).yaml" - }, - { - "name": "TWDGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD3.yaml" - }, - { - "name": "TWDGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/SolarEnergy.yaml" - }, - { - "name": "TWDGCN: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/Hainan.yaml" - }, - { - "name": "TWDGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7.yaml" - }, - { - "name": "TWDGCN: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7(M).yaml" - }, - { - "name": "STSGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STSGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD4.yaml" - }, - { - "name": "STSGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/METR-LA.yaml" - }, - { - "name": "STSGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/AirQuality.yaml" - }, - { - "name": "STSGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STSGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD8.yaml" - }, - { - "name": "STSGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD3.yaml" - }, - { - "name": "STSGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/SolarEnergy.yaml" - }, - { - "name": "STSGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD7.yaml" - }, - { - "name": "STID: NYCBike_Inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Inflow.yaml" - }, - { - "name": "STID: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/AirQuality.yaml" - }, - { - "name": "STID: NYCBike_Outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Outflow.yaml" - }, - { - "name": "STAWnet: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/NYCBike-InFlow.yaml" - }, - { - "name": "STAWnet: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD4.yaml" - }, - { - "name": "STAWnet: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/METR-LA.yaml" - }, - { - "name": "STAWnet: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/AirQuality.yaml" - }, - { - "name": "STAWnet: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/NYCBike-OutFlow.yaml" - }, - { - "name": "STAWnet: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD8.yaml" - }, - { - "name": "STAWnet: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD3.yaml" - }, - { - "name": "STAWnet: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/SolarEnergy.yaml" - }, - { - "name": "STAWnet: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD7.yaml" - }, - { - "name": "DCRNN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/NYCBike-InFlow.yaml" - }, - { - "name": "DCRNN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD4.yaml" - }, - { - "name": "DCRNN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/METR-LA.yaml" - }, - { - "name": "DCRNN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/AirQuality.yaml" - }, - { - "name": "DCRNN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/NYCBike-OutFlow.yaml" - }, - { - "name": "DCRNN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD8.yaml" - }, - { - "name": "DCRNN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD3.yaml" - }, - { - "name": "DCRNN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/SolarEnergy.yaml" - }, - { - "name": "DCRNN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD7.yaml" - }, - { - "name": "STAEFormer: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/NYCBike-InFlow.yaml" - }, - { - "name": "STAEFormer: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD4.yaml" - }, - { - "name": "STAEFormer: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/METR-LA.yaml" - }, - { - "name": "STAEFormer: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/AirQuality.yaml" - }, - { - "name": "STAEFormer: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/NYCBike-OutFlow.yaml" - }, - { - "name": "STAEFormer: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD8.yaml" - }, - { - "name": "STAEFormer: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD3.yaml" - }, - { - "name": "STAEFormer: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/SolarEnergy.yaml" - }, - { - "name": "STAEFormer: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD7.yaml" - }, - { - "name": "STGODE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/NYCBike-InFlow.yaml" - }, - { - "name": "STGODE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD4.yaml" - }, - { - "name": "STGODE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/METR-LA.yaml" - }, - { - "name": "STGODE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/AirQuality.yaml" - }, - { - "name": "STGODE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGODE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD8.yaml" - }, - { - "name": "STGODE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD3.yaml" - }, - { - "name": "STGODE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/SolarEnergy.yaml" - }, - { - "name": "STGODE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD7.yaml" - }, - { - "name": "STGNCDE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/NYCBike-InFlow.yaml" - }, - { - "name": "STGNCDE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD4.yaml" - }, - { - "name": "STGNCDE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/METR-LA.yaml" - }, - { - "name": "STGNCDE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/AirQuality.yaml" - }, - { - "name": "STGNCDE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGNCDE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD8.yaml" - }, - { - "name": "STGNCDE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD3.yaml" - }, - { - "name": "STGNCDE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/SolarEnergy.yaml" - }, - { - "name": "STGNCDE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD7.yaml" - }, - { - "name": "ASTRA: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-InFlow.yaml" - }, - { - "name": "ASTRA: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-OutFlow.yaml" - }, - { - "name": "ST_SSL: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/NYCBike-InFlow.yaml" - }, - { - "name": "ST_SSL: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD4.yaml" - }, - { - "name": "ST_SSL: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/METR-LA.yaml" - }, - { - "name": "ST_SSL: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/AirQuality.yaml" - }, - { - "name": "ST_SSL: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/NYCBike-OutFlow.yaml" - }, - { - "name": "ST_SSL: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD8.yaml" - }, - { - "name": "ST_SSL: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD3.yaml" - }, - { - "name": "ST_SSL: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/SolarEnergy.yaml" - }, - { - "name": "ST_SSL: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD7.yaml" - }, - { - "name": "TCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/NYCBike-InFlow.yaml" - }, - { - "name": "TCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD4.yaml" - }, - { - "name": "TCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/METR-LA.yaml" - }, - { - "name": "TCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/AirQuality.yaml" - }, - { - "name": "TCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/NYCBike-OutFlow.yaml" - }, - { - "name": "TCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD8.yaml" - }, - { - "name": "TCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD3.yaml" - }, - { - "name": "TCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/SolarEnergy.yaml" - }, - { - "name": "TCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD7.yaml" - }, - { - "name": "EXP: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/NYCBike-InFlow.yaml" - }, - { - "name": "EXP: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD4.yaml" - }, - { - "name": "EXP: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/METR-LA.yaml" - }, - { - "name": "EXP: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/AirQuality.yaml" - }, - { - "name": "EXP: SD", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/SD.yaml" - }, - { - "name": "EXP: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/NYCBike-OutFlow.yaml" - }, - { - "name": "EXP: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD8.yaml" - }, - { - "name": "EXP: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD3.yaml" - }, - { - "name": "EXP: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/SolarEnergy.yaml" - }, - { - "name": "EXP: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD7.yaml" - }, - { - "name": "DDGCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/NYCBike-InFlow.yaml" - }, - { - "name": "DDGCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD4.yaml" - }, - { - "name": "DDGCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/AirQuality.yaml" - }, - { - "name": "DDGCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "DDGCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD8.yaml" - }, - { - "name": "DDGCRN: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7(L).yaml" - }, - { - "name": "DDGCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD3.yaml" - }, - { - "name": "DDGCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/SolarEnergy.yaml" - }, - { - "name": "DDGCRN: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/Hainan.yaml" - }, - { - "name": "DDGCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7.yaml" - }, - { - "name": "DDGCRN: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7(M).yaml" - }, - { - "name": "DSANET: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/NYCBike-InFlow.yaml" - }, - { - "name": "DSANET: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD4.yaml" - }, - { - "name": "DSANET: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/METR-LA.yaml" - }, - { - "name": "DSANET: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/AirQuality.yaml" - }, - { - "name": "DSANET: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/NYCBike-OutFlow.yaml" - }, - { - "name": "DSANET: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD8.yaml" - }, - { - "name": "DSANET: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD3.yaml" - }, - { - "name": "DSANET: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/SolarEnergy.yaml" - }, - { - "name": "DSANET: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD7.yaml" - }, - { - "name": "STFGNN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/NYCBike-InFlow.yaml" - }, - { - "name": "STFGNN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD4.yaml" - }, - { - "name": "STFGNN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/METR-LA.yaml" - }, - { - "name": "STFGNN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/AirQuality.yaml" - }, - { - "name": "STFGNN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/NYCBike-OutFlow.yaml" - }, - { - "name": "STFGNN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD8.yaml" - }, - { - "name": "STFGNN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD3.yaml" - }, - { - "name": "STFGNN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/SolarEnergy.yaml" - }, - { - "name": "STFGNN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD7.yaml" - }, - { - "name": "AGCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/NYCBike-InFlow.yaml" - }, - { - "name": "AGCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD4.yaml" - }, - { - "name": "AGCRN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/METR-LA.yaml" - }, - { - "name": "AGCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/AirQuality.yaml" - }, - { - "name": "AGCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "AGCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD8.yaml" - }, - { - "name": "AGCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD3.yaml" - }, - { - "name": "AGCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/SolarEnergy.yaml" - }, - { - "name": "AGCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD7.yaml" - }, - { - "name": "STGNRDE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/NYCBike-InFlow.yaml" - }, - { - "name": "STGNRDE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD4.yaml" - }, - { - "name": "STGNRDE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/METR-LA.yaml" - }, - { - "name": "STGNRDE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/AirQuality.yaml" - }, - { - "name": "STGNRDE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGNRDE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD8.yaml" - }, - { - "name": "STGNRDE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD3.yaml" - }, - { - "name": "STGNRDE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/SolarEnergy.yaml" - }, - { - "name": "STGNRDE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD7.yaml" - }, - { - "name": "REPST: PEMS-BAY_paper", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMS-BAY_paper.yaml" - }, - { - "name": "REPST: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-InFlow.yaml" - }, - { - "name": "REPST: BeijingAirQuality(Deprecated)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BeijingAirQuality(Deprecated).yaml" - }, - { - "name": "REPST: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-OutFlow.yaml" - }, - { - "name": "STIDGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STIDGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD4.yaml" - }, - { - "name": "STIDGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/METR-LA.yaml" - }, - { - "name": "STIDGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/AirQuality.yaml" - }, - { - "name": "STIDGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STIDGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD8.yaml" - }, - { - "name": "STIDGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD3.yaml" - }, - { - "name": "STIDGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/SolarEnergy.yaml" - }, - { - "name": "STIDGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD7.yaml" - }, - { - "name": "PDG2SEQ: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/NYCBike-InFlow.yaml" - }, - { - "name": "PDG2SEQ: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD4.yaml" - }, - { - "name": "PDG2SEQ: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/METR-LA.yaml" - }, - { - "name": "PDG2SEQ: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/AirQuality.yaml" - }, - { - "name": "PDG2SEQ: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/NYCBike-OutFlow.yaml" - }, - { - "name": "PDG2SEQ: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD8.yaml" - }, - { - "name": "PDG2SEQ: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD3.yaml" - }, - { - "name": "PDG2SEQ: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/SolarEnergy.yaml" - }, - { - "name": "PDG2SEQ: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD7.yaml" - }, - { - "name": "NLT: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/NYCBike-InFlow.yaml" - }, - { - "name": "NLT: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD4.yaml" - }, - { - "name": "NLT: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/METR-LA.yaml" - }, - { - "name": "NLT: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/AirQuality.yaml" - }, - { - "name": "NLT: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/NYCBike-OutFlow.yaml" - }, - { - "name": "NLT: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD8.yaml" - }, - { - "name": "NLT: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD3.yaml" - }, - { - "name": "NLT: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/SolarEnergy.yaml" - }, - { - "name": "NLT: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD7.yaml" - }, - { - "name": "ARIMA: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/NYCBike-InFlow.yaml" - }, - { - "name": "ARIMA: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD4.yaml" - }, - { - "name": "ARIMA: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/METR-LA.yaml" - }, - { - "name": "ARIMA: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/AirQuality.yaml" - }, - { - "name": "ARIMA: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/NYCBike-OutFlow.yaml" - }, - { - "name": "ARIMA: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD8.yaml" - }, - { - "name": "ARIMA: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7(L).yaml" - }, - { - "name": "ARIMA: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD3.yaml" - }, - { - "name": "ARIMA: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/SolarEnergy.yaml" - }, - { - "name": "ARIMA: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/Hainan.yaml" - }, - { - "name": "ARIMA: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7.yaml" - }, - { - "name": "ARIMA: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7(M).yaml" - }, - { - "name": "STMLP: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/NYCBike-InFlow.yaml" - }, - { - "name": "STMLP: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD4.yaml" - }, - { - "name": "STMLP: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/METR-LA.yaml" - }, - { - "name": "STMLP: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/AirQuality.yaml" - }, - { - "name": "STMLP: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/NYCBike-OutFlow.yaml" - }, - { - "name": "STMLP: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD8.yaml" - }, - { - "name": "STMLP: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD3.yaml" - }, - { - "name": "STMLP: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/SolarEnergy.yaml" - }, - { - "name": "STMLP: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD7.yaml" - }, - { - "name": "MegaCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/NYCBike-InFlow.yaml" - }, - { - "name": "MegaCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD4.yaml" - }, - { - "name": "MegaCRN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/METR-LA.yaml" - }, - { - "name": "MegaCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/AirQuality.yaml" - }, - { - "name": "MegaCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "MegaCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD8.yaml" - }, - { - "name": "MegaCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD3.yaml" - }, - { - "name": "MegaCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/SolarEnergy.yaml" - }, - { - "name": "MegaCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD7.yaml" - }, - { - "name": "GWN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/NYCBike-InFlow.yaml" - }, - { - "name": "GWN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD4.yaml" - }, - { - "name": "GWN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/METR-LA.yaml" - }, - { - "name": "GWN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/AirQuality.yaml" - }, - { - "name": "GWN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/NYCBike-OutFlow.yaml" - }, - { - "name": "GWN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD8.yaml" - }, - { - "name": "GWN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD3.yaml" - }, - { - "name": "GWN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/SolarEnergy.yaml" - }, - { - "name": "GWN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD7.yaml" - }, - { - "name": "STGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD4.yaml" - }, - { - "name": "STGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/METR-LA.yaml" - }, - { - "name": "STGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/AirQuality.yaml" - }, - { - "name": "STGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD8.yaml" - }, - { - "name": "STGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD3.yaml" - }, - { - "name": "STGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/SolarEnergy.yaml" - }, - { - "name": "STGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD7.yaml" - }, - { - "name": "iTransformer: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/iTransformer/METR-LA.yaml" - }, - { - "name": "iTransformer: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/iTransformer/AirQuality.yaml" - }, - { - "name": "HI: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/HI/PEMS-BAY.yaml" - }, + "justMyCode": false + } ] } \ No newline at end of file diff --git a/config/MTGNN/AirQuality.yaml b/config/MTGNN/AirQuality.yaml new file mode 100644 index 0000000..9846895 --- /dev/null +++ b/config/MTGNN/AirQuality.yaml @@ -0,0 +1,64 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 35 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 6 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 6 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/BJTaxi-Inflow.yaml b/config/MTGNN/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..09e453a --- /dev/null +++ b/config/MTGNN/BJTaxi-Inflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/BJTaxi-Outflow.yaml b/config/MTGNN/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..1b62a4e --- /dev/null +++ b/config/MTGNN/BJTaxi-Outflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/METR-LA.yaml b/config/MTGNN/METR-LA.yaml new file mode 100644 index 0000000..2518638 --- /dev/null +++ b/config/MTGNN/METR-LA.yaml @@ -0,0 +1,64 @@ +basic: + dataset: METR-LA + device: cuda:1 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 207 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/NYCBike-Inflow.yaml b/config/MTGNN/NYCBike-Inflow.yaml new file mode 100644 index 0000000..95ae41b --- /dev/null +++ b/config/MTGNN/NYCBike-Inflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/NYCBike-Outflow.yaml b/config/MTGNN/NYCBike-Outflow.yaml new file mode 100644 index 0000000..b1646ea --- /dev/null +++ b/config/MTGNN/NYCBike-Outflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/PEMS-BAY.yaml b/config/MTGNN/PEMS-BAY.yaml new file mode 100644 index 0000000..7f28aca --- /dev/null +++ b/config/MTGNN/PEMS-BAY.yaml @@ -0,0 +1,64 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 325 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/SolarEnergy.yaml b/config/MTGNN/SolarEnergy.yaml new file mode 100644 index 0000000..2f60b8d --- /dev/null +++ b/config/MTGNN/SolarEnergy.yaml @@ -0,0 +1,64 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 137 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml index a3e6418..3cdf977 100644 --- a/config/PatchTST/AirQuality.yaml +++ b/config/PatchTST/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 6 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml index 9bd66d9..576dbd6 100644 --- a/config/PatchTST/BJTaxi-Inflow.yaml +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml index 2382695..773ba26 100644 --- a/config/PatchTST/BJTaxi-Outflow.yaml +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml index d076d35..6b9461a 100644 --- a/config/PatchTST/METR-LA.yaml +++ b/config/PatchTST/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:1 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml index 2c3026c..408995c 100644 --- a/config/PatchTST/NYCBike-Inflow.yaml +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml index 16eee20..c50f4a1 100644 --- a/config/PatchTST/NYCBike-Outflow.yaml +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml index 6186db3..e798294 100644 --- a/config/PatchTST/PEMS-BAY.yaml +++ b/config/PatchTST/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -24,6 +24,7 @@ model: pred_len: 24 d_model: 128 patch_len: 6 + enc_in: 1 stride: 8 d_ff: 2048 dropout: 0.1 diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index 28b85b9..b1de602 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -24,6 +24,7 @@ model: pred_len: 24 d_model: 128 patch_len: 6 + enc_in: 6 stride: 8 d_ff: 2048 dropout: 0.1 diff --git a/model/MTGNN/MTGNN.py b/model/MTGNN/MTGNN.py index 483a184..43d9b31 100644 --- a/model/MTGNN/MTGNN.py +++ b/model/MTGNN/MTGNN.py @@ -3,91 +3,109 @@ from model.MTGNN.layer import * class gtnet(nn.Module): - def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): + def __init__(self, configs): super(gtnet, self).__init__() - self.gcn_true = gcn_true - self.buildA_true = buildA_true - self.num_nodes = num_nodes - self.dropout = dropout - self.predefined_A = predefined_A - self.filter_convs = nn.ModuleList() - self.gate_convs = nn.ModuleList() - self.residual_convs = nn.ModuleList() - self.skip_convs = nn.ModuleList() - self.gconv1 = nn.ModuleList() - self.gconv2 = nn.ModuleList() - self.norm = nn.ModuleList() - self.start_conv = nn.Conv2d(in_channels=in_dim, - out_channels=residual_channels, + self.gcn_true = configs['gcn_true'] # 是否使用图卷积网络 + self.buildA_true = configs['buildA_true'] # 是否动态构建邻接矩阵 + self.num_nodes = configs['num_nodes'] # 节点数量 + self.device = configs['device'] # 设备(CPU/GPU) + self.dropout = configs['dropout'] # dropout率 + self.predefined_A = configs.get('predefined_A', None) # 预定义邻接矩阵 + self.static_feat = configs.get('static_feat', None) # 静态特征 + self.subgraph_size = configs['subgraph_size'] # 子图大小 + self.node_dim = configs['node_dim'] # 节点嵌入维度 + self.dilation_exponential = configs['dilation_exponential'] # 膨胀卷积指数 + self.conv_channels = configs['conv_channels'] # 卷积通道数 + self.residual_channels = configs['residual_channels'] # 残差通道数 + self.skip_channels = configs['skip_channels'] # 跳跃连接通道数 + self.end_channels = configs['end_channels'] # 输出层通道数 + self.seq_length = configs['seq_len'] # 输入序列长度 + self.in_dim = configs['in_dim'] # 输入特征维度 + self.out_len = configs['out_len'] # 输出序列长度 + self.out_dim = configs['out_dim'] # 输出预测维度 + self.layers = configs['layers'] # 模型层数 + self.propalpha = configs['propalpha'] # 图传播参数alpha + self.tanhalpha = configs['tanhalpha'] # tanh激活参数alpha + self.layer_norm_affline = configs['layer_norm_affline'] # 层归一化是否使用affine变换 + self.gcn_depth = configs['gcn_depth'] # 图卷积深度 + self.filter_convs = nn.ModuleList() # 卷积滤波器列表 + self.gate_convs = nn.ModuleList() # 门控卷积列表 + self.residual_convs = nn.ModuleList() # 残差卷积列表 + self.skip_convs = nn.ModuleList() # 跳跃连接卷积列表 + self.gconv1 = nn.ModuleList() # 第一层图卷积列表 + self.gconv2 = nn.ModuleList() # 第二层图卷积列表 + self.norm = nn.ModuleList() # 归一化层列表 + self.start_conv = nn.Conv2d(in_channels=self.in_dim, + out_channels=self.residual_channels, kernel_size=(1, 1)) - self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) + self.gc = graph_constructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat) - self.seq_length = seq_length kernel_size = 7 - if dilation_exponential>1: - self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + if self.dilation_exponential>1: + self.receptive_field = int(1+(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: - self.receptive_field = layers*(kernel_size-1) + 1 + self.receptive_field = self.layers*(kernel_size-1) + 1 for i in range(1): - if dilation_exponential>1: - rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + if self.dilation_exponential>1: + rf_size_i = int(1 + i*(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: - rf_size_i = i*layers*(kernel_size-1)+1 + rf_size_i = i*self.layers*(kernel_size-1)+1 new_dilation = 1 - for j in range(1,layers+1): - if dilation_exponential > 1: - rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) + for j in range(1,self.layers+1): + if self.dilation_exponential > 1: + rf_size_j = int(rf_size_i + (kernel_size-1)*(self.dilation_exponential**j-1)/(self.dilation_exponential-1)) else: rf_size_j = rf_size_i+j*(kernel_size-1) - self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) - self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) - self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=residual_channels, + self.filter_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.residual_channels, kernel_size=(1, 1))) if self.seq_length>self.receptive_field: - self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=skip_channels, + self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.skip_channels, kernel_size=(1, self.seq_length-rf_size_j+1))) else: - self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=skip_channels, + self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.skip_channels, kernel_size=(1, self.receptive_field-rf_size_j+1))) if self.gcn_true: - self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) - self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + self.gconv1.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) + self.gconv2.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) if self.seq_length>self.receptive_field: - self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) else: - self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) - new_dilation *= dilation_exponential + new_dilation *= self.dilation_exponential - self.layers = layers - self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, - out_channels=end_channels, + self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels, + out_channels=self.end_channels, kernel_size=(1,1), bias=True) - self.end_conv_2 = nn.Conv2d(in_channels=end_channels, - out_channels=out_dim, + self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels, + out_channels=self.out_len * self.out_dim, kernel_size=(1,1), bias=True) if self.seq_length > self.receptive_field: - self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) - self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.seq_length), bias=True) + self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + else: - self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) - self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) + self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field), bias=True) + self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, 1), bias=True) - - self.idx = torch.arange(self.num_nodes).to(device) + self.idx = torch.arange(self.num_nodes).to(self.device) def forward(self, input, idx=None): + input = input[..., :-2] # 去掉周期嵌入 + input = input.transpose(1, 3) seq_len = input.size(3) assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' @@ -130,5 +148,8 @@ class gtnet(nn.Module): skip = self.skipE(x) + skip x = F.relu(skip) x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) + x = self.end_conv_2(x) # [b, t*c, n, 1] + # [b, t*c, n, 1] -> [b,t,c,n] -> [b, t, n, c] + x = x.reshape(x.size(0), self.out_len, self.out_dim, self.num_nodes) + x = x.permute(0, 1, 3, 2) return x \ No newline at end of file diff --git a/model/PatchTST/PatchTST.py b/model/PatchTST/PatchTST.py index 3112030..4645c28 100644 --- a/model/PatchTST/PatchTST.py +++ b/model/PatchTST/PatchTST.py @@ -62,14 +62,14 @@ class Model(nn.Module): activation=configs['activation'] ) for l in range(configs['e_layers']) ], - norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2)) + norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs['d_model']), Transpose(1,2)) ) # Prediction Head - self.head_nf = configs.d_model * \ - int((configs.seq_len - self.patch_len) / self.stride + 2) - self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len, - head_dropout=configs.dropout) + self.head_nf = configs['d_model'] * \ + int((configs['seq_len'] - self.patch_len) / self.stride + 2) + self.head = FlattenHead(configs['enc_in'], self.head_nf, configs['pred_len'], + head_dropout=configs['dropout']) def forecast(self, x_enc): # Normalization from Non-stationary Transformer diff --git a/model/PatchTST/layers/Embed.py b/model/PatchTST/layers/Embed.py index 94896e0..d38d093 100644 --- a/model/PatchTST/layers/Embed.py +++ b/model/PatchTST/layers/Embed.py @@ -1,5 +1,26 @@ import torch import torch.nn as nn +import math + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] class PatchEmbedding(nn.Module): def __init__(self, d_model, patch_len, stride, padding, dropout): diff --git a/model/model_selector.py b/model/model_selector.py index 09b7fdc..5621037 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -30,6 +30,7 @@ from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer from model.HI.HI import HI from model.PatchTST.PatchTST import Model as PatchTST +from model.MTGNN.MTGNN import gtnet as MTGNN @@ -99,3 +100,5 @@ def model_selector(config): return HI(model_config) case "PatchTST": return PatchTST(model_config) + case "MTGNN": + return MTGNN(model_config) diff --git a/train.py b/train.py index dad4609..7bd72ad 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,27 @@ import yaml import torch +import os import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer +def read_config(config_path): + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + # 全局配置 + device = "cuda:0" # 指定设备 + seed = 2023 # 随机种子 + epochs = 100 + + # 拷贝项 + config["basic"]["device"] = device + config["model"]["device"] = device + config["basic"]["seed"] = seed + config["train"]["epochs"] = epochs + return config + def run(config): init.init_seed(config["basic"]["seed"]) model = init.init_model(config) @@ -45,22 +62,26 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["PatchTST"] + model_list = ["MTGNN"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] # dataset_list = ["AirQuality"] - device = "cuda:0" # 指定设备 - seed = 2023 # 随机种子 - epochs = 1 + + # 我的调试开关,不做测试就填 str(False) + os.environ["TRY"] = str(False) + for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" - with open(config_path, "r") as file: - config = yaml.safe_load(file) - config["basic"]["device"] = device - config["basic"]["seed"] = seed - config["train"]["epochs"] = epochs - print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") - print(f"config: {config}") - run(config) + # 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs + config = read_config(config_path) + print(f"\nRunning {model} on {dataset}") + # print(f"config: {config}") + if os.environ.get("TRY") == "True": + try: + run(config) + except Exception as e: + pass + else: + run(config) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 80a6672..3372873 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -71,6 +71,10 @@ class Trainer: label = target[..., : self.args["output_dim"]] # 计算loss和反归一化loss output = self.model(data) + # 我的调试开关 + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") + assert False loss = self.loss(output, label) d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) -- 2.40.1 From 19fd7622a379111e5772bd3b7f3d713ab154b15c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 11 Dec 2025 23:16:25 +0800 Subject: [PATCH 25/41] =?UTF-8?q?=E5=85=BC=E5=AE=B9InFormer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/Informer/AirQuality.yaml | 66 +++++++ config/Informer/BJTaxi-Inflow.yaml | 66 +++++++ config/Informer/BJTaxi-Outflow.yaml | 66 +++++++ config/Informer/METR-LA.yaml | 66 +++++++ config/Informer/NYCBike-Inflow.yaml | 66 +++++++ config/Informer/NYCBike-Outflow.yaml | 66 +++++++ config/Informer/PEMS-BAY.yaml | 66 +++++++ config/Informer/SolarEnergy.yaml | 66 +++++++ dataloader/Informer_loader.py | 179 +++++++++++++++++++ dataloader/loader_selector.py | 5 +- model/Informer/attn.py | 163 +++++++++++++++++ model/Informer/decoder.py | 51 ++++++ model/Informer/embed.py | 129 ++++++++++++++ model/Informer/encoder.py | 98 +++++++++++ model/Informer/masking.py | 24 +++ model/Informer/model.py | 141 +++++++++++++++ model/model_selector.py | 3 + test_informer.py | 57 ++++++ train.py | 24 ++- trainer/InformerTrainer.py | 250 +++++++++++++++++++++++++++ trainer/trainer_selector.py | 13 ++ 21 files changed, 1656 insertions(+), 9 deletions(-) create mode 100644 config/Informer/AirQuality.yaml create mode 100644 config/Informer/BJTaxi-Inflow.yaml create mode 100644 config/Informer/BJTaxi-Outflow.yaml create mode 100644 config/Informer/METR-LA.yaml create mode 100644 config/Informer/NYCBike-Inflow.yaml create mode 100644 config/Informer/NYCBike-Outflow.yaml create mode 100644 config/Informer/PEMS-BAY.yaml create mode 100644 config/Informer/SolarEnergy.yaml create mode 100644 dataloader/Informer_loader.py create mode 100644 model/Informer/attn.py create mode 100644 model/Informer/decoder.py create mode 100644 model/Informer/embed.py create mode 100644 model/Informer/encoder.py create mode 100644 model/Informer/masking.py create mode 100644 model/Informer/model.py create mode 100644 test_informer.py create mode 100644 trainer/InformerTrainer.py diff --git a/config/Informer/AirQuality.yaml b/config/Informer/AirQuality.yaml new file mode 100644 index 0000000..4b1568a --- /dev/null +++ b/config/Informer/AirQuality.yaml @@ -0,0 +1,66 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 6 + dec_in: 6 + c_out: 6 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/BJTaxi-Inflow.yaml b/config/Informer/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..56d089d --- /dev/null +++ b/config/Informer/BJTaxi-Inflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/BJTaxi-Outflow.yaml b/config/Informer/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..875cce8 --- /dev/null +++ b/config/Informer/BJTaxi-Outflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/METR-LA.yaml b/config/Informer/METR-LA.yaml new file mode 100644 index 0000000..731fa3e --- /dev/null +++ b/config/Informer/METR-LA.yaml @@ -0,0 +1,66 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/NYCBike-Inflow.yaml b/config/Informer/NYCBike-Inflow.yaml new file mode 100644 index 0000000..30ca485 --- /dev/null +++ b/config/Informer/NYCBike-Inflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/NYCBike-Outflow.yaml b/config/Informer/NYCBike-Outflow.yaml new file mode 100644 index 0000000..9fcfe6b --- /dev/null +++ b/config/Informer/NYCBike-Outflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/PEMS-BAY.yaml b/config/Informer/PEMS-BAY.yaml new file mode 100644 index 0000000..961bd6f --- /dev/null +++ b/config/Informer/PEMS-BAY.yaml @@ -0,0 +1,66 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/SolarEnergy.yaml b/config/Informer/SolarEnergy.yaml new file mode 100644 index 0000000..0d31425 --- /dev/null +++ b/config/Informer/SolarEnergy.yaml @@ -0,0 +1,66 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 1024 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 1024 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/Informer_loader.py b/dataloader/Informer_loader.py new file mode 100644 index 0000000..1192c03 --- /dev/null +++ b/dataloader/Informer_loader.py @@ -0,0 +1,179 @@ +import numpy as np +import torch +from dataloader.data_selector import load_st_dataset +from utils.normalization import normalize_dataset + + +# ============================================================== +# MAIN ENTRY +# ============================================================== + +def get_dataloader(args, normalizer="std", single=True): + """ + Return dataloaders with x, y, x_mark, y_mark. + This version follows Informer/ETSformer official dataloader behavior. + """ + data = load_st_dataset(args) + args = args["data"] + + x, y, x_mark, y_mark = _prepare_data_with_windows(data, args) + + # --- split --- + split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio + x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) + y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) + x_mark_train, x_mark_val, x_mark_test = split_fn(x_mark, args["val_ratio"], args["test_ratio"]) + y_mark_train, y_mark_val, y_mark_test = split_fn(y_mark, args["val_ratio"], args["test_ratio"]) + + # --- normalization --- + scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) + _apply_existing_scaler(y_train, y_val, y_test, scaler, args) + + # reshape [b, t, n, c] -> [b*n, t, c] + (x_train, x_val, x_test, + y_train, y_val, y_test, + x_mark_train, x_mark_val, x_mark_test, + y_mark_train, y_mark_val, y_mark_test) = _reshape_tensor( + x_train, x_val, x_test, + y_train, y_val, y_test, + x_mark_train, x_mark_val, x_mark_test, + y_mark_train, y_mark_val, y_mark_test + ) + + # --- dataloaders --- + return ( + _create_dataloader(x_train, y_train, x_mark_train, y_mark_train, + args["batch_size"], True, False), + _create_dataloader(x_val, y_val, x_mark_val, y_mark_val, + args["batch_size"], False, False), + _create_dataloader(x_test, y_test, x_mark_test, y_mark_test, + args["batch_size"], False, False), + scaler + ) + + +# ============================================================== +# Informer-style WINDOW GENERATION +# ============================================================== + +def _prepare_data_with_windows(data, args): + """ + Generate x, y, x_mark, y_mark using Informer slicing rule. + + x: [seq_len] + y: [label_len + pred_len] + """ + seq_len = args["lag"] + label_len = args["label_len"] + pred_len = args["horizon"] + + L, N, C = data.shape + + # ---------- construct timestamp features ---------- + time_in_day, day_in_week = _generate_time_features(L, args) + data_mark = np.concatenate([time_in_day, day_in_week], axis=-1) + + xs, ys, x_marks, y_marks = [], [], [], [] + + for s_begin in range(L - seq_len - pred_len - 1): + s_end = s_begin + seq_len + r_begin = s_end - label_len + r_end = r_begin + label_len + pred_len + + xs.append(data[s_begin:s_end]) + ys.append(data[r_begin:r_end]) + + x_marks.append(data_mark[s_begin:s_end]) + y_marks.append(data_mark[r_begin:r_end]) + + return np.array(xs), np.array(ys), np.array(x_marks), np.array(y_marks) + + +# ============================================================== +# TIME FEATURE +# ============================================================== + +def _generate_time_features(L, args): + N = args["num_nodes"] + + # Time in day + tid = np.array([i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)]) + tid = np.tile(tid[:, None], (1, N)) + + # Day in week + diw = np.array([(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)]) + diw = np.tile(diw[:, None], (1, N)) + + return tid[..., None], diw[..., None] + + +# ============================================================== +# NORMALIZATION +# ============================================================== + +def _normalize_data(train_data, val_data, test_data, args, normalizer): + scaler = normalize_dataset( + train_data[..., :args["input_dim"]], + normalizer, args["column_wise"] + ) + for data in [train_data, val_data, test_data]: + data[..., :args["input_dim"]] = scaler.transform( + data[..., :args["input_dim"]] + ) + return scaler + + +def _apply_existing_scaler(train_data, val_data, test_data, scaler, args): + for data in [train_data, val_data, test_data]: + data[..., :args["input_dim"]] = scaler.transform( + data[..., :args["input_dim"]] + ) + + +# ============================================================== +# DATALOADER +# ============================================================== + +def _create_dataloader(x, y, x_mark, y_mark, batch_size, shuffle, drop_last): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = torch.utils.data.TensorDataset( + torch.tensor(x, dtype=torch.float32, device=device), + torch.tensor(y, dtype=torch.float32, device=device), + torch.tensor(x_mark, dtype=torch.float32, device=device), + torch.tensor(y_mark, dtype=torch.float32, device=device), + ) + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + + +# ============================================================== +# SPLIT +# ============================================================== + +def split_data_by_days(data, val_days, test_days, interval=30): + t = int((24 * 60) / interval) + test_data = data[-t * int(test_days):] + val_data = data[-t * int(test_days + val_days):-t * int(test_days)] + train_data = data[:-t * int(test_days + val_days)] + return train_data, val_data, test_data + + +def split_data_by_ratio(data, val_ratio, test_ratio): + L = len(data) + test_data = data[-int(L * test_ratio):] + val_data = data[-int(L * (test_ratio + val_ratio)):-int(L * test_ratio)] + train_data = data[: -int(L * (test_ratio + val_ratio))] + return train_data, val_data, test_data + + +# ============================================================== +# RESHAPE [B,T,N,C] -> [B*N,T,C] +# ============================================================== + +def _reshape_tensor(*tensors): + reshaped = [] + for x in tensors: + b, t, n, c = x.shape + x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c) + reshaped.append(x_new) + return reshaped diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index 88d1e2d..c1862df 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -4,12 +4,15 @@ from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader from dataloader.TSloader import get_dataloader as TS_loader +from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] - if model_name in TS_model: + if model_name == "Informer": + return Informer_loader(config, normalizer, single) + elif model_name in TS_model: return TS_loader(config, normalizer, single) else : match model_name: diff --git a/model/Informer/attn.py b/model/Informer/attn.py new file mode 100644 index 0000000..45344a8 --- /dev/null +++ b/model/Informer/attn.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from math import sqrt +from model.Informer.masking import TriangularCausalMask, ProbMask + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1./sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward(self, queries, keys, values, attn_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2,1) + keys = keys.transpose(2,1) + values = values.transpose(2,1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + + U_part = U_part if U_part='1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular') + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2) + return x + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4; hour_size = 24 + weekday_size = 7; day_size = 32; month_size = 13 + + Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding + if freq=='t': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + + # Check the size of x's last dimension to avoid index errors + last_dim = x.shape[-1] + + minute_x = 0. + hour_x = 0. + weekday_x = 0. + day_x = 0. + month_x = 0. + + # For our generated time features, we have only 2 dimensions: [day_of_week, hour] + # So we need to map them to the appropriate embedding layers + if last_dim > 0: + # Use the first dimension for hour + # Ensure hour is in the valid range [0, 23] + hour = torch.clamp(x[:,:,0], 0, 23) + hour_x = self.hour_embed(hour) + + if last_dim > 1: + # Use the second dimension for weekday + # Ensure weekday is in the valid range [0, 6] + weekday = torch.clamp(x[:,:,1], 0, 6) + weekday_x = self.weekday_embed(weekday) + + return hour_x + weekday_x + day_x + month_x + minute_x + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model) + + def forward(self, x): + return self.embed(x) + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + a = self.value_embedding(x) + b = self.position_embedding(x) + c = self.temporal_embedding(x_mark) + x = a + b + c + + return self.dropout(x) \ No newline at end of file diff --git a/model/Informer/encoder.py b/model/Informer/encoder.py new file mode 100644 index 0000000..7aeb877 --- /dev/null +++ b/model/Informer/encoder.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + padding = 1 if torch.__version__>='1.5.0' else 2 + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=padding, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1,2) + return x + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + # x [B, L, D] + # x = x + self.dropout(self.attention( + # x, x, x, + # attn_mask = attn_mask + # )) + new_x, attn = self.attention( + x, x, x, + attn_mask = attn_mask + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) + y = self.dropout(self.conv2(y).transpose(-1,1)) + + return self.norm2(x+y), attn + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + +class EncoderStack(nn.Module): + def __init__(self, encoders, inp_lens): + super(EncoderStack, self).__init__() + self.encoders = nn.ModuleList(encoders) + self.inp_lens = inp_lens + + def forward(self, x, attn_mask=None): + # x [B, L, D] + x_stack = []; attns = [] + for i_len, encoder in zip(self.inp_lens, self.encoders): + inp_len = x.shape[1]//(2**i_len) + x_s, attn = encoder(x[:, -inp_len:, :]) + x_stack.append(x_s); attns.append(attn) + x_stack = torch.cat(x_stack, -2) + + return x_stack, attns \ No newline at end of file diff --git a/model/Informer/masking.py b/model/Informer/masking.py new file mode 100644 index 0000000..7fd479e --- /dev/null +++ b/model/Informer/masking.py @@ -0,0 +1,24 @@ +import torch + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + +class ProbMask(): + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask \ No newline at end of file diff --git a/model/Informer/model.py b/model/Informer/model.py new file mode 100644 index 0000000..fb7471f --- /dev/null +++ b/model/Informer/model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack +from model.Informer.decoder import Decoder, DecoderLayer +from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer +from model.Informer.embed import DataEmbedding + +class Informer(nn.Module): + def __init__(self, args): + super(Informer, self).__init__() + self.pred_len = args['pred_len'] + self.attn = args['attn'] + self.output_attention = args['output_attention'] + + # Encoding + self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + # Attention + Attn = ProbAttention if args['attn']=='prob' else FullAttention + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'] + ) for l in range(args['e_layers']) + ], + [ + ConvLayer( + args['d_model'] + ) for l in range(args['e_layers']-1) + ] if args['distil'] else None, + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=args['mix']), + AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'], + ) + for l in range(args['d_layers']) + ], + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + dec_out = self.projection(dec_out) + + if self.output_attention: + return dec_out[:,-self.pred_len:,:], attns + else: + return dec_out[:,-self.pred_len:,:] # [B, L, D] + + +class InformerStack(nn.Module): + def __init__(self, args): + super(InformerStack, self).__init__() + self.pred_len = args['pred_len'] + self.attn = args['attn'] + self.output_attention = args['output_attention'] + + # Encoding + self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + # Attention + Attn = ProbAttention if args['attn']=='prob' else FullAttention + # Encoder + + inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here + encoders = [ + Encoder( + [ + EncoderLayer( + AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'] + ) for l in range(el) + ], + [ + ConvLayer( + args['d_model'] + ) for l in range(el-1) + ] if args['distil'] else None, + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) for el in args['e_layers']] + self.encoder = EncoderStack(encoders, inp_lens) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=args['mix']), + AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'], + ) + for l in range(args['d_layers']) + ], + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + dec_out = self.projection(dec_out) + + if self.output_attention: + return dec_out[:,-self.pred_len:,:], attns + else: + return dec_out[:,-self.pred_len:,:] # [B, L, D] \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 5621037..f74dde2 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -28,6 +28,7 @@ from model.ASTRA.astra import ASTRA as ASTRA from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer +from model.Informer.model import Informer from model.HI.HI import HI from model.PatchTST.PatchTST import Model as PatchTST from model.MTGNN.MTGNN import gtnet as MTGNN @@ -96,6 +97,8 @@ def model_selector(config): return ASTRAv3(model_config) case "iTransformer": return iTransformer(model_config) + case "Informer": + return Informer(model_config) case "HI": return HI(model_config) case "PatchTST": diff --git a/test_informer.py b/test_informer.py new file mode 100644 index 0000000..b614533 --- /dev/null +++ b/test_informer.py @@ -0,0 +1,57 @@ +import torch +from model.model_selector import model_selector +import yaml + +# 读取配置文件 +with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f: + config = yaml.safe_load(f) + +# 初始化模型 +model = model_selector(config) +print('Informer模型初始化成功!') +print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}') + +# 创建测试数据 +B, T, C = 2, 24, 6 +x_enc = torch.randn(B, T, C) + +# 测试1: 完整参数 +print('\n测试1: 完整参数') +x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维 +x_dec = torch.randn(B, 12+24, C) # label_len + pred_len +x_mark_dec = torch.randn(B, 12+24, 4) +try: + output = model(x_enc, x_mark_enc, x_dec, x_mark_dec) + print(f'输出形状: {output.shape}') + print('测试1通过!') +except Exception as e: + print(f'测试1失败: {e}') + +# 测试2: 省略x_mark_enc +print('\n测试2: 省略x_mark_enc') +try: + output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) + print(f'输出形状: {output.shape}') + print('测试2通过!') +except Exception as e: + print(f'测试2失败: {e}') + +# 测试3: 省略x_dec和x_mark_dec +print('\n测试3: 省略x_dec和x_mark_dec') +try: + output = model(x_enc, x_mark_enc=x_mark_enc) + print(f'输出形状: {output.shape}') + print('测试3通过!') +except Exception as e: + print(f'测试3失败: {e}') + +# 测试4: 仅传入x_enc +print('\n测试4: 仅传入x_enc') +try: + output = model(x_enc) + print(f'输出形状: {output.shape}') + print('测试4通过!') +except Exception as e: + print(f'测试4失败: {e}') + +print('\n所有测试完成!') \ No newline at end of file diff --git a/train.py b/train.py index 7bd72ad..9c81209 100644 --- a/train.py +++ b/train.py @@ -11,13 +11,14 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备 + device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 100 + epochs = 1 # 拷贝项 config["basic"]["device"] = device config["model"]["device"] = device + config["train"]["device"] = device config["basic"]["seed"] = seed config["train"]["epochs"] = epochs return config @@ -62,14 +63,15 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["MTGNN"] + model_list = ["Informer"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["PEMS-BAY"] # 我的调试开关,不做测试就填 str(False) - os.environ["TRY"] = str(False) - + # os.environ["TRY"] = str(False) + os.environ["TRY"] = str(True) + for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" @@ -81,7 +83,13 @@ if __name__ == "__main__": try: run(config) except Exception as e: - pass + import traceback + import sys, traceback + tb_lines = traceback.format_exc().splitlines() + # 如果不是AssertionError,才打印完整traceback + if not tb_lines[-1].startswith("AssertionError"): + traceback.print_exc() + print(f"\n===== {model} on {dataset} failed with error: {e} =====\n") else: run(config) diff --git a/trainer/InformerTrainer.py b/trainer/InformerTrainer.py new file mode 100644 index 0000000..7b7ed27 --- /dev/null +++ b/trainer/InformerTrainer.py @@ -0,0 +1,250 @@ +import math +import os +import time +import copy +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + +class InformerTrainer: + """Informer模型训练器,负责整个训练流程的管理,支持多输入模型""" + + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): + # 设备和基本参数 + self.config = args + self.device = args["basic"]["device"] + train_args = args["train"] + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + # 数据处理工具 + self.scaler = scaler + self.args = train_args + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch,支持多输入模型""" + # 设置模型模式和是否进行优化 + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + for _, (x, y, x_mark, y_mark) in progress_bar: + # 转移数据 + x = x.to(self.device) + y = y[:, -self.args['pred_len']:, :self.args["output_dim"]].to(self.device) + x_mark = x_mark.to(self.device) + y_mark = y_mark.to(self.device) + # [256, 24, 6] + dec_inp = torch.zeros_like(y[:, -self.args['pred_len']:, :]).float() + # [256, 48(pred+label), 6] + dec_inp = torch.cat([y[:, :self.args['label_len'], :], dec_inp], dim=1).float().to(self.device) + + # 计算loss和反归一化loss + output = self.model(x, x_mark, dec_inp, y_mark) + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == y.shape else '❌'}]: output: {output.shape}, label: {y.shape}") + assert False + loss = self.loss(output, y) + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(y) + d_loss = self.loss(d_output, d_label) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + # 计算损失并记录指标 + avg_loss = total_loss / len(dataloader) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + # 初始化记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + # 开始训练 + self.logger.info("Training process started") + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + # 早停 + if self._should_early_stop(not_improved_count): + break + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + # 最终评估 + self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标,支持多输入模型""" + device = args["device"] + + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + pred_len = args['pred_len'] + label_len = args['label_len'] + output_dim = args['output_dim'] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for _, (x, y, x_mark, y_mark) in enumerate(data_loader): + # 转移数据 + x = x.to(device) + y = y[:, -pred_len:, :output_dim].to(device) + x_mark = x_mark.to(device) + y_mark = y_mark.to(device) + # 生成dec_inp + dec_inp = torch.zeros_like(y[:, -pred_len:, :]).float() + dec_inp = torch.cat([y[:, :label_len, :], dec_inp], dim=1).float().to(device) + output = model(x, x_mark, dec_inp, y_mark) + y_pred.append(output.detach().cpu()) + y_true.append(y.detach().cpu()) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 97d6b0b..89340ea 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -4,6 +4,7 @@ from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer +from trainer.InformerTrainer import InformerTrainer def select_trainer( @@ -96,6 +97,18 @@ def select_trainer( args, lr_scheduler, ) + case "Informer": + return InformerTrainer( + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler, + ) case _: return Trainer( model, -- 2.40.1 From 9147803c2b18b4c226ff9c7fd07f965480ca0ae6 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 14 Dec 2025 17:47:38 +0800 Subject: [PATCH 26/41] =?UTF-8?q?=E6=94=B9=E8=BF=9Btrainer=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=BB=9F=E4=B8=80=E7=9A=84loader=E8=80=8C?= =?UTF-8?q?=E4=B8=8D=E6=98=AFTSLoader=EF=BC=8C=E4=BB=85=E5=9C=A8Trainer?= =?UTF-8?q?=E4=B8=8A=E5=81=9A=E4=BA=86shape=E5=8F=98=E6=8D=A2=EF=BC=8C?= =?UTF-8?q?=E7=A1=AE=E5=AE=9A=E5=B0=BD=E5=8F=AF=E8=83=BD=E5=B0=91=E6=94=B9?= =?UTF-8?q?=E5=8A=A8=E6=95=B0=E6=8D=AE=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/PatchTST/AirQuality.yaml | 4 +- config/PatchTST/BJTaxi-Inflow.yaml | 4 +- config/PatchTST/BJTaxi-Outflow.yaml | 4 +- config/PatchTST/METR-LA.yaml | 4 +- config/PatchTST/NYCBike-Inflow.yaml | 4 +- config/PatchTST/NYCBike-Outflow.yaml | 4 +- config/PatchTST/PEMS-BAY.yaml | 4 +- config/PatchTST/SolarEnergy.yaml | 4 +- config/iTransformer/AirQuality.yaml | 4 +- config/iTransformer/BJTaxi-Inflow.yaml | 4 +- config/iTransformer/BJTaxi-Outflow.yaml | 4 +- config/iTransformer/METR-LA.yaml | 4 +- config/iTransformer/NYCBike-Inflow.yaml | 4 +- config/iTransformer/NYCBike-Outflow.yaml | 4 +- config/iTransformer/PEMS-BAY.yaml | 4 +- config/iTransformer/SolarEnergy.yaml | 6 +- dataloader/loader_selector.py | 32 +-- test_informer.py | 57 ----- train.py | 23 +- trainer/TSTrainer.py | 296 +++++++++++++++++++++++ trainer/trainer_selector.py | 17 +- 21 files changed, 376 insertions(+), 115 deletions(-) delete mode 100644 test_informer.py create mode 100755 trainer/TSTrainer.py diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml index 3cdf977..91a497e 100644 --- a/config/PatchTST/AirQuality.yaml +++ b/config/PatchTST/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml index 576dbd6..a4e0308 100644 --- a/config/PatchTST/BJTaxi-Inflow.yaml +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 2048 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml index 773ba26..68c8476 100644 --- a/config/PatchTST/BJTaxi-Outflow.yaml +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 2048 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml index 6b9461a..3f88951 100644 --- a/config/PatchTST/METR-LA.yaml +++ b/config/PatchTST/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml index 408995c..0f7bc97 100644 --- a/config/PatchTST/NYCBike-Inflow.yaml +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml index c50f4a1..516e1e1 100644 --- a/config/PatchTST/NYCBike-Outflow.yaml +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml index e798294..ba93575 100644 --- a/config/PatchTST/PEMS-BAY.yaml +++ b/config/PatchTST/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index b1de602..d31a458 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/AirQuality.yaml b/config/iTransformer/AirQuality.yaml index 23eba27..b27d72c 100644 --- a/config/iTransformer/AirQuality.yaml +++ b/config/iTransformer/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-Inflow.yaml index dfc2df2..1df1a67 100644 --- a/config/iTransformer/BJTaxi-Inflow.yaml +++ b/config/iTransformer/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 2048 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-Outflow.yaml index d14bed5..8da0e92 100644 --- a/config/iTransformer/BJTaxi-Outflow.yaml +++ b/config/iTransformer/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 2048 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 2048 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 20c4068..996e44c 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-Inflow.yaml index 8afa656..fdb4dce 100644 --- a/config/iTransformer/NYCBike-Inflow.yaml +++ b/config/iTransformer/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-Outflow.yaml index 7abba88..7401648 100644 --- a/config/iTransformer/NYCBike-Outflow.yaml +++ b/config/iTransformer/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/PEMS-BAY.yaml b/config/iTransformer/PEMS-BAY.yaml index 17f2fd4..80d354a 100644 --- a/config/iTransformer/PEMS-BAY.yaml +++ b/config/iTransformer/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/iTransformer/SolarEnergy.yaml b/config/iTransformer/SolarEnergy.yaml index cce005a..154be4a 100644 --- a/config/iTransformer/SolarEnergy.yaml +++ b/config/iTransformer/SolarEnergy.yaml @@ -6,11 +6,11 @@ basic: seed: 2023 data: - batch_size: 256 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -31,7 +31,7 @@ model: train: - batch_size: 256 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index c1862df..5ea47fa 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -10,19 +10,19 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] - if model_name == "Informer": - return Informer_loader(config, normalizer, single) - elif model_name in TS_model: - return TS_loader(config, normalizer, single) - else : - match model_name: - case "STGNCDE": - return cde_loader(config, normalizer, single) - case "STGNRDE": - return nrde_loader(config, normalizer, single) - case "DCRNN": - return DCRNN_loader(config, normalizer, single) - case "EXP": - return EXP_loader(config, normalizer, single) - case _: - return normal_loader(config, normalizer, single) + # if model_name == "Informer": + # return Informer_loader(config, normalizer, single) + # elif model_name in TS_model: + # return TS_loader(config, normalizer, single) + # else : + match model_name: + case "STGNCDE": + return cde_loader(config, normalizer, single) + case "STGNRDE": + return nrde_loader(config, normalizer, single) + case "DCRNN": + return DCRNN_loader(config, normalizer, single) + case "EXP": + return EXP_loader(config, normalizer, single) + case _: + return normal_loader(config, normalizer, single) diff --git a/test_informer.py b/test_informer.py deleted file mode 100644 index b614533..0000000 --- a/test_informer.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -from model.model_selector import model_selector -import yaml - -# 读取配置文件 -with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f: - config = yaml.safe_load(f) - -# 初始化模型 -model = model_selector(config) -print('Informer模型初始化成功!') -print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}') - -# 创建测试数据 -B, T, C = 2, 24, 6 -x_enc = torch.randn(B, T, C) - -# 测试1: 完整参数 -print('\n测试1: 完整参数') -x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维 -x_dec = torch.randn(B, 12+24, C) # label_len + pred_len -x_mark_dec = torch.randn(B, 12+24, 4) -try: - output = model(x_enc, x_mark_enc, x_dec, x_mark_dec) - print(f'输出形状: {output.shape}') - print('测试1通过!') -except Exception as e: - print(f'测试1失败: {e}') - -# 测试2: 省略x_mark_enc -print('\n测试2: 省略x_mark_enc') -try: - output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) - print(f'输出形状: {output.shape}') - print('测试2通过!') -except Exception as e: - print(f'测试2失败: {e}') - -# 测试3: 省略x_dec和x_mark_dec -print('\n测试3: 省略x_dec和x_mark_dec') -try: - output = model(x_enc, x_mark_enc=x_mark_enc) - print(f'输出形状: {output.shape}') - print('测试3通过!') -except Exception as e: - print(f'测试3失败: {e}') - -# 测试4: 仅传入x_enc -print('\n测试4: 仅传入x_enc') -try: - output = model(x_enc) - print(f'输出形状: {output.shape}') - print('测试4通过!') -except Exception as e: - print(f'测试4失败: {e}') - -print('\n所有测试完成!') \ No newline at end of file diff --git a/train.py b/train.py index 9c81209..df31ca8 100644 --- a/train.py +++ b/train.py @@ -6,14 +6,16 @@ import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer +import cProfile + def read_config(config_path): with open(config_path, "r") as file: config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cuda:1" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 + epochs = 120 # 拷贝项 config["basic"]["device"] = device @@ -60,17 +62,17 @@ def run(config): case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") - -if __name__ == "__main__": +def main(debug=False): # 指定模型 - model_list = ["Informer"] + model_list = ["PatchTST"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["PEMS-BAY"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) - os.environ["TRY"] = str(True) + os.environ["TRY"] = str(debug) for model in model_list: for dataset in dataset_list: @@ -93,3 +95,8 @@ if __name__ == "__main__": else: run(config) + + +if __name__ == "__main__": + # 调试用 + main(debug = False) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py new file mode 100755 index 0000000..b8def31 --- /dev/null +++ b/trainer/TSTrainer.py @@ -0,0 +1,296 @@ +import math +import os +import time +import copy +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + +class TSWrapper: + def __init__(self, args): + self.b = args['train']['batch_size'] + self.t = args['data']['lag'] + self.n = args['data']['num_nodes'] + self.c = args['data']['input_dim'] + + + def transpose(self, x : torch.Tensor): + # [b, t, n, c] -> [b*n, t, c] + self.b = x.shape[0] + x = x[..., :-2] + x = x.permute(0, 2, 1, 3) + x = x.reshape(self.b*self.n, self.t, self.c) + return x + + def inv_transpose(self, x : torch.Tensor): + x = x.reshape(self.b, self.n, self.t, self.c) + x = x.permute(0, 2, 1, 3) + return x + + +class Trainer: + """模型训练器,负责整个训练流程的管理""" + + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): + # 设备和基本参数 + self.config = args + self.device = args["basic"]["device"] + train_args = args["train"] + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + # 数据处理工具 + self.scaler = scaler + self.args = train_args + self.ts_wrapper = TSWrapper(args) + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + for _, (data, target) in progress_bar: + # 转移数据 + data = data.to(self.device) + target = target.to(self.device) + label = target[..., : self.args["output_dim"]] + # 转换为 [b*n, t, c] + data = self.ts_wrapper.transpose(data) + # 计算loss和反归一化loss + output = self.model(data) + # 转换回[b, t, n, c] + output = self.ts_wrapper.inv_transpose(output) + # 我的调试开关 + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") + assert False + loss = self.loss(output, label) + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + d_loss = self.loss(d_output, d_label) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + # 计算损失并记录指标 + avg_loss = total_loss / len(dataloader) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + # 初始化记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + # 开始训练 + self.logger.info("Training process started") + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + # 早停 + if self._should_early_stop(not_improved_count): + break + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + # 最终评估 + self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况,从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + + # 加载模型检查点(如果提供了路径) + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for data, target in data_loader: + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + data = data[..., :-2] + b, t, n, c = data.shape + data = data.permute(0, 2, 1, 3) + data = data.reshape(b*n, t, c) + label = target[..., : output_dim] + output = model(data) + output = output.reshape(b, n, t, c) + output = output.permute(0, 2, 1, 3) + + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 89340ea..723b257 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -5,7 +5,7 @@ from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer - +from trainer.TSTrainer import Trainer as TSTrainer def select_trainer( model, @@ -20,6 +20,21 @@ def select_trainer( kwargs, ): model_name = args["basic"]["model"] + TS_model = ["HI", "PatchTST", "iTransformer"] + if model_name in TS_model: + return TSTrainer( + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler, + ) + + match model_name: case "STGNCDE": return cdeTrainer( -- 2.40.1 From 5827554c735f1a0491ca32232ec73ab6f0c01b9b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 14 Dec 2025 17:48:37 +0800 Subject: [PATCH 27/41] =?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=EF=BC=8C=E5=8A=A8=E6=80=81=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/AGCRN/model_config.json | 7 ++ model/ARIMA/model_config.json | 7 ++ model/ASTRA/model_config.json | 17 +++ model/DCRNN/model_config.json | 7 ++ model/DDGCRN/model_config.json | 7 ++ model/DSANET/model_config.json | 7 ++ model/EXP/model_config.json | 7 ++ model/GWN/model_config.json | 7 ++ model/HI/model_config.json | 7 ++ model/Informer/model_config.json | 7 ++ model/MTGNN/model_config.json | 7 ++ model/MegaCRN/model_config.json | 7 ++ model/NLT/model_config.json | 7 ++ model/PDG2SEQ/model_config.json | 7 ++ model/PatchTST/model_config.json | 7 ++ model/README.md | 109 +++++++++++++++++++ model/REPST/model_config.json | 7 ++ model/STAEFormer/model_config.json | 7 ++ model/STAWnet/model_config.json | 7 ++ model/STFGNN/model_config.json | 7 ++ model/STGCN/model_config.json | 7 ++ model/STGNCDE/model_config.json | 7 ++ model/STGNRDE/model_config.json | 7 ++ model/STGODE/model_config.json | 7 ++ model/STID/model_config.json | 7 ++ model/STIDGCN/model_config.json | 7 ++ model/STMLP/model_config.json | 7 ++ model/STSGCN/model_config.json | 7 ++ model/ST_SSL/model_config.json | 7 ++ model/TCN/model_config.json | 7 ++ model/TWDGCN/model_config.json | 7 ++ model/iTransformer/model_config.json | 7 ++ model/model_selector.py | 152 +++++++++------------------ train.py | 6 +- 34 files changed, 390 insertions(+), 104 deletions(-) create mode 100644 model/AGCRN/model_config.json create mode 100644 model/ARIMA/model_config.json create mode 100644 model/ASTRA/model_config.json create mode 100644 model/DCRNN/model_config.json create mode 100644 model/DDGCRN/model_config.json create mode 100644 model/DSANET/model_config.json create mode 100644 model/EXP/model_config.json create mode 100644 model/GWN/model_config.json create mode 100644 model/HI/model_config.json create mode 100644 model/Informer/model_config.json create mode 100644 model/MTGNN/model_config.json create mode 100644 model/MegaCRN/model_config.json create mode 100644 model/NLT/model_config.json create mode 100644 model/PDG2SEQ/model_config.json create mode 100644 model/PatchTST/model_config.json create mode 100644 model/README.md create mode 100644 model/REPST/model_config.json create mode 100644 model/STAEFormer/model_config.json create mode 100644 model/STAWnet/model_config.json create mode 100644 model/STFGNN/model_config.json create mode 100644 model/STGCN/model_config.json create mode 100644 model/STGNCDE/model_config.json create mode 100644 model/STGNRDE/model_config.json create mode 100644 model/STGODE/model_config.json create mode 100644 model/STID/model_config.json create mode 100644 model/STIDGCN/model_config.json create mode 100644 model/STMLP/model_config.json create mode 100644 model/STSGCN/model_config.json create mode 100644 model/ST_SSL/model_config.json create mode 100644 model/TCN/model_config.json create mode 100644 model/TWDGCN/model_config.json create mode 100644 model/iTransformer/model_config.json diff --git a/model/AGCRN/model_config.json b/model/AGCRN/model_config.json new file mode 100644 index 0000000..e1c9b61 --- /dev/null +++ b/model/AGCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "AGCRN", + "module": "model.AGCRN.AGCRN", + "entry": "AGCRN" + } +] \ No newline at end of file diff --git a/model/ARIMA/model_config.json b/model/ARIMA/model_config.json new file mode 100644 index 0000000..9b33c5c --- /dev/null +++ b/model/ARIMA/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "ARIMA", + "module": "model.ARIMA.ARIMA", + "entry": "ARIMA" + } +] \ No newline at end of file diff --git a/model/ASTRA/model_config.json b/model/ASTRA/model_config.json new file mode 100644 index 0000000..3cd0064 --- /dev/null +++ b/model/ASTRA/model_config.json @@ -0,0 +1,17 @@ +[ + { + "name": "ASTRA", + "module": "model.ASTRA.astra", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v2", + "module": "model.ASTRA.astrav2", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v3", + "module": "model.ASTRA.astrav3", + "entry": "ASTRA" + } +] \ No newline at end of file diff --git a/model/DCRNN/model_config.json b/model/DCRNN/model_config.json new file mode 100644 index 0000000..c92b599 --- /dev/null +++ b/model/DCRNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DCRNN", + "module": "model.DCRNN.dcrnn_model", + "entry": "DCRNNModel" + } +] \ No newline at end of file diff --git a/model/DDGCRN/model_config.json b/model/DDGCRN/model_config.json new file mode 100644 index 0000000..a07fc3a --- /dev/null +++ b/model/DDGCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DDGCRN", + "module": "model.DDGCRN.DDGCRN", + "entry": "DDGCRN" + } +] \ No newline at end of file diff --git a/model/DSANET/model_config.json b/model/DSANET/model_config.json new file mode 100644 index 0000000..5624f8a --- /dev/null +++ b/model/DSANET/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "DSANET", + "module": "model.DSANET.DSANET", + "entry": "DSANet" + } +] \ No newline at end of file diff --git a/model/EXP/model_config.json b/model/EXP/model_config.json new file mode 100644 index 0000000..bdf39b7 --- /dev/null +++ b/model/EXP/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "EXP", + "module": "model.EXP.EXP32", + "entry": "EXP" + } +] \ No newline at end of file diff --git a/model/GWN/model_config.json b/model/GWN/model_config.json new file mode 100644 index 0000000..38d05b4 --- /dev/null +++ b/model/GWN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "GWN", + "module": "model.GWN.GraphWaveNet", + "entry": "gwnet" + } +] \ No newline at end of file diff --git a/model/HI/model_config.json b/model/HI/model_config.json new file mode 100644 index 0000000..3071864 --- /dev/null +++ b/model/HI/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "HI", + "module": "model.HI.HI", + "entry": "HI" + } +] \ No newline at end of file diff --git a/model/Informer/model_config.json b/model/Informer/model_config.json new file mode 100644 index 0000000..3836cd0 --- /dev/null +++ b/model/Informer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "Informer", + "module": "model.Informer.model", + "entry": "Informer" + } +] \ No newline at end of file diff --git a/model/MTGNN/model_config.json b/model/MTGNN/model_config.json new file mode 100644 index 0000000..94aa32c --- /dev/null +++ b/model/MTGNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "MTGNN", + "module": "model.MTGNN.MTGNN", + "entry": "gtnet" + } +] \ No newline at end of file diff --git a/model/MegaCRN/model_config.json b/model/MegaCRN/model_config.json new file mode 100644 index 0000000..e8c0599 --- /dev/null +++ b/model/MegaCRN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "MegaCRN", + "module": "model.MegaCRN.MegaCRNModel", + "entry": "MegaCRNModel" + } +] \ No newline at end of file diff --git a/model/NLT/model_config.json b/model/NLT/model_config.json new file mode 100644 index 0000000..a99a6b1 --- /dev/null +++ b/model/NLT/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "NLT", + "module": "model.NLT.HierAttnLstm", + "entry": "HierAttnLstm" + } +] \ No newline at end of file diff --git a/model/PDG2SEQ/model_config.json b/model/PDG2SEQ/model_config.json new file mode 100644 index 0000000..783f3bf --- /dev/null +++ b/model/PDG2SEQ/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "PDG2SEQ", + "module": "model.PDG2SEQ.PDG2Seqb", + "entry": "PDG2Seq" + } +] \ No newline at end of file diff --git a/model/PatchTST/model_config.json b/model/PatchTST/model_config.json new file mode 100644 index 0000000..d613fbb --- /dev/null +++ b/model/PatchTST/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "PatchTST", + "module": "model.PatchTST.PatchTST", + "entry": "Model" + } +] \ No newline at end of file diff --git a/model/README.md b/model/README.md new file mode 100644 index 0000000..24dd3ca --- /dev/null +++ b/model/README.md @@ -0,0 +1,109 @@ +# 模型注册说明 + +## 概述 + +本项目使用基于配置文件的模型注册机制,每个模型目录下的 `model_config.json` 文件用于注册该目录下的模型。 + +## model_config.json 格式 + +### 基本格式 + +每个 `model_config.json` 文件是一个 JSON 数组,包含一个或多个模型配置对象: + +```json +[ + { + "name": "模型名称", + "module": "模型模块路径", + "entry": "模型入口点" + } +] +``` + +### 字段说明 + +- **name**: 模型的唯一标识符,用于在配置文件中选择模型 +- **module**: 模型所在的模块路径,使用 Python 导入格式 +- **entry**: 模型的入口点,可以是类名或函数名 + +### 示例 + +#### 1. 单个模型 + +```json +[ + { + "name": "DDGCRN", + "module": "model.DDGCRN.DDGCRN", + "entry": "DDGCRN" + } +] +``` + +#### 2. 多个模型(同一目录下的不同版本) + +```json +[ + { + "name": "ASTRA", + "module": "model.ASTRA.astra", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v2", + "module": "model.ASTRA.astrav2", + "entry": "ASTRA" + }, + { + "name": "ASTRA_v3", + "module": "model.ASTRA.astrav3", + "entry": "ASTRA" + } +] +``` + +#### 3. 函数模型 + +```json +[ + { + "name": "STGNCDE", + "module": "model.STGNCDE.Make_model", + "entry": "make_model" + } +] +``` + +## 添加新模型 + +1. 在 `model` 目录下创建模型目录 +2. 在该目录下实现模型代码 +3. 创建 `model_config.json` 文件,配置模型信息 +4. 在配置文件中使用模型名称选择模型 + +## 注意事项 + +1. 模型名称必须唯一,不允许重复 +2. 模块路径必须是正确的 Python 导入路径 +3. 入口点必须是模块中存在的类或函数 +4. 配置文件必须是有效的 JSON 格式 +5. 每个模型目录下只能有一个 `model_config.json` 文件 + +## 模型选择 + +在配置文件中,通过 `basic.model` 字段指定要使用的模型名称: + +```json +{ + "basic": { + "model": "ASTRA" + }, + "model": { + // 模型特定配置 + } +} +``` + +## 冲突检测 + +系统会自动检测模型名冲突,如有冲突会抛出 `AssertionError` 并显示冲突信息。 diff --git a/model/REPST/model_config.json b/model/REPST/model_config.json new file mode 100644 index 0000000..5bdfce6 --- /dev/null +++ b/model/REPST/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "REPST", + "module": "model.REPST.repst", + "entry": "repst" + } +] \ No newline at end of file diff --git a/model/STAEFormer/model_config.json b/model/STAEFormer/model_config.json new file mode 100644 index 0000000..8823a88 --- /dev/null +++ b/model/STAEFormer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STAEFormer", + "module": "model.STAEFormer.STAEFormer", + "entry": "STAEformer" + } +] \ No newline at end of file diff --git a/model/STAWnet/model_config.json b/model/STAWnet/model_config.json new file mode 100644 index 0000000..0e83de9 --- /dev/null +++ b/model/STAWnet/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STAWnet", + "module": "model.STAWnet.STAWnet", + "entry": "STAWnet" + } +] \ No newline at end of file diff --git a/model/STFGNN/model_config.json b/model/STFGNN/model_config.json new file mode 100644 index 0000000..ef5bd7e --- /dev/null +++ b/model/STFGNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STFGNN", + "module": "model.STFGNN.STFGNN", + "entry": "STFGNN" + } +] \ No newline at end of file diff --git a/model/STGCN/model_config.json b/model/STGCN/model_config.json new file mode 100644 index 0000000..af5885a --- /dev/null +++ b/model/STGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGCN", + "module": "model.STGCN.models", + "entry": "STGCNChebGraphConv" + } +] \ No newline at end of file diff --git a/model/STGNCDE/model_config.json b/model/STGNCDE/model_config.json new file mode 100644 index 0000000..3ec8745 --- /dev/null +++ b/model/STGNCDE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGNCDE", + "module": "model.STGNCDE.Make_model", + "entry": "make_model" + } +] \ No newline at end of file diff --git a/model/STGNRDE/model_config.json b/model/STGNRDE/model_config.json new file mode 100644 index 0000000..ec655a8 --- /dev/null +++ b/model/STGNRDE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGNRDE", + "module": "model.STGNRDE.Make_model", + "entry": "make_model" + } +] \ No newline at end of file diff --git a/model/STGODE/model_config.json b/model/STGODE/model_config.json new file mode 100644 index 0000000..d6a03e2 --- /dev/null +++ b/model/STGODE/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STGODE", + "module": "model.STGODE.STGODE", + "entry": "ODEGCN" + } +] \ No newline at end of file diff --git a/model/STID/model_config.json b/model/STID/model_config.json new file mode 100644 index 0000000..1a39d87 --- /dev/null +++ b/model/STID/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STID", + "module": "model.STID.STID", + "entry": "STID" + } +] \ No newline at end of file diff --git a/model/STIDGCN/model_config.json b/model/STIDGCN/model_config.json new file mode 100644 index 0000000..a986383 --- /dev/null +++ b/model/STIDGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STIDGCN", + "module": "model.STIDGCN.STIDGCN", + "entry": "STIDGCN" + } +] \ No newline at end of file diff --git a/model/STMLP/model_config.json b/model/STMLP/model_config.json new file mode 100644 index 0000000..e7cfb08 --- /dev/null +++ b/model/STMLP/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STMLP", + "module": "model.STMLP.STMLP", + "entry": "STMLP" + } +] \ No newline at end of file diff --git a/model/STSGCN/model_config.json b/model/STSGCN/model_config.json new file mode 100644 index 0000000..a5e2b4d --- /dev/null +++ b/model/STSGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STSGCN", + "module": "model.STSGCN.STSGCN", + "entry": "STSGCN" + } +] \ No newline at end of file diff --git a/model/ST_SSL/model_config.json b/model/ST_SSL/model_config.json new file mode 100644 index 0000000..8bbfb74 --- /dev/null +++ b/model/ST_SSL/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "ST_SSL", + "module": "model.ST_SSL.ST_SSL", + "entry": "STSSLModel" + } +] \ No newline at end of file diff --git a/model/TCN/model_config.json b/model/TCN/model_config.json new file mode 100644 index 0000000..d083150 --- /dev/null +++ b/model/TCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "TCN", + "module": "model.TCN.TCN", + "entry": "TemporalConvNet" + } +] \ No newline at end of file diff --git a/model/TWDGCN/model_config.json b/model/TWDGCN/model_config.json new file mode 100644 index 0000000..92f3167 --- /dev/null +++ b/model/TWDGCN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "TWDGCN", + "module": "model.TWDGCN.TWDGCN", + "entry": "TWDGCN" + } +] \ No newline at end of file diff --git a/model/iTransformer/model_config.json b/model/iTransformer/model_config.json new file mode 100644 index 0000000..79c8db5 --- /dev/null +++ b/model/iTransformer/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "iTransformer", + "module": "model.iTransformer.iTransformer", + "entry": "iTransformer" + } +] \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index f74dde2..9afd0ff 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -1,107 +1,57 @@ -from model.DDGCRN.DDGCRN import DDGCRN -from model.HI import HI -from model.TWDGCN.TWDGCN import TWDGCN -from model.AGCRN.AGCRN import AGCRN -from model.NLT.HierAttnLstm import HierAttnLstm -from model.STGNCDE.Make_model import make_model -from model.DSANET.DSANET import DSANet -from model.STGCN.models import STGCNChebGraphConv -from model.DCRNN.dcrnn_model import DCRNNModel -from model.ARIMA.ARIMA import ARIMA -from model.TCN.TCN import TemporalConvNet -from model.GWN.GraphWaveNet import gwnet -from model.STFGNN.STFGNN import STFGNN -from model.STSGCN.STSGCN import STSGCN -from model.STGODE.STGODE import ODEGCN -from model.PDG2SEQ.PDG2Seqb import PDG2Seq -from model.STMLP.STMLP import STMLP -from model.STIDGCN.STIDGCN import STIDGCN -from model.STID.STID import STID -from model.STAEFormer.STAEFormer import STAEformer -from model.EXP.EXP32 import EXP as EXP -from model.MegaCRN.MegaCRNModel import MegaCRNModel -from model.ST_SSL.ST_SSL import STSSLModel -from model.STGNRDE.Make_model import make_model as make_nrde_model -from model.STAWnet.STAWnet import STAWnet -from model.REPST.repst import repst as REPST -from model.ASTRA.astra import ASTRA as ASTRA -from model.ASTRA.astrav2 import ASTRA as ASTRAv2 -from model.ASTRA.astrav3 import ASTRA as ASTRAv3 -from model.iTransformer.iTransformer import iTransformer -from model.Informer.model import Informer -from model.HI.HI import HI -from model.PatchTST.PatchTST import Model as PatchTST -from model.MTGNN.MTGNN import gtnet as MTGNN +import os +import json +import importlib +import sys +from pathlib import Path +class ModelRegistry: + def __init__(self): + self.models = {} + self.model_configs = {} + self.model_dir = Path(__file__).parent + self._load_model_configs() + + def _load_model_configs(self): + """加载所有model_config.json文件""" + # 直接遍历所有model_config.json文件 + for config_path in self.model_dir.rglob("model_config.json"): + # 读取配置文件 + with open(config_path, 'r') as f: + configs = json.load(f) + + # 处理每个模型配置 + for config in configs: + model_name = config["name"] + # 检查模型名冲突 + assert model_name not in self.model_configs, f"模型名冲突: {model_name} 已存在,冲突文件: {config_path}" + self.model_configs[model_name] = config + + def _load_model(self, model_name): + """动态加载模型""" + if model_name not in self.model_configs: + raise ValueError(f"模型 {model_name} 未注册") + + config = self.model_configs[model_name] + module = importlib.import_module(config["module"]) + model_cls = getattr(module, config["entry"]) + self.models[model_name] = model_cls + + def get_model(self, model_name): + """获取模型类或函数""" + if model_name not in self.models: + self._load_model(model_name) + return self.models[model_name] +# 初始化模型注册表 +model_registry = ModelRegistry() def model_selector(config): model_name = config["basic"]["model"] model_config = config["model"] - match model_name: - case "DDGCRN": - return DDGCRN(model_config) - case "TWDGCN": - return TWDGCN(model_config) - case "AGCRN": - return AGCRN(model_config) - case "NLT": - return HierAttnLstm(model_config) - case "STGNCDE": - return make_model(model_config) - case "DSANET": - return DSANet(model_config) - case "STGCN": - return STGCNChebGraphConv(model_config) - case "DCRNN": - return DCRNNModel(model_config) - case "ARIMA": - return ARIMA(model_config) - case "TCN": - return TemporalConvNet(model_config) - case "GWN": - return gwnet(model_config) - case "STFGNN": - return STFGNN(model_config) - case "STSGCN": - return STSGCN(model_config) - case "STGODE": - return ODEGCN(model_config) - case "PDG2SEQ": - return PDG2Seq(model_config) - case "STMLP": - return STMLP(model_config) - case "STIDGCN": - return STIDGCN(model_config) - case "STID": - return STID(model_config) - case "STAEFormer": - return STAEformer(model_config) - case "EXP": - return EXP(model_config) - case "MegaCRN": - return MegaCRNModel(model_config) - case "ST_SSL": - return STSSLModel(model_config) - case "STGNRDE": - return make_nrde_model(model_config) - case "STAWnet": - return STAWnet(model_config) - case "REPST": - return REPST(model_config) - case "ASTRA": - return ASTRA(model_config) - case "ASTRA_v2": - return ASTRAv2(model_config) - case "ASTRA_v3": - return ASTRAv3(model_config) - case "iTransformer": - return iTransformer(model_config) - case "Informer": - return Informer(model_config) - case "HI": - return HI(model_config) - case "PatchTST": - return PatchTST(model_config) - case "MTGNN": - return MTGNN(model_config) + + model_cls = model_registry.get_model(model_name) + model = model_cls(model_config) + # print(f"\n=== 模型选择结果 ===") + print(f"选择的模型: {model_name}") + print(f"模型入口: {model_registry.model_configs[model_name]['module']}:{model_registry.model_configs[model_name]['entry']}") + return model diff --git a/train.py b/train.py index 9c81209..5beb472 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 + epochs = 100 # 拷贝项 config["basic"]["device"] = device @@ -63,14 +63,14 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["Informer"] + model_list = ["iTransformer"] # 指定数据集 dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] # dataset_list = ["PEMS-BAY"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) - os.environ["TRY"] = str(True) + os.environ["TRY"] = str(False) for model in model_list: for dataset in dataset_list: -- 2.40.1 From 97743dfd05941a7ef617985998d43f4c7a904eea Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 00:38:43 +0800 Subject: [PATCH 28/41] opt trainer --- train.py | 10 ++- trainer/Trainer.py | 163 ++++++++++++++++++++------------------------- 2 files changed, 78 insertions(+), 95 deletions(-) diff --git a/train.py b/train.py index 76ea652..acd0e60 100644 --- a/train.py +++ b/train.py @@ -6,14 +6,12 @@ import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer -import cProfile - def read_config(config_path): with open(config_path, "r") as file: config = yaml.safe_load(file) # 全局配置 - device = "cuda:1" # 指定设备为cuda:0 + device = "cpu" # 指定设备为cuda:0 seed = 2023 # 随机种子 epochs = 120 @@ -67,8 +65,8 @@ def main(debug=False): model_list = ["iTransformer"] # 指定数据集 # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] + dataset_list = ["AirQuality"] + # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) @@ -99,4 +97,4 @@ def main(debug=False): if __name__ == "__main__": # 调试用 - main(debug = False) \ No newline at end of file + main(debug = True) \ No newline at end of file diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 3372873..04842ba 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -10,47 +10,45 @@ from tqdm import tqdm class Trainer: """模型训练器,负责整个训练流程的管理""" - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, scaler, - args, lr_scheduler=None,): + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): # 设备和基本参数 self.config = args self.device = args["basic"]["device"] - train_args = args["train"] + self.args = args["train"] + # 模型和训练相关组件 - self.model = model - self.loss = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler + # 数据加载器 - self.train_loader = train_loader - self.val_loader = val_loader - self.test_loader = test_loader + self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader, test_loader + # 数据处理工具 self.scaler = scaler - self.args = train_args + # 初始化路径、日志和统计 - self._initialize_paths(train_args) - self._initialize_logger(train_args) + self._initialize_paths(self.args) + self._initialize_logger(self.args) def _initialize_paths(self, args): """初始化模型保存路径""" - self.best_path = os.path.join(args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + log_dir = args["log_dir"] + self.best_path = os.path.join(log_dir, "best_model.pth") + self.best_test_path = os.path.join(log_dir, "best_test_model.pth") + self.loss_figure_path = os.path.join(log_dir, "loss.png") def _initialize_logger(self, args): """初始化日志记录器""" - if not os.path.isdir(args["log_dir"]) and not args["debug"]: - os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) - self.logger.info(f"Experiment log path in: {args['log_dir']}") + log_dir = args["log_dir"] + if not args["debug"]: + os.makedirs(log_dir, exist_ok=True) + self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=args["debug"]) + self.logger.info(f"Experiment log path in: {log_dir}") def _run_epoch(self, epoch, dataloader, mode): """运行一个训练/验证/测试epoch""" # 设置模型模式和是否进行优化 - if mode == "train": self.model.train(); optimizer_step = True - else: self.model.eval(); optimizer_step = False + self.model.train() if mode == "train" else self.model.eval() + optimizer_step = mode == "train" # 初始化变量 total_loss = 0 @@ -60,105 +58,111 @@ class Trainer: # 训练/验证循环 with torch.set_grad_enabled(optimizer_step): progress_bar = tqdm( - enumerate(dataloader), + dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" ) - for _, (data, target) in progress_bar: - # 转移数据 - data = data.to(self.device) - target = target.to(self.device) + for data, target in progress_bar: + # 转移数据并提取标签 + data, target = data.to(self.device), target.to(self.device) label = target[..., : self.args["output_dim"]] - # 计算loss和反归一化loss + + # 计算输出 output = self.model(data) + # 我的调试开关 if os.environ.get("TRY") == "True": - print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") + status = '✅' if output.shape == label.shape else '❌' + print(f"[{status}]: output: {output.shape}, label: {label.shape}") assert False + + # 计算损失 loss = self.loss(output, label) d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) d_loss = self.loss(d_output, d_label) + # 累积损失和预测结果 total_loss += d_loss.item() y_pred.append(d_output.detach().cpu()) y_true.append(d_label.detach().cpu()) + # 反向传播和优化(仅在训练模式) if optimizer_step and self.optimizer is not None: self.optimizer.zero_grad() loss.backward() + # 梯度裁剪(如果需要) if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) # 合并所有批次的预测结果 - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) + # 计算损失并记录指标 avg_loss = total_loss / len(dataloader) mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( f"Epoch #{epoch:02d}: {mode.capitalize():<5} " f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" ) + return avg_loss - def train_epoch(self, epoch): - return self._run_epoch(epoch, self.train_loader, "train") - - def val_epoch(self, epoch): - return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - - def test_epoch(self, epoch): - return self._run_epoch(epoch, self.test_loader, "test") - def train(self): # 初始化记录 - best_model, best_test_model = None, None - best_loss, best_test_loss = float("inf"), float("inf") + best_model = best_test_model = None + best_loss = best_test_loss = float("inf") not_improved_count = 0 + # 开始训练 self.logger.info("Training process started") + # 训练循环 for epoch in range(1, self.args["epochs"] + 1): # 训练、验证和测试一个epoch - train_epoch_loss = self.train_epoch(epoch) - val_epoch_loss = self.val_epoch(epoch) - test_epoch_loss = self.test_epoch(epoch) + train_epoch_loss = self._run_epoch(epoch, self.train_loader, "train") + val_epoch_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + test_epoch_loss = self._run_epoch(epoch, self.test_loader, "test") + # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break + # 更新最佳验证模型 if val_epoch_loss < best_loss: - best_loss = val_epoch_loss - not_improved_count = 0 + best_loss, not_improved_count = val_epoch_loss, 0 best_model = copy.deepcopy(self.model.state_dict()) self.logger.info("Best validation model saved!") else: not_improved_count += 1 - # 早停 + + # 早停检查 if self._should_early_stop(not_improved_count): break + # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 if not self.args["debug"]: self._save_best_models(best_model, best_test_model) + # 最终评估 self._finalize_training(best_model, best_test_model) def _should_early_stop(self, not_improved_count): """检查是否满足早停条件""" - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): + if self.args["early_stop"] and not_improved_count == self.args["early_stop_patience"]: self.logger.info( f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." ) @@ -190,58 +194,43 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): """对模型进行评估并输出性能指标""" - # 确定设备信息 - device = None - output_dim = None - # 处理不同的参数格式 - if isinstance(args, dict): - if "basic" in args: - # 完整配置情况 - device = args["basic"]["device"] - output_dim = args["train"]["output_dim"] - else: - # 只有train_args情况,从模型获取设备 - device = next(model.parameters()).device - output_dim = args["output_dim"] - else: + # 验证参数类型 + if not isinstance(args, dict): raise ValueError(f"Unsupported args type: {type(args)}") + # 确定设备和输出维度 + is_full_config = "basic" in args + device = args["basic"]["device"] if is_full_config else next(model.parameters()).device + output_dim = args["train"]["output_dim"] if is_full_config else args["output_dim"] + + # 获取metrics参数 + train_args = args["train"] if is_full_config else args + mae_thresh, mape_thresh = train_args["mae_thresh"], train_args["mape_thresh"] + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) model.to(device) - # 设置为评估模式 + # 设置为评估模式并收集预测结果 model.eval() - - # 收集预测和真实标签 y_pred, y_true = [], [] # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: # 将数据和标签移动到指定设备 - data = data.to(device) - target = target.to(device) - + data, target = data.to(device), target.to(device) label = target[..., : output_dim] + output = model(data) y_pred.append(output.detach().cpu()) y_true.append(label.detach().cpu()) + # 反归一化并计算指标 d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - - # 获取metrics参数 - if "basic" in args: - # 完整配置情况 - mae_thresh = args["train"]["mae_thresh"] - mape_thresh = args["train"]["mape_thresh"] - else: - # 只有train_args情况 - mae_thresh = args["mae_thresh"] - mape_thresh = args["mape_thresh"] # 计算并记录每个时间步的指标 for t in range(d_y_true.shape[1]): @@ -254,9 +243,5 @@ class Trainer: logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) - logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - - @staticmethod - def _compute_sampling_threshold(global_step, k): - return k / (k + math.exp(global_step / k)) + avg_mae, avg_rmse, avg_mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info(f"Average Horizon, MAE: {avg_mae:.4f}, RMSE: {avg_rmse:.4f}, MAPE: {avg_mape:.4f}") -- 2.40.1 From 3095b7435b42d2833e002381a4320507504bb8f4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 01:38:47 +0800 Subject: [PATCH 29/41] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=8A=A0=E8=BD=BD=E5=99=A8=E5=92=8C=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=99=A8=E4=BB=A3=E7=A0=81=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E7=BB=93=E6=9E=84=E5=92=8C=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构数据加载器模块,使用字典映射替代switch-case结构 简化训练器逻辑,合并重复代码,提高可维护性 优化日志时间格式,缩短显示长度 调整训练配置,减少默认epoch数并启用GPU训练 统一数据加载方式,提取公共方法减少重复代码 --- dataloader/data_selector.py | 155 ++++++-------- dataloader/loader_selector.py | 27 +-- train.py | 8 +- trainer/TSTrainer.py | 387 +++++++++++++--------------------- trainer/Trainer.py | 277 ++++++------------------ trainer/trainer_selector.py | 147 ++----------- utils/logger.py | 2 +- 7 files changed, 318 insertions(+), 685 deletions(-) diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index e0b23e1..bd8e61a 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -2,95 +2,80 @@ import os import numpy as np import h5py + def load_st_dataset(config): dataset = config["basic"]["dataset"] - # sample = config["data"]["sample"] - # output B, N, D - match dataset: - case "BeijingAirQuality": - data_path = os.path.join("./data/BeijingAirQuality/data.dat") - data = np.memmap(data_path, dtype=np.float32, mode='r') - L, N, C = 36000, 7, 3 - data = data.reshape(L, N, C) - case "AirQuality": - data_path = os.path.join("./data/AirQuality/data.dat") - data = np.memmap(data_path, dtype=np.float32, mode='r') - L, N, C = 8701,35,6 - data = data.reshape(L, N, C) - case "PEMS-BAY": - data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") - with h5py.File(data_path, 'r') as f: - data = f['speed']['block0_values'][:] - case "METR-LA": - data_path = os.path.join("./data/METR-LA/METR-LA.h5") - with h5py.File(data_path, 'r') as f: - data = f['df']['block0_values'][:] - case "SolarEnergy": - data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv") - data = np.loadtxt(data_path, delimiter=",") - case "PEMSD3": - data_path = os.path.join("./data/PEMS03/PEMS03.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD4": - data_path = os.path.join("./data/PEMS04/PEMS04.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7": - data_path = os.path.join("./data/PEMS07/PEMS07.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD8": - data_path = os.path.join("./data/PEMS08/PEMS08.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7(L)": - data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7(M)": - data_path = os.path.join("./data/PEMS07(M)/V_228.csv") - data = np.genfromtxt(data_path, delimiter=",") - case "BJ": - data_path = os.path.join("./data/BJ/BJ500.csv") - data = np.genfromtxt(data_path, delimiter=",", skip_header=1) - case "Hainan": - data_path = os.path.join("./data/Hainan/Hainan.npz") - data = np.load(data_path)["data"][:, :, 0] - case "SD": - data_path = os.path.join("./data/SD/data.npz") - data = np.load(data_path)["data"][:, :, 0].astype(np.float32) - case "BJTaxi-InFlow": - data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) - case "BJTaxi-OutFlow": - data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) - case "NYCBike-InFlow": - data_path = os.path.join("./data/NYCBike/NYC16x8.h5") - with h5py.File(data_path, 'r') as f: - data = f['data'][:].astype(np.float32) - data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) - data = data[:, :, 0:1] - case "NYCBike-OutFlow": - data_path = os.path.join("./data/NYCBike/NYC16x8.h5") - with h5py.File(data_path, 'r') as f: - data = f['data'][:].astype(np.float32) - data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) - data = data[:, :, 1:2] - case _: - raise ValueError(f"Unsupported dataset: {dataset}") - # Ensure data shape compatibility - if len(data.shape) == 2: - data = np.expand_dims(data, axis=-1) + loaders = { + "BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3), + "AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6), - print("加载 %s 数据集中... " % dataset) - # return data[::sample] + "PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")), + "METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")), + + "SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","), + + "PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"), + "PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"), + "PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"), + "PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"), + + "PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"), + "PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","), + + "BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1), + "Hainan": lambda: _npz("./data/Hainan/Hainan.npz"), + "SD": lambda: _npz("./data/SD/data.npz", cast=True), + + "BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32), + "BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32), + + "NYCBike-InFlow": lambda: _nyc_bike(0), + "NYCBike-OutFlow": lambda: _nyc_bike(1), + } + + if dataset not in loaders: + raise ValueError(f"Unsupported dataset: {dataset}") + + data = loaders[dataset]() + + if data.ndim == 2: + data = data[..., None] + + print(f"加载 {dataset} 数据集中... ") return data + +# ---------------- helpers ---------------- +def _memmap(path, L, N, C): + data = np.memmap(path, dtype=np.float32, mode="r") + return data.reshape(L, N, C) + + +def _h5(path, keys): + with h5py.File(path, "r") as f: + return f[keys[0]][keys[1]][:] + + +def _npz(path, cast=False): + data = np.load(path)["data"][:, :, 0] + return data.astype(np.float32) if cast else data + + +def _nyc_bike(channel): + with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f: + data = f["data"][:].astype(np.float32) + data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2) + return data[:, :, channel:channel + 1] + + def read_BeijingTaxi(): - files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", - "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] - all_data = [] - for file in files: - data_path = os.path.join(f"./data/BeijingTaxi/{file}") - data = np.load(data_path) - all_data.append(data) - all_data = np.concatenate(all_data, axis=0) - time_num = all_data.shape[0] - all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) - return all_data \ No newline at end of file + files = [ + "TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", + "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy", + ] + data = np.concatenate( + [np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0 + ) + T = data.shape[0] + return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2) diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index 5ea47fa..caeeb03 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -8,21 +8,12 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): - TS_model = ["iTransformer", "HI", "PatchTST"] - model_name = config["basic"]["model"] - # if model_name == "Informer": - # return Informer_loader(config, normalizer, single) - # elif model_name in TS_model: - # return TS_loader(config, normalizer, single) - # else : - match model_name: - case "STGNCDE": - return cde_loader(config, normalizer, single) - case "STGNRDE": - return nrde_loader(config, normalizer, single) - case "DCRNN": - return DCRNN_loader(config, normalizer, single) - case "EXP": - return EXP_loader(config, normalizer, single) - case _: - return normal_loader(config, normalizer, single) + loader_map = { + "STGNCDE": cde_loader, + "STGNRDE": nrde_loader, + "DCRNN": DCRNN_loader, + "EXP": EXP_loader, + } + return loader_map.get(config["basic"]["model"], normal_loader)( + config, normalizer, single + ) diff --git a/train.py b/train.py index acd0e60..139cdfa 100644 --- a/train.py +++ b/train.py @@ -11,9 +11,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cpu" # 指定设备为cuda:0 + device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 120 + epochs = 1 # 拷贝项 config["basic"]["device"] = device @@ -65,8 +65,8 @@ def main(debug=False): model_list = ["iTransformer"] # 指定数据集 # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - dataset_list = ["AirQuality"] - # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index b8def31..c427072 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -1,296 +1,195 @@ -import math -import os -import time -import copy -import torch +import os, time, copy, torch +from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -from tqdm import tqdm class TSWrapper: def __init__(self, args): - self.b = args['train']['batch_size'] - self.t = args['data']['lag'] self.n = args['data']['num_nodes'] - self.c = args['data']['input_dim'] - - def transpose(self, x : torch.Tensor): + def forward(self, x): # [b, t, n, c] -> [b*n, t, c] - self.b = x.shape[0] - x = x[..., :-2] - x = x.permute(0, 2, 1, 3) - x = x.reshape(self.b*self.n, self.t, self.c) - return x - - def inv_transpose(self, x : torch.Tensor): - x = x.reshape(self.b, self.n, self.t, self.c) - x = x.permute(0, 2, 1, 3) - return x + b, t, n, c = x.shape + x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) + return x, b, t, n, c + + def inverse(self, x, b, t, n, c): + return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) class Trainer: - """模型训练器,负责整个训练流程的管理""" - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, scaler, - args, lr_scheduler=None,): - # 设备和基本参数 + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): + self.config = args self.device = args["basic"]["device"] - train_args = args["train"] - # 模型和训练相关组件 - self.model = model + self.args = args["train"] + + self.model = model.to(self.device) self.loss = loss self.optimizer = optimizer self.lr_scheduler = lr_scheduler - # 数据加载器 + self.train_loader = train_loader - self.val_loader = val_loader + self.val_loader = val_loader or test_loader self.test_loader = test_loader - # 数据处理工具 self.scaler = scaler - self.args = train_args - self.ts_wrapper = TSWrapper(args) - # 初始化路径、日志和统计 - self._initialize_paths(train_args) - self._initialize_logger(train_args) - - def _initialize_paths(self, args): - """初始化模型保存路径""" - self.best_path = os.path.join(args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - def _initialize_logger(self, args): - """初始化日志记录器""" - if not os.path.isdir(args["log_dir"]) and not args["debug"]: - os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) - self.logger.info(f"Experiment log path in: {args['log_dir']}") - def _run_epoch(self, epoch, dataloader, mode): - """运行一个训练/验证/测试epoch""" - # 设置模型模式和是否进行优化 - if mode == "train": self.model.train(); optimizer_step = True - else: self.model.eval(); optimizer_step = False + self.ts = TSWrapper(args) + self._init_paths() + self._init_logger() - # 初始化变量 - total_loss = 0 - epoch_time = time.time() + # ---------------- init ---------------- + def _init_paths(self): + d = self.args["log_dir"] + self.best_path = os.path.join(d, "best_model.pth") + self.best_test_path = os.path.join(d, "best_test_model.pth") + + def _init_logger(self): + if not self.args["debug"]: + os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger( + self.args["log_dir"], + name=self.model.__class__.__name__, + debug=self.args["debug"], + ) + + # ---------------- epoch ---------------- + def _run_epoch(self, epoch, loader, mode): + is_train = mode == "train" + self.model.train() if is_train else self.model.eval() + + total_loss, start = 0.0, time.time() y_pred, y_true = [], [] - # 训练/验证循环 - with torch.set_grad_enabled(optimizer_step): - progress_bar = tqdm( - enumerate(dataloader), - total=len(dataloader), - desc=f"{mode.capitalize()} Epoch {epoch}" - ) - for _, (data, target) in progress_bar: - # 转移数据 - data = data.to(self.device) - target = target.to(self.device) - label = target[..., : self.args["output_dim"]] - # 转换为 [b*n, t, c] - data = self.ts_wrapper.transpose(data) - # 计算loss和反归一化loss - output = self.model(data) - # 转换回[b, t, n, c] - output = self.ts_wrapper.inv_transpose(output) - # 我的调试开关 + with torch.set_grad_enabled(is_train): + for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + x, b, t, n, c = self.ts.forward(data) + out = self.model(x) + out = self.ts.inverse(out, b, t, n, c) + if os.environ.get("TRY") == "True": - print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") - assert False - loss = self.loss(output, label) - d_output = self.scaler.inverse_transform(output) - d_label = self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - # 累积损失和预测结果 + print(out.shape, label.shape) + assert out.shape == label.shape + + loss = self.loss(out, label) + d_out = self.scaler.inverse_transform(out) + d_lbl = self.scaler.inverse_transform(label) + d_loss = self.loss(d_out, d_lbl) + total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: + y_pred.append(d_out.detach().cpu()) + y_true.append(d_lbl.detach().cpu()) + + if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - # 梯度裁剪(如果需要) if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.args["max_grad_norm"] + ) self.optimizer.step() - # 更新进度条 - progress_bar.set_postfix(loss=d_loss.item()) - # 合并所有批次的预测结果 - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) - # 计算损失并记录指标 - avg_loss = total_loss / len(dataloader) - mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} " - f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + y_pred = torch.cat(y_pred) + y_true = torch.cat(y_true) + + mae, rmse, mape = all_metrics( + y_pred, y_true, + self.args["mae_thresh"], + self.args["mape_thresh"] ) - return avg_loss - def train_epoch(self, epoch): - return self._run_epoch(epoch, self.train_loader, "train") - - def val_epoch(self, epoch): - return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - - def test_epoch(self, epoch): - return self._run_epoch(epoch, self.test_loader, "test") + self.logger.info( + f"Epoch #{epoch:02d} {mode:<5} " + f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " + f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + ) + return total_loss / len(loader) + # ---------------- train ---------------- def train(self): - # 初始化记录 - best_model, best_test_model = None, None - best_loss, best_test_loss = float("inf"), float("inf") - not_improved_count = 0 - # 开始训练 - self.logger.info("Training process started") - # 训练循环 + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None + patience = 0 + + self.logger.info("Training started") + for epoch in range(1, self.args["epochs"] + 1): - # 训练、验证和测试一个epoch - train_epoch_loss = self.train_epoch(epoch) - val_epoch_loss = self.val_epoch(epoch) - test_epoch_loss = self.test_epoch(epoch) - # 检查梯度爆炸 - if train_epoch_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") + losses = { + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), + } + + if losses["train"] > 1e6: + self.logger.warning("Gradient explosion detected") break - # 更新最佳验证模型 - if val_epoch_loss < best_loss: - best_loss = val_epoch_loss - not_improved_count = 0 - best_model = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") + + if losses["val"] < best: + best, patience = losses["val"], 0 + best_w = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved") else: - not_improved_count += 1 - # 早停 - if self._should_early_stop(not_improved_count): + patience += 1 + + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: + self.logger.info("Early stopping triggered") break - # 更新最佳测试模型 - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 + + if losses["test"] < best_test: + best_test = losses["test"] + best_test_w = copy.deepcopy(self.model.state_dict()) + if not self.args["debug"]: - self._save_best_models(best_model, best_test_model) - # 最终评估 - self._finalize_training(best_model, best_test_model) - - def _should_early_stop(self, not_improved_count): - """检查是否满足早停条件""" - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - return True - return False - - def _save_best_models(self, best_model, best_test_model): - """保存最佳模型到文件""" - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) - - def _log_model_params(self): - """输出模型可训练参数数量""" - total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) - self.logger.info(f"Trainable params: {total_params}") - + torch.save(best_w, self.best_path) + torch.save(best_test_w, self.best_test_path) - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + self._final_test(best_w, best_test_w) - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - """对模型进行评估并输出性能指标""" - # 确定设备信息 - device = None - output_dim = None - # 处理不同的参数格式 - if isinstance(args, dict): - if "basic" in args: - # 完整配置情况 - device = args["basic"]["device"] - output_dim = args["train"]["output_dim"] - else: - # 只有train_args情况,从模型获取设备 - device = next(model.parameters()).device - output_dim = args["output_dim"] - else: - raise ValueError(f"Unsupported args type: {type(args)}") - - # 加载模型检查点(如果提供了路径) - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(device) + # ---------------- final test ---------------- + def _final_test(self, best_w, best_test_w): + for name, w in [("best val", best_w), ("best test", best_test_w)]: + self.model.load_state_dict(w) + self.logger.info(f"Testing on {name} model") + self.evaluate() - # 设置为评估模式 - model.eval() - - # 收集预测和真实标签 + # ---------------- evaluate ---------------- + def evaluate(self): + self.model.eval() y_pred, y_true = [], [] - # 不计算梯度的情况下进行预测 with torch.no_grad(): - for data, target in data_loader: - # 将数据和标签移动到指定设备 - data = data.to(device) - target = target.to(device) - - data = data[..., :-2] - b, t, n, c = data.shape - data = data.permute(0, 2, 1, 3) - data = data.reshape(b*n, t, c) - label = target[..., : output_dim] - output = model(data) - output = output.reshape(b, n, t, c) - output = output.permute(0, 2, 1, 3) + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) + x, b, t, n, c = self.ts.forward(data) + out = self.model(x) + out = self.ts.inverse(out, b, t, n, c) - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + y_pred.append(out.cpu()) + y_true.append(label.cpu()) - # 获取metrics参数 - if "basic" in args: - # 完整配置情况 - mae_thresh = args["train"]["mae_thresh"] - mape_thresh = args["train"]["mape_thresh"] - else: - # 只有train_args情况 - mae_thresh = args["mae_thresh"] - mape_thresh = args["mape_thresh"] - - # 计算并记录每个时间步的指标 - for t in range(d_y_true.shape[1]): + d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) + d_true = self.scaler.inverse_transform(torch.cat(y_true)) + + for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( - d_y_pred[:, t, ...], - d_y_true[:, t, ...], - mae_thresh, - mape_thresh, + d_pred[:, t], d_true[:, t], + self.args["mae_thresh"], + self.args["mape_thresh"] ) - logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + self.logger.info( + f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + ) + + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" + ) - # 计算并记录平均指标 - mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) - logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - - @staticmethod - def _compute_sampling_threshold(global_step, k): - return k / (k + math.exp(global_step / k)) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 04842ba..65980b9 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -1,4 +1,3 @@ -import math import os import time import copy @@ -8,240 +7,100 @@ from utils.loss_function import all_metrics from tqdm import tqdm class Trainer: - """模型训练器,负责整个训练流程的管理""" - def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): - # 设备和基本参数 - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - # 模型和训练相关组件 + self.config, self.device, self.args = args, args["basic"]["device"], args["train"] self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler - - # 数据加载器 - self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader, test_loader + self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler - # 数据处理工具 - self.scaler = scaler + log_dir = self.args["log_dir"] + self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] - # 初始化路径、日志和统计 - self._initialize_paths(self.args) - self._initialize_logger(self.args) - - def _initialize_paths(self, args): - """初始化模型保存路径""" - log_dir = args["log_dir"] - self.best_path = os.path.join(log_dir, "best_model.pth") - self.best_test_path = os.path.join(log_dir, "best_test_model.pth") - self.loss_figure_path = os.path.join(log_dir, "loss.png") - - def _initialize_logger(self, args): - """初始化日志记录器""" - log_dir = args["log_dir"] - if not args["debug"]: - os.makedirs(log_dir, exist_ok=True) - self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=args["debug"]) + if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) + self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) self.logger.info(f"Experiment log path in: {log_dir}") - def _run_epoch(self, epoch, dataloader, mode): - """运行一个训练/验证/测试epoch""" - # 设置模型模式和是否进行优化 - self.model.train() if mode == "train" else self.model.eval() - optimizer_step = mode == "train" - - # 初始化变量 - total_loss = 0 - epoch_time = time.time() - y_pred, y_true = [], [] - - # 训练/验证循环 - with torch.set_grad_enabled(optimizer_step): - progress_bar = tqdm( - dataloader, - total=len(dataloader), - desc=f"{mode.capitalize()} Epoch {epoch}" - ) - for data, target in progress_bar: - # 转移数据并提取标签 - data, target = data.to(self.device), target.to(self.device) - label = target[..., : self.args["output_dim"]] - - # 计算输出 - output = self.model(data) - - # 我的调试开关 - if os.environ.get("TRY") == "True": - status = '✅' if output.shape == label.shape else '❌' - print(f"[{status}]: output: {output.shape}, label: {label.shape}") - assert False - - # 计算损失 - loss = self.loss(output, label) - d_output = self.scaler.inverse_transform(output) - d_label = self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - - # 累积损失和预测结果 - total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - # 梯度裁剪(如果需要) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) - - self.optimizer.step() - - # 更新进度条 - progress_bar.set_postfix(loss=d_loss.item()) - - # 合并所有批次的预测结果 - y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) - - # 计算损失并记录指标 - avg_loss = total_loss / len(dataloader) - mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - - self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} " - f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" - ) - - return avg_loss - def train(self): - # 初始化记录 best_model = best_test_model = None best_loss = best_test_loss = float("inf") not_improved_count = 0 - # 开始训练 self.logger.info("Training process started") - # 训练循环 for epoch in range(1, self.args["epochs"] + 1): - # 训练、验证和测试一个epoch - train_epoch_loss = self._run_epoch(epoch, self.train_loader, "train") - val_epoch_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - test_epoch_loss = self._run_epoch(epoch, self.test_loader, "test") + train_loss = self._run_epoch(epoch, self.train_loader, "train") + val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + test_loss = self._run_epoch(epoch, self.test_loader, "test") - # 检查梯度爆炸 - if train_epoch_loss > 1e6: + if train_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break - # 更新最佳验证模型 - if val_epoch_loss < best_loss: - best_loss, not_improved_count = val_epoch_loss, 0 - best_model = copy.deepcopy(self.model.state_dict()) + if val_loss < best_loss: + best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict()) self.logger.info("Best validation model saved!") - else: - not_improved_count += 1 - - # 早停检查 - if self._should_early_stop(not_improved_count): + elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]: + self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") break - # 更新最佳测试模型 - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) + if test_loss < best_test_loss: + best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 - if not self.args["debug"]: - self._save_best_models(best_model, best_test_model) - - # 最终评估 - self._finalize_training(best_model, best_test_model) - - def _should_early_stop(self, not_improved_count): - """检查是否满足早停条件""" - if self.args["early_stop"] and not_improved_count == self.args["early_stop_patience"]: - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - return True - return False - - def _save_best_models(self, best_model, best_test_model): - """保存最佳模型到文件""" + torch.save(best_model, self.best_path) torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + + for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]: + self.model.load_state_dict(state_dict) + self.logger.info(f"Testing on {model_name} model") + self._run_epoch(None, self.test_loader, "test", log_horizon=True) - def _log_model_params(self): - """输出模型可训练参数数量""" - total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) - self.logger.info(f"Trainable params: {total_params}") + def _run_epoch(self, epoch, dataloader, mode, log_horizon=False): + self.model.train() if mode == "train" else self.model.eval() + optimizer_step = mode == "train" - - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - """对模型进行评估并输出性能指标""" - # 验证参数类型 - if not isinstance(args, dict): - raise ValueError(f"Unsupported args type: {type(args)}") - - # 确定设备和输出维度 - is_full_config = "basic" in args - device = args["basic"]["device"] if is_full_config else next(model.parameters()).device - output_dim = args["train"]["output_dim"] if is_full_config else args["output_dim"] - - # 获取metrics参数 - train_args = args["train"] if is_full_config else args - mae_thresh, mape_thresh = train_args["mae_thresh"], train_args["mape_thresh"] - - # 加载模型检查点(如果提供了路径) - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(device) - - # 设置为评估模式并收集预测结果 - model.eval() + total_loss, epoch_time = 0, time.time() y_pred, y_true = [], [] - - # 不计算梯度的情况下进行预测 - with torch.no_grad(): - for data, target in data_loader: - # 将数据和标签移动到指定设备 - data, target = data.to(device), target.to(device) - label = target[..., : output_dim] - - output = model(data) - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) - - # 反归一化并计算指标 - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - # 计算并记录每个时间步的指标 - for t in range(d_y_true.shape[1]): - mae, rmse, mape = all_metrics( - d_y_pred[:, t, ...], - d_y_true[:, t, ...], - mae_thresh, - mape_thresh, - ) - logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + with torch.set_grad_enabled(optimizer_step): + for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + output = self.model(data) + loss = self.loss(output, label) + d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) + d_loss = self.loss(d_output, d_label) + + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + if optimizer_step and self.optimizer: + self.optimizer.zero_grad() + loss.backward() + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + + y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) + + if log_horizon: + for t in range(y_true.shape[1]): + mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + + if epoch and mode: + self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") + elif mode: + self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}") + + return total_loss / len(dataloader) - # 计算并记录平均指标 - avg_mae, avg_rmse, avg_mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) - logger.info(f"Average Horizon, MAE: {avg_mae:.4f}, RMSE: {avg_rmse:.4f}, MAPE: {avg_mape:.4f}") + def test(self, path=None): + if path: + self.model.load_state_dict(torch.load(path)["state_dict"]) + self.model.to(self.device) + + self._run_epoch(None, self.test_loader, "test", log_horizon=True) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 723b257..17aa81d 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -7,132 +7,31 @@ from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer from trainer.TSTrainer import Trainer as TSTrainer + def select_trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs, + model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler, kwargs ): model_name = args["basic"]["model"] - TS_model = ["HI", "PatchTST", "iTransformer"] - if model_name in TS_model: - return TSTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) + base_args = ( + model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler + ) + if model_name in {"HI", "PatchTST", "iTransformer"}: + return TSTrainer(*base_args) - match model_name: - case "STGNCDE": - return cdeTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs[0], - None, - ) - case "STGNRDE": - return cdeTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs[0], - None, - ) - case "DCRNN": - return DCRNN_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "PDG2SEQ": - return PDG2SEQ_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "STMLP": - return STMLP_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "EXP": - return EXP_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "Informer": - return InformerTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case _: - return Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) + trainer_map = { + "DCRNN": DCRNN_Trainer, + "PDG2SEQ": PDG2SEQ_Trainer, + "STMLP": STMLP_Trainer, + "EXP": EXP_Trainer, + "Informer": InformerTrainer, + } + + if model_name in {"STGNCDE", "STGNRDE"}: + return cdeTrainer(*base_args, kwargs[0], None) + + return trainer_map.get(model_name, Trainer)(*base_args) diff --git a/utils/logger.py b/utils/logger.py index 8a2f187..7a818f6 100755 --- a/utils/logger.py +++ b/utils/logger.py @@ -18,7 +18,7 @@ def get_logger(root, name=None, debug=True): logger.handlers.clear() # 时间格式改为 年/月/日 时:分:秒 - formatter = logging.Formatter("%(asctime)s - %(message)s", "%Y/%m/%d %H:%M:%S") + formatter = logging.Formatter("%(asctime)s - %(message)s", "%m/%d %H:%M") # 控制台输出 console_handler = logging.StreamHandler() -- 2.40.1 From 56b09ea8ac8eb386c209f8a9ddadb3a9ec8ebbb1 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 19:55:57 +0800 Subject: [PATCH 30/41] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=99=A8=E5=92=8C=E9=85=8D=E7=BD=AE=E7=BB=93=E6=9E=84?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor(trainer): 重构Trainer类结构,拆分初始化方法 perf(trainer): 优化训练循环和评估逻辑 style(config): 统一配置文件命名和结构 fix(trainer): 修复形状检查逻辑和调试模式处理 docs: 更新README和注释说明 --- .../AirQuality.yaml} | 0 .../BJTaxi-InFlow.yaml} | 0 .../BJTaxi-OutFlow.yaml} | 0 .../v2_METR-LA.yaml => ASTRA_v2/METR-LA.yaml} | 2 +- .../NYCBike-InFlow.yaml} | 0 .../NYCBike-OutFlow.yaml} | 0 .../PEMS-BAY.yaml} | 2 +- .../SolarEnergy.yaml} | 2 +- config/ASTRA_v3/AirQuality.yaml | 54 ++++ config/ASTRA_v3/BJTaxi-InFlow.yaml | 54 ++++ config/ASTRA_v3/BJTaxi-OutFlow.yaml | 54 ++++ .../v3_METR-LA.yaml => ASTRA_v3/METR-LA.yaml} | 4 +- config/ASTRA_v3/NYCBike-InFlow.yaml | 54 ++++ config/ASTRA_v3/NYCBike-OutFlow.yaml | 54 ++++ .../PEMS-BAY.yaml} | 4 +- config/ASTRA_v3/SolarEnergy.yaml | 54 ++++ ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 config/REPST/AirQuality.yaml | 4 +- model/ASTRA/astrav3.py | 8 +- train.py | 23 +- trainer/TSTrainer.py | 8 +- trainer/Trainer.py | 256 +++++++++++------- 23 files changed, 517 insertions(+), 120 deletions(-) rename config/{ASTRA/v2_AirQuality.yaml => ASTRA_v2/AirQuality.yaml} (100%) rename config/{ASTRA/v2_BJTaxi-InFlow.yaml => ASTRA_v2/BJTaxi-InFlow.yaml} (100%) rename config/{ASTRA/v2_BJTaxi-OutFlow.yaml => ASTRA_v2/BJTaxi-OutFlow.yaml} (100%) rename config/{ASTRA/v2_METR-LA.yaml => ASTRA_v2/METR-LA.yaml} (97%) rename config/{ASTRA/v2_NYCBike-InFlow.yaml => ASTRA_v2/NYCBike-InFlow.yaml} (100%) rename config/{ASTRA/v2_NYCBike-OutFlow.yaml => ASTRA_v2/NYCBike-OutFlow.yaml} (100%) rename config/{ASTRA/v3_PEMS-BAY.yaml => ASTRA_v2/PEMS-BAY.yaml} (97%) rename config/{ASTRA/v2_SolarEnergy.yaml => ASTRA_v2/SolarEnergy.yaml} (97%) create mode 100644 config/ASTRA_v3/AirQuality.yaml create mode 100644 config/ASTRA_v3/BJTaxi-InFlow.yaml create mode 100644 config/ASTRA_v3/BJTaxi-OutFlow.yaml rename config/{ASTRA/v3_METR-LA.yaml => ASTRA_v3/METR-LA.yaml} (93%) create mode 100644 config/ASTRA_v3/NYCBike-InFlow.yaml create mode 100644 config/ASTRA_v3/NYCBike-OutFlow.yaml rename config/{ASTRA/v2_PEMS-BAY.yaml => ASTRA_v3/PEMS-BAY.yaml} (92%) create mode 100644 config/ASTRA_v3/SolarEnergy.yaml rename config/MTGNN/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/MTGNN/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) diff --git a/config/ASTRA/v2_AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml similarity index 100% rename from config/ASTRA/v2_AirQuality.yaml rename to config/ASTRA_v2/AirQuality.yaml diff --git a/config/ASTRA/v2_BJTaxi-InFlow.yaml b/config/ASTRA_v2/BJTaxi-InFlow.yaml similarity index 100% rename from config/ASTRA/v2_BJTaxi-InFlow.yaml rename to config/ASTRA_v2/BJTaxi-InFlow.yaml diff --git a/config/ASTRA/v2_BJTaxi-OutFlow.yaml b/config/ASTRA_v2/BJTaxi-OutFlow.yaml similarity index 100% rename from config/ASTRA/v2_BJTaxi-OutFlow.yaml rename to config/ASTRA_v2/BJTaxi-OutFlow.yaml diff --git a/config/ASTRA/v2_METR-LA.yaml b/config/ASTRA_v2/METR-LA.yaml similarity index 97% rename from config/ASTRA/v2_METR-LA.yaml rename to config/ASTRA_v2/METR-LA.yaml index bf92089..dca4bb4 100644 --- a/config/ASTRA/v2_METR-LA.yaml +++ b/config/ASTRA_v2/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/v2_NYCBike-InFlow.yaml b/config/ASTRA_v2/NYCBike-InFlow.yaml similarity index 100% rename from config/ASTRA/v2_NYCBike-InFlow.yaml rename to config/ASTRA_v2/NYCBike-InFlow.yaml diff --git a/config/ASTRA/v2_NYCBike-OutFlow.yaml b/config/ASTRA_v2/NYCBike-OutFlow.yaml similarity index 100% rename from config/ASTRA/v2_NYCBike-OutFlow.yaml rename to config/ASTRA_v2/NYCBike-OutFlow.yaml diff --git a/config/ASTRA/v3_PEMS-BAY.yaml b/config/ASTRA_v2/PEMS-BAY.yaml similarity index 97% rename from config/ASTRA/v3_PEMS-BAY.yaml rename to config/ASTRA_v2/PEMS-BAY.yaml index 9f98483..2f6dfbf 100755 --- a/config/ASTRA/v3_PEMS-BAY.yaml +++ b/config/ASTRA_v2/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA_v3 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/v2_SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml similarity index 97% rename from config/ASTRA/v2_SolarEnergy.yaml rename to config/ASTRA_v2/SolarEnergy.yaml index a45ad73..83a87c2 100644 --- a/config/ASTRA/v2_SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml new file mode 100644 index 0000000..68e6acc --- /dev/null +++ b/config/ASTRA_v3/AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/BJTaxi-InFlow.yaml b/config/ASTRA_v3/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..34abfd8 --- /dev/null +++ b/config/ASTRA_v3/BJTaxi-InFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/BJTaxi-OutFlow.yaml b/config/ASTRA_v3/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..8e6b30d --- /dev/null +++ b/config/ASTRA_v3/BJTaxi-OutFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/v3_METR-LA.yaml b/config/ASTRA_v3/METR-LA.yaml similarity index 93% rename from config/ASTRA/v3_METR-LA.yaml rename to config/ASTRA_v3/METR-LA.yaml index 5d22820..2b5512b 100644 --- a/config/ASTRA/v3_METR-LA.yaml +++ b/config/ASTRA_v3/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA_v3 + model: ASTRA_v3 seed: 2023 data: @@ -19,11 +19,9 @@ data: val_ratio: 0.2 model: - chebyshev_order: 3 d_ff: 128 d_model: 64 dropout: 0.2 - graph_hidden_dim: 32 gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 diff --git a/config/ASTRA_v3/NYCBike-InFlow.yaml b/config/ASTRA_v3/NYCBike-InFlow.yaml new file mode 100644 index 0000000..18c4fa3 --- /dev/null +++ b/config/ASTRA_v3/NYCBike-InFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/NYCBike-OutFlow.yaml b/config/ASTRA_v3/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..ff73662 --- /dev/null +++ b/config/ASTRA_v3/NYCBike-OutFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/v2_PEMS-BAY.yaml b/config/ASTRA_v3/PEMS-BAY.yaml similarity index 92% rename from config/ASTRA/v2_PEMS-BAY.yaml rename to config/ASTRA_v3/PEMS-BAY.yaml index c40034d..6739aeb 100755 --- a/config/ASTRA/v2_PEMS-BAY.yaml +++ b/config/ASTRA_v3/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v3 seed: 2023 data: @@ -19,11 +19,9 @@ data: val_ratio: 0.2 model: - chebyshev_order: 3 d_ff: 128 d_model: 64 dropout: 0.2 - graph_hidden_dim: 32 gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 diff --git a/config/ASTRA_v3/SolarEnergy.yaml b/config/ASTRA_v3/SolarEnergy.yaml new file mode 100644 index 0000000..289b839 --- /dev/null +++ b/config/ASTRA_v3/SolarEnergy.yaml @@ -0,0 +1,54 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 137 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/MTGNN/NYCBike-Inflow.yaml b/config/MTGNN/NYCBike-InFlow.yaml similarity index 100% rename from config/MTGNN/NYCBike-Inflow.yaml rename to config/MTGNN/NYCBike-InFlow.yaml diff --git a/config/MTGNN/NYCBike-Outflow.yaml b/config/MTGNN/NYCBike-OutFlow.yaml similarity index 100% rename from config/MTGNN/NYCBike-Outflow.yaml rename to config/MTGNN/NYCBike-OutFlow.yaml diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index a40e11e..c035a44 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -13,8 +13,8 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 - steps_per_day: 288 + num_nodes: 35 + steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index a29bfc3..0e9aebf 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -184,7 +184,7 @@ class ASTRA(nn.Module): def forward(self, x): # 数据处理 - x = x[..., :1] # [B,T,N,1] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] # 图编码 @@ -203,7 +203,9 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 - outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] - outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, self.input_dim) + outputs = outputs.permute(0,2,1,3) return outputs \ No newline at end of file diff --git a/train.py b/train.py index 139cdfa..2d3a32f 100644 --- a/train.py +++ b/train.py @@ -60,25 +60,17 @@ def run(config): case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") -def main(debug=False): - # 指定模型 - model_list = ["iTransformer"] - # 指定数据集 - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] - +def main(model, data, debug=False): # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) - os.environ["TRY"] = str(False) - + os.environ["TRY"] = str(debug) + for model in model_list: - for dataset in dataset_list: + for dataset in data: config_path = f"./config/{model}/{dataset}.yaml" # 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs config = read_config(config_path) print(f"\nRunning {model} on {dataset}") - # print(f"config: {config}") if os.environ.get("TRY") == "True": try: run(config) @@ -97,4 +89,9 @@ def main(debug=False): if __name__ == "__main__": # 调试用 - main(debug = True) \ No newline at end of file + # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + main(model_list, dataset_list, debug = True) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index c427072..932d8b3 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -73,8 +73,12 @@ class Trainer: out = self.ts.inverse(out, b, t, n, c) if os.environ.get("TRY") == "True": - print(out.shape, label.shape) - assert out.shape == label.shape + if out.shape == label.shape: + print("shape true") + assert False + else: + print("shape false") + assert False loss = self.loss(out, label) d_out = self.scaler.inverse_transform(out) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 65980b9..cdd444b 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -1,106 +1,180 @@ -import os -import time -import copy -import torch +import os, time, copy, torch +from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -from tqdm import tqdm class Trainer: - def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): - self.config, self.device, self.args = args, args["basic"]["device"], args["train"] - self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler - self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler - - log_dir = self.args["log_dir"] - self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] - - if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) - self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) - self.logger.info(f"Experiment log path in: {log_dir}") + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): - def train(self): - best_model = best_test_model = None - best_loss = best_test_loss = float("inf") - not_improved_count = 0 - - self.logger.info("Training process started") - - for epoch in range(1, self.args["epochs"] + 1): - train_loss = self._run_epoch(epoch, self.train_loader, "train") - val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - test_loss = self._run_epoch(epoch, self.test_loader, "test") - - if train_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") - break - - if val_loss < best_loss: - best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") - elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]: - self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") - break - - if test_loss < best_test_loss: - best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict()) - + self.config = args + self.device = args["basic"]["device"] + self.args = args["train"] - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") - - for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]: - self.model.load_state_dict(state_dict) - self.logger.info(f"Testing on {model_name} model") - self._run_epoch(None, self.test_loader, "test", log_horizon=True) - - def _run_epoch(self, epoch, dataloader, mode, log_horizon=False): - self.model.train() if mode == "train" else self.model.eval() - optimizer_step = mode == "train" - - total_loss, epoch_time = 0, time.time() + self.model = model.to(self.device) + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + self.train_loader = train_loader + self.val_loader = val_loader or test_loader + self.test_loader = test_loader + self.scaler = scaler + + self._init_paths() + self._init_logger() + + # ---------------- init ---------------- + def _init_paths(self): + d = self.args["log_dir"] + self.best_path = os.path.join(d, "best_model.pth") + self.best_test_path = os.path.join(d, "best_test_model.pth") + + def _init_logger(self): + if not self.args["debug"]: + os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger( + self.args["log_dir"], + name=self.model.__class__.__name__, + debug=self.args["debug"], + ) + + # ---------------- epoch ---------------- + def _run_epoch(self, epoch, loader, mode): + is_train = mode == "train" + self.model.train() if is_train else self.model.eval() + + total_loss, start = 0.0, time.time() y_pred, y_true = [], [] - - with torch.set_grad_enabled(optimizer_step): - for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): + + with torch.set_grad_enabled(is_train): + for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - - output = self.model(data) - loss = self.loss(output, label) - d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - + + out = self.model(data) + + if os.environ.get("TRY") == "True": + if out.shape == label.shape: + print(f"shape true, out: {out.shape}, label: {label.shape}") + assert False + else: + print(f"shape false, out: {out.shape}, label: {label.shape}") + assert False + + loss = self.loss(out, label) + d_out = self.scaler.inverse_transform(out) + d_lbl = self.scaler.inverse_transform(label) + d_loss = self.loss(d_out, d_lbl) + total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - - if optimizer_step and self.optimizer: + y_pred.append(d_out.detach().cpu()) + y_true.append(d_lbl.detach().cpu()) + + if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.args["max_grad_norm"] + ) self.optimizer.step() - - y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) - - if log_horizon: - for t in range(y_true.shape[1]): - mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - - avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - - if epoch and mode: - self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") - elif mode: - self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}") - - return total_loss / len(dataloader) - def test(self, path=None): - if path: - self.model.load_state_dict(torch.load(path)["state_dict"]) - self.model.to(self.device) + y_pred = torch.cat(y_pred) + y_true = torch.cat(y_true) + + mae, rmse, mape = all_metrics( + y_pred, y_true, + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + + self.logger.info( + f"Epoch #{epoch:02d} {mode:<5} " + f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " + f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + ) + return total_loss / len(loader) + + # ---------------- train ---------------- + def train(self): + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None + patience = 0 + + self.logger.info("Training started") + + for epoch in range(1, self.args["epochs"] + 1): + losses = { + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), + } + + if losses["train"] > 1e6: + self.logger.warning("Gradient explosion detected") + break + + if losses["val"] < best: + best, patience = losses["val"], 0 + best_w = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved") + else: + patience += 1 + + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: + self.logger.info("Early stopping triggered") + break + + if losses["test"] < best_test: + best_test = losses["test"] + best_test_w = copy.deepcopy(self.model.state_dict()) + + if not self.args["debug"]: + torch.save(best_w, self.best_path) + torch.save(best_test_w, self.best_test_path) + + self._final_test(best_w, best_test_w) + + # ---------------- final test ---------------- + def _final_test(self, best_w, best_test_w): + for name, w in [("best val", best_w), ("best test", best_test_w)]: + self.model.load_state_dict(w) + self.logger.info(f"Testing on {name} model") + self.evaluate() + + # ---------------- evaluate ---------------- + def evaluate(self): + self.model.eval() + y_pred, y_true = [], [] + + with torch.no_grad(): + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + out = self.model(data) + + y_pred.append(out.cpu()) + y_true.append(label.cpu()) + + d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) + d_true = self.scaler.inverse_transform(torch.cat(y_true)) + + for t in range(d_true.shape[1]): + mae, rmse, mape = all_metrics( + d_pred[:, t], d_true[:, t], + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + self.logger.info( + f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + ) - self._run_epoch(None, self.test_loader, "test", log_horizon=True) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" + ) + -- 2.40.1 From b6d4f5daf5834307da1d85592c22a5d23e100be5 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 20:54:20 +0800 Subject: [PATCH 31/41] =?UTF-8?q?refactor(dataloader):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD=E5=99=A8=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E4=BC=98=E5=8C=96=E6=BB=91=E5=8A=A8=E7=AA=97?= =?UTF-8?q?=E5=8F=A3=E7=94=9F=E6=88=90=E5=92=8C=E5=BD=92=E4=B8=80=E5=8C=96?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构PeMSD和EXP数据加载器,使用numpy的stride_tricks实现高效滑动窗口 统一数据预处理流程,简化代码结构并提高可维护性 优化归一化处理,支持多scaler通道独立处理 --- dataloader/EXPdataloader.py | 243 ++++++++++------------------------ dataloader/PeMSDdataloader.py | 210 ++++++++++------------------- train.py | 11 +- trainer/TSTrainer.py | 1 - trainer/Trainer.py | 46 ++++--- 5 files changed, 171 insertions(+), 340 deletions(-) diff --git a/dataloader/EXPdataloader.py b/dataloader/EXPdataloader.py index 237bf71..18ebf61 100755 --- a/dataloader/EXPdataloader.py +++ b/dataloader/EXPdataloader.py @@ -1,199 +1,90 @@ import numpy as np import torch -from utils.normalization import normalize_dataset from dataloader.data_selector import load_st_dataset +from utils.normalization import normalize_dataset -def get_dataloader(args, normalizer="std", single=True): - # args should now include 'cycle' - data = load_st_dataset(args["type"], args["sample"]) # [T, N, F] - L, N, F = data.shape - # compute cycle index - cycle_arr = np.arange(L) % args["cycle"] # length-L array +_device = "cuda" if torch.cuda.is_available() else "cpu" +to_tensor = lambda a: torch.as_tensor(a, dtype=torch.float32, device=_device) - # Step 1: sliding windows for X and Y - x = add_window_x(data, args["lag"], args["horizon"], single) - y = add_window_y(data, args["lag"], args["horizon"], single) - # window count = M = L - lag - horizon + 1 - M = x.shape[0] +# Sliding window (stride trick, zero copy) +window = lambda d, w, h, o=0: np.lib.stride_tricks.as_strided( + d[o:], + shape=(len(d) - w - h + 1, w, *d.shape[1:]), + strides=(d.strides[0], d.strides[0], *d.strides[1:]) +) - # Step 2: time features - time_in_day = np.tile( - np.array([i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)]), - (N, 1), - ).T.reshape(L, N, 1) - day_in_week = np.tile( - np.array( - [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)] - ), - (N, 1), - ).T.reshape(L, N, 1) +# pad_with_last_sample=True +pad_last = lambda X, Y, bs: ( + (lambda r: ( + (np.concatenate([X, np.repeat(X[-1:], r, 0)], 0), + np.concatenate([Y, np.repeat(Y[-1:], r, 0)], 0)) + if r else (X, Y) + ))((-len(X)) % bs) +) - x_day = add_window_x(time_in_day, args["lag"], args["horizon"], single) - x_week = add_window_x(day_in_week, args["lag"], args["horizon"], single) - x = np.concatenate([x, x_day, x_week], axis=-1) - # del x_day, x_week - # gc.collect() +# Train / Val / Test split +split_by_ratio = lambda d, vr, tr: ( + d[:-(vl := int(len(d) * (vr + tr)))], + d[-vl:-(tl := int(len(d) * tr))], + d[-tl:] +) - # Step 3: extract cycle index per window: take value at end of sequence - cycle_win = np.array([cycle_arr[i + args["lag"]] for i in range(M)]) # shape [M] - # Step 4: split into train/val/test - if args["test_ratio"] > 1: - x_train, x_val, x_test = split_data_by_days( - x, args["val_ratio"], args["test_ratio"] - ) - y_train, y_val, y_test = split_data_by_days( - y, args["val_ratio"], args["test_ratio"] - ) - c_train, c_val, c_test = split_data_by_days( - cycle_win, args["val_ratio"], args["test_ratio"] - ) - else: - x_train, x_val, x_test = split_data_by_ratio( - x, args["val_ratio"], args["test_ratio"] - ) - y_train, y_val, y_test = split_data_by_ratio( - y, args["val_ratio"], args["test_ratio"] - ) - c_train, c_val, c_test = split_data_by_ratio( - cycle_win, args["val_ratio"], args["test_ratio"] - ) - # del x, y, cycle_win - # gc.collect() +def get_dataloader(config, normalizer="std", single_step=True): + data = load_st_dataset(config) + cfg = config["data"] - # Step 5: normalization on X only - scaler = normalize_dataset( - x_train[..., : args["input_dim"]], normalizer, args["column_wise"] - ) - x_train[..., : args["input_dim"]] = scaler.transform( - x_train[..., : args["input_dim"]] - ) - x_val[..., : args["input_dim"]] = scaler.transform(x_val[..., : args["input_dim"]]) - x_test[..., : args["input_dim"]] = scaler.transform( - x_test[..., : args["input_dim"]] + T, N, _ = data.shape + lag, horizon, batch_size, input_dim = ( + cfg["lag"], cfg["horizon"], cfg["batch_size"], cfg["input_dim"] ) - # add time features to Y - y_day = add_window_y(time_in_day, args["lag"], args["horizon"], single) - y_week = add_window_y(day_in_week, args["lag"], args["horizon"], single) - y = np.concatenate([y, y_day, y_week], axis=-1) - # del y_day, y_week, time_in_day, day_in_week - # gc.collect() - - # split Y time-augmented - if args["test_ratio"] > 1: - y_train, y_val, y_test = split_data_by_days( - y, args["val_ratio"], args["test_ratio"] - ) - else: - y_train, y_val, y_test = split_data_by_ratio( - y, args["val_ratio"], args["test_ratio"] - ) - # del y - - # Step 6: create dataloaders including cycle index - train_loader = data_loader_with_cycle( - x_train, y_train, c_train, args["batch_size"], shuffle=True, drop_last=True - ) - val_loader = data_loader_with_cycle( - x_val, y_val, c_val, args["batch_size"], shuffle=False, drop_last=True - ) - test_loader = data_loader_with_cycle( - x_test, y_test, c_test, args["batch_size"], shuffle=False, drop_last=False + # X / Y construction + X = window(data, lag, horizon) + Y = window( + data, + 1 if single_step else horizon, + horizon, + lag if not single_step else lag + horizon - 1 ) - return train_loader, val_loader, test_loader, scaler + # Time features + t = np.arange(T) + time_in_day = np.tile((t % cfg["steps_per_day"]) / cfg["steps_per_day"], (N, 1)).T + day_in_week = np.tile((t // cfg["steps_per_day"]) % cfg["days_per_week"], (N, 1)).T + tf = lambda z: window(z[..., None], lag, horizon) + X = np.concatenate([X, tf(time_in_day), tf(day_in_week)], -1) + Y = np.concatenate([Y, tf(time_in_day), tf(day_in_week)], -1) -def data_loader_with_cycle(X, Y, C, batch_size, shuffle=True, drop_last=True): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - X_t = torch.tensor(X, dtype=torch.float32, device=device) - Y_t = torch.tensor(Y, dtype=torch.float32, device=device) - C_t = torch.tensor(C, dtype=torch.long, device=device).unsqueeze(-1) # [B,1] - dataset = torch.utils.data.TensorDataset(X_t, Y_t, C_t) - loader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) - return loader + # Split + X_train, X_val, X_test = split_by_ratio(X, cfg["val_ratio"], cfg["test_ratio"]) + Y_train, Y_val, Y_test = split_by_ratio(Y, cfg["val_ratio"], cfg["test_ratio"]) - -def split_data_by_days(data, val_days, test_days, interval=30): - t = int((24 * 60) / interval) - test_data = data[-t * int(test_days) :] - val_data = data[-t * int(test_days + val_days) : -t * int(test_days)] - train_data = data[: -t * int(test_days + val_days)] - return train_data, val_data, test_data - - -def split_data_by_ratio(data, val_ratio, test_ratio): - data_len = data.shape[0] - test_data = data[-int(data_len * test_ratio) :] - val_data = data[ - -int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio) + # Channel-wise normalization (fit on train only) + scalers = [ + normalize_dataset(X_train[..., i:i+1], normalizer, cfg["column_wise"]) + for i in range(input_dim) ] - train_data = data[: -int(data_len * (test_ratio + val_ratio))] - return train_data, val_data, test_data + for i, sc in enumerate(scalers): + for d in (X_train, X_val, X_test, Y_train, Y_val, Y_test): + d[..., i:i+1] = sc.transform(d[..., i:i+1]) + # Padding + X_train, Y_train = pad_last(X_train, Y_train, batch_size) + X_val, Y_val = pad_last(X_val, Y_val, batch_size) + X_test, Y_test = pad_last(X_test, Y_test, batch_size) -def data_loader(X, Y, batch_size, shuffle=True, drop_last=True): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - X = torch.tensor(X, dtype=torch.float32, device=device) - Y = torch.tensor(Y, dtype=torch.float32, device=device) - data = torch.utils.data.TensorDataset(X, Y) - dataloader = torch.utils.data.DataLoader( - data, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + # DataLoader + make_loader = lambda X, Y, shuffle: torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(to_tensor(X), to_tensor(Y)), + batch_size=batch_size, shuffle=shuffle, drop_last=False ) - return dataloader - -def add_window_x(data, window=3, horizon=1, single=False): - """ - Generate windowed X values from the input data. - - :param data: Input data, shape [B, ...] - :param window: Size of the sliding window - :param horizon: Horizon size - :param single: If True, generate single-step windows, else multi-step - :return: X with shape [B, W, ...] - """ - length = len(data) - end_index = length - horizon - window + 1 - x = [] # Sliding windows - index = 0 - - while index < end_index: - x.append(data[index : index + window]) - index += 1 - - return np.array(x) - - -def add_window_y(data, window=3, horizon=1, single=False): - """ - Generate windowed Y values from the input data. - - :param data: Input data, shape [B, ...] - :param window: Size of the sliding window - :param horizon: Horizon size - :param single: If True, generate single-step windows, else multi-step - :return: Y with shape [B, H, ...] - """ - length = len(data) - end_index = length - horizon - window + 1 - y = [] # Horizon values - index = 0 - - while index < end_index: - if single: - y.append(data[index + window + horizon - 1 : index + window + horizon]) - else: - y.append(data[index + window : index + window + horizon]) - index += 1 - - return np.array(y) - - -if __name__ == "__main__": - res = load_st_dataset("SD", 1) - k = 1 + return ( + make_loader(X_train, Y_train, True), + make_loader(X_val, Y_val, False), + make_loader(X_test, Y_test, False), + scalers + ) diff --git a/dataloader/PeMSDdataloader.py b/dataloader/PeMSDdataloader.py index 0e079e1..18ebf61 100755 --- a/dataloader/PeMSDdataloader.py +++ b/dataloader/PeMSDdataloader.py @@ -1,158 +1,90 @@ import numpy as np import torch - from dataloader.data_selector import load_st_dataset from utils.normalization import normalize_dataset -def get_dataloader(args, normalizer="std", single=True): - data = load_st_dataset(args) - args = args["data"] - L, N, F = data.shape +_device = "cuda" if torch.cuda.is_available() else "cpu" +to_tensor = lambda a: torch.as_tensor(a, dtype=torch.float32, device=_device) - # Generate sliding windows for main data and add time features - x, y = _prepare_data_with_windows(data, args, single) - - # Split data - split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio - x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) - y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) +# Sliding window (stride trick, zero copy) +window = lambda d, w, h, o=0: np.lib.stride_tricks.as_strided( + d[o:], + shape=(len(d) - w - h + 1, w, *d.shape[1:]), + strides=(d.strides[0], d.strides[0], *d.strides[1:]) +) - # Normalize x and y using the same scaler - scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) - _apply_existing_scaler(y_train, y_val, y_test, scaler, args) +# pad_with_last_sample=True +pad_last = lambda X, Y, bs: ( + (lambda r: ( + (np.concatenate([X, np.repeat(X[-1:], r, 0)], 0), + np.concatenate([Y, np.repeat(Y[-1:], r, 0)], 0)) + if r else (X, Y) + ))((-len(X)) % bs) +) - # Create dataloaders - return ( - _create_dataloader(x_train, y_train, args["batch_size"], True, False), - _create_dataloader(x_val, y_val, args["batch_size"], False, False), - _create_dataloader(x_test, y_test, args["batch_size"], False, False), - scaler +# Train / Val / Test split +split_by_ratio = lambda d, vr, tr: ( + d[:-(vl := int(len(d) * (vr + tr)))], + d[-vl:-(tl := int(len(d) * tr))], + d[-tl:] +) + + +def get_dataloader(config, normalizer="std", single_step=True): + data = load_st_dataset(config) + cfg = config["data"] + + T, N, _ = data.shape + lag, horizon, batch_size, input_dim = ( + cfg["lag"], cfg["horizon"], cfg["batch_size"], cfg["input_dim"] ) + # X / Y construction + X = window(data, lag, horizon) + Y = window( + data, + 1 if single_step else horizon, + horizon, + lag if not single_step else lag + horizon - 1 + ) -def _prepare_data_with_windows(data, args, single): - # Generate sliding windows for main data - x = add_window_x(data, args["lag"], args["horizon"], single) - y = add_window_y(data, args["lag"], args["horizon"], single) + # Time features + t = np.arange(T) + time_in_day = np.tile((t % cfg["steps_per_day"]) / cfg["steps_per_day"], (N, 1)).T + day_in_week = np.tile((t // cfg["steps_per_day"]) % cfg["days_per_week"], (N, 1)).T + tf = lambda z: window(z[..., None], lag, horizon) - # Generate time features - time_features = _generate_time_features(data.shape[0], args) - - # Add time features to x and y - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - y = _add_time_features(y, time_features, args["lag"], args["horizon"], single, add_window_y) - - return x, y + X = np.concatenate([X, tf(time_in_day), tf(day_in_week)], -1) + Y = np.concatenate([Y, tf(time_in_day), tf(day_in_week)], -1) + # Split + X_train, X_val, X_test = split_by_ratio(X, cfg["val_ratio"], cfg["test_ratio"]) + Y_train, Y_val, Y_test = split_by_ratio(Y, cfg["val_ratio"], cfg["test_ratio"]) -def _generate_time_features(L, args): - N = args["num_nodes"] - time_in_day = [i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)] - time_in_day = np.tile(np.array(time_in_day), [1, N, 1]).transpose((2, 1, 0)) - - day_in_week = [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)] - day_in_week = np.tile(np.array(day_in_week), [1, N, 1]).transpose((2, 1, 0)) - - return time_in_day, day_in_week - - -def _add_time_features(data, time_features, lag, horizon, single, window_fn): - time_in_day, day_in_week = time_features - time_day = window_fn(time_in_day, lag, horizon, single) - time_week = window_fn(day_in_week, lag, horizon, single) - return np.concatenate([data, time_day, time_week], axis=-1) - - -def _normalize_data(train_data, val_data, test_data, args, normalizer): - scaler = normalize_dataset(train_data[..., : args["input_dim"]], normalizer, args["column_wise"]) - - for data in [train_data, val_data, test_data]: - data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) - - return scaler - - -def _apply_existing_scaler(train_data, val_data, test_data, scaler, args): - for data in [train_data, val_data, test_data]: - data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]]) - - -def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - X_tensor = torch.tensor(X_data, dtype=torch.float32, device=device) - Y_tensor = torch.tensor(Y_data, dtype=torch.float32, device=device) - dataset = torch.utils.data.TensorDataset(X_tensor, Y_tensor) - return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) - - -def split_data_by_days(data, val_days, test_days, interval=30): - t = int((24 * 60) / interval) - test_data = data[-t * int(test_days) :] - val_data = data[-t * int(test_days + val_days) : -t * int(test_days)] - train_data = data[: -t * int(test_days + val_days)] - return train_data, val_data, test_data - - -def split_data_by_ratio(data, val_ratio, test_ratio): - data_len = data.shape[0] - test_data = data[-int(data_len * test_ratio) :] - val_data = data[ - -int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio) + # Channel-wise normalization (fit on train only) + scalers = [ + normalize_dataset(X_train[..., i:i+1], normalizer, cfg["column_wise"]) + for i in range(input_dim) ] - train_data = data[: -int(data_len * (test_ratio + val_ratio))] - return train_data, val_data, test_data + for i, sc in enumerate(scalers): + for d in (X_train, X_val, X_test, Y_train, Y_val, Y_test): + d[..., i:i+1] = sc.transform(d[..., i:i+1]) + # Padding + X_train, Y_train = pad_last(X_train, Y_train, batch_size) + X_val, Y_val = pad_last(X_val, Y_val, batch_size) + X_test, Y_test = pad_last(X_test, Y_test, batch_size) + # DataLoader + make_loader = lambda X, Y, shuffle: torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(to_tensor(X), to_tensor(Y)), + batch_size=batch_size, shuffle=shuffle, drop_last=False + ) - -def _generate_windows(data, window=3, horizon=1, offset=0): - """ - Internal helper function to generate sliding windows. - - :param data: Input data - :param window: Window size - :param horizon: Horizon size - :param offset: Offset from window start - :return: Windowed data - """ - length = len(data) - end_index = length - horizon - window + 1 - windows = [] - index = 0 - - while index < end_index: - windows.append(data[index + offset : index + offset + window]) - index += 1 - - return np.array(windows) - -def add_window_x(data, window=3, horizon=1, single=False): - """ - Generate windowed X values from the input data. - - :param data: Input data, shape [B, ...] - :param window: Size of the sliding window - :param horizon: Horizon size - :param single: If True, generate single-step windows, else multi-step - :return: X with shape [B, W, ...] - """ - return _generate_windows(data, window, horizon, offset=0) - -def add_window_y(data, window=3, horizon=1, single=False): - """ - Generate windowed Y values from the input data. - - :param data: Input data, shape [B, ...] - :param window: Size of the sliding window - :param horizon: Horizon size - :param single: If True, generate single-step windows, else multi-step - :return: Y with shape [B, H, ...] - """ - offset = window if not single else window + horizon - 1 - return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset) - -# if __name__ == "__main__": -# from dataloader.data_selector import load_st_dataset -# res = load_st_dataset({"dataset": "SD"}) -# print(f"Dataset shape: {res.shape}") + return ( + make_loader(X_train, Y_train, True), + make_loader(X_val, Y_val, False), + make_loader(X_test, Y_test, False), + scalers + ) diff --git a/train.py b/train.py index 2d3a32f..83d056a 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 + epochs = 1 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -90,8 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] - # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] - main(model_list, dataset_list, debug = True) \ No newline at end of file + dataset_list = ["AirQuality"] + # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + main(model_list, dataset_list, debug = False) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 932d8b3..5ba71f2 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -22,7 +22,6 @@ class Trainer: train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): - self.config = args self.device = args["basic"]["device"] self.args = args["train"] diff --git a/trainer/Trainer.py b/trainer/Trainer.py index cdd444b..7c9aee0 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -3,6 +3,7 @@ from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics + class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, @@ -22,6 +23,16 @@ class Trainer: self.test_loader = test_loader self.scaler = scaler + # ===== 新增:统一反归一化接口(单 scaler / 多 scaler 通吃)===== + self.inv = ( + (lambda x: self.scaler.inverse_transform(x)) + if not isinstance(self.scaler, (list, tuple)) + else (lambda x: torch.cat( + [s.inverse_transform(x[..., i:i+1]) + for i, s in enumerate(self.scaler)], + dim=-1)) + ) + self._init_paths() self._init_logger() @@ -56,16 +67,14 @@ class Trainer: out = self.model(data) if os.environ.get("TRY") == "True": - if out.shape == label.shape: - print(f"shape true, out: {out.shape}, label: {label.shape}") - assert False - else: - print(f"shape false, out: {out.shape}, label: {label.shape}") - assert False + print(f"out: {out.shape}, label: {label.shape}") + assert False loss = self.loss(out, label) - d_out = self.scaler.inverse_transform(out) - d_lbl = self.scaler.inverse_transform(label) + + # ===== 修改点:反归一化 ===== + d_out = self.inv(out) + d_lbl = self.inv(label) d_loss = self.loss(d_out, d_lbl) total_loss += d_loss.item() @@ -120,12 +129,10 @@ class Trainer: if losses["val"] < best: best, patience = losses["val"], 0 best_w = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved") else: patience += 1 if self.args["early_stop"] and patience == self.args["early_stop_patience"]: - self.logger.info("Early stopping triggered") break if losses["test"] < best_test: @@ -154,14 +161,12 @@ class Trainer: for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - - out = self.model(data) - - y_pred.append(out.cpu()) + y_pred.append(self.model(data).cpu()) y_true.append(label.cpu()) - d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) - d_true = self.scaler.inverse_transform(torch.cat(y_true)) + # ===== 修改点:反归一化 ===== + d_pred = self.inv(torch.cat(y_pred)) + d_true = self.inv(torch.cat(y_true)) for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( @@ -172,9 +177,12 @@ class Trainer: self.logger.info( f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" ) - - avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + + avg_mae, avg_rmse, avg_mape = all_metrics( + d_pred, d_true, + self.args["mae_thresh"], + self.args["mape_thresh"] + ) self.logger.info( f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" ) - -- 2.40.1 From 5e52f23c8d6c4f4c1275d79076c052e28efd99f3 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 21:23:04 +0800 Subject: [PATCH 32/41] =?UTF-8?q?fix(config):=20=E4=BF=AE=E6=AD=A3?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6=E5=91=BD=E5=90=8D=E4=B8=8D?= =?UTF-8?q?=E4=B8=80=E8=87=B4=E9=97=AE=E9=A2=98=E5=B9=B6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor(trainer): 重构训练器代码,优化反归一化处理和形状转换逻辑 style(trainer): 简化代码格式,提高可读性 chore: 更新训练脚本中的模型和数据集列表 --- ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 train.py | 10 +- trainer/TSTrainer.py | 99 +++++++------ trainer/Trainer.py | 135 ++++-------------- 14 files changed, 85 insertions(+), 159 deletions(-) rename config/HI/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/HI/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/HI/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) rename config/Informer/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/Informer/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/Informer/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/Informer/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) rename config/iTransformer/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/iTransformer/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/iTransformer/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/iTransformer/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) diff --git a/config/HI/BJTaxi-Inflow.yaml b/config/HI/BJTaxi-InFlow.yaml similarity index 100% rename from config/HI/BJTaxi-Inflow.yaml rename to config/HI/BJTaxi-InFlow.yaml diff --git a/config/HI/NYCBike-Inflow.yaml b/config/HI/NYCBike-InFlow.yaml similarity index 100% rename from config/HI/NYCBike-Inflow.yaml rename to config/HI/NYCBike-InFlow.yaml diff --git a/config/HI/NYCBike-Outflow.yaml b/config/HI/NYCBike-OutFlow.yaml similarity index 100% rename from config/HI/NYCBike-Outflow.yaml rename to config/HI/NYCBike-OutFlow.yaml diff --git a/config/Informer/BJTaxi-Inflow.yaml b/config/Informer/BJTaxi-InFlow.yaml similarity index 100% rename from config/Informer/BJTaxi-Inflow.yaml rename to config/Informer/BJTaxi-InFlow.yaml diff --git a/config/Informer/BJTaxi-Outflow.yaml b/config/Informer/BJTaxi-OutFlow.yaml similarity index 100% rename from config/Informer/BJTaxi-Outflow.yaml rename to config/Informer/BJTaxi-OutFlow.yaml diff --git a/config/Informer/NYCBike-Inflow.yaml b/config/Informer/NYCBike-InFlow.yaml similarity index 100% rename from config/Informer/NYCBike-Inflow.yaml rename to config/Informer/NYCBike-InFlow.yaml diff --git a/config/Informer/NYCBike-Outflow.yaml b/config/Informer/NYCBike-OutFlow.yaml similarity index 100% rename from config/Informer/NYCBike-Outflow.yaml rename to config/Informer/NYCBike-OutFlow.yaml diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-InFlow.yaml similarity index 100% rename from config/iTransformer/BJTaxi-Inflow.yaml rename to config/iTransformer/BJTaxi-InFlow.yaml diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-OutFlow.yaml similarity index 100% rename from config/iTransformer/BJTaxi-Outflow.yaml rename to config/iTransformer/BJTaxi-OutFlow.yaml diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-InFlow.yaml similarity index 100% rename from config/iTransformer/NYCBike-Inflow.yaml rename to config/iTransformer/NYCBike-InFlow.yaml diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-OutFlow.yaml similarity index 100% rename from config/iTransformer/NYCBike-Outflow.yaml rename to config/iTransformer/NYCBike-OutFlow.yaml diff --git a/train.py b/train.py index 83d056a..b0b5af1 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -91,8 +91,8 @@ if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - model_list = ["MTGNN"] - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] - dataset_list = ["AirQuality"] + model_list = ["iTransformer"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # dataset_list = ["AirQuality"] # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] - main(model_list, dataset_list, debug = False) \ No newline at end of file + main(model_list, dataset_list, debug = True) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 5ba71f2..81ee54f 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -3,25 +3,13 @@ from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -class TSWrapper: - def __init__(self, args): - self.n = args['data']['num_nodes'] - - def forward(self, x): - # [b, t, n, c] -> [b*n, t, c] - b, t, n, c = x.shape - x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) - return x, b, t, n, c - - def inverse(self, x, b, t, n, c): - return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) - class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.config = args self.device = args["basic"]["device"] self.args = args["train"] @@ -35,7 +23,24 @@ class Trainer: self.test_loader = test_loader self.scaler = scaler - self.ts = TSWrapper(args) + # ---------- shape magic (replace TSWrapper) ---------- + self.pack = lambda x: ( + x[..., :-2] + .permute(0, 2, 1, 3) + .reshape(-1, x.size(1), x.size(3) - 2), + x.shape + ) + self.unpack = lambda y, s: ( + y.reshape(s[0], s[2], s[1], -1) + .permute(0, 2, 1, 3) + ) + + # ---------- inverse scaler ---------- + self.inv = lambda x: torch.cat( + [s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], + dim=-1 + ) + self._init_paths() self._init_logger() @@ -51,7 +56,7 @@ class Trainer: self.logger = get_logger( self.args["log_dir"], name=self.model.__class__.__name__, - debug=self.args["debug"], + debug=self.args["debug"] ) # ---------------- epoch ---------------- @@ -67,21 +72,17 @@ class Trainer: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, b, t, n, c = self.ts.forward(data) - out = self.model(x) - out = self.ts.inverse(out, b, t, n, c) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) if os.environ.get("TRY") == "True": - if out.shape == label.shape: - print("shape true") - assert False - else: - print("shape false") - assert False + print(f"out:{out.shape} label:{label.shape}", + "✅" if out.shape == label.shape else "❌") + assert False loss = self.loss(out, label) - d_out = self.scaler.inverse_transform(out) - d_lbl = self.scaler.inverse_transform(label) + + d_out, d_lbl = self.inv(out), self.inv(label) d_loss = self.loss(d_out, d_lbl) total_loss += d_loss.item() @@ -98,9 +99,7 @@ class Trainer: ) self.optimizer.step() - y_pred = torch.cat(y_pred) - y_true = torch.cat(y_true) - + y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) mae, rmse, mape = all_metrics( y_pred, y_true, self.args["mae_thresh"], @@ -110,23 +109,28 @@ class Trainer: self.logger.info( f"Epoch #{epoch:02d} {mode:<5} " f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + f"MAPE:{mape:7.4f} " + f"Time:{time.time() - start:.2f}s" ) + return total_loss / len(loader) # ---------------- train ---------------- def train(self): - best, best_test = float("inf"), float("inf") - best_w, best_test_w = None, None + best = best_test = float("inf") + best_w = best_test_w = None patience = 0 self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { - "train": self._run_epoch(epoch, self.train_loader, "train"), - "val": self._run_epoch(epoch, self.val_loader, "val"), - "test": self._run_epoch(epoch, self.test_loader, "test"), + k: self._run_epoch(epoch, l, k) + for k, l in [ + ("train", self.train_loader), + ("val", self.val_loader), + ("test", self.test_loader) + ] } if losses["train"] > 1e6: @@ -171,15 +175,14 @@ class Trainer: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, b, t, n, c = self.ts.forward(data) - out = self.model(x) - out = self.ts.inverse(out, b, t, n, c) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) y_pred.append(out.cpu()) y_true.append(label.cpu()) - d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) - d_true = self.scaler.inverse_transform(torch.cat(y_true)) + d_pred = self.inv(torch.cat(y_pred)) + d_true = self.inv(torch.cat(y_true)) for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( @@ -188,11 +191,15 @@ class Trainer: self.args["mape_thresh"] ) self.logger.info( - f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + f"Horizon {t+1:02d} " + f"MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" ) - - avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info( - f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" - ) + mae, rmse, mape = all_metrics( + d_pred, d_true, + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + self.logger.info( + f"AVG MAE:{mae:.4f} AVG RMSE:{rmse:.4f} AVG MAPE:{mape:.4f}" + ) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 7c9aee0..e0838b5 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -3,59 +3,29 @@ from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics - class Trainer: - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, - scaler, args, lr_scheduler=None): - - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - self.model = model.to(self.device) - self.loss = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - self.train_loader = train_loader - self.val_loader = val_loader or test_loader - self.test_loader = test_loader + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.device, self.args = args["basic"]["device"], args["train"] + self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler + self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader self.scaler = scaler - - # ===== 新增:统一反归一化接口(单 scaler / 多 scaler 通吃)===== - self.inv = ( - (lambda x: self.scaler.inverse_transform(x)) - if not isinstance(self.scaler, (list, tuple)) - else (lambda x: torch.cat( - [s.inverse_transform(x[..., i:i+1]) - for i, s in enumerate(self.scaler)], - dim=-1)) - ) - + self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat self._init_paths() self._init_logger() # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] - self.best_path = os.path.join(d, "best_model.pth") - self.best_test_path = os.path.join(d, "best_test_model.pth") + self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth") def _init_logger(self): - if not self.args["debug"]: - os.makedirs(self.args["log_dir"], exist_ok=True) - self.logger = get_logger( - self.args["log_dir"], - name=self.model.__class__.__name__, - debug=self.args["debug"], - ) + if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"]) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() - total_loss, start = 0.0, time.time() y_pred, y_true = [], [] @@ -63,20 +33,12 @@ class Trainer: for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - out = self.model(data) - - if os.environ.get("TRY") == "True": - print(f"out: {out.shape}, label: {label.shape}") - assert False - + if os.environ.get("TRY") == "True": print(f"out: {out.shape}, label: {label.shape} \ + {'✅' if out.shape == label.shape else '❌'}"); assert False loss = self.loss(out, label) - - # ===== 修改点:反归一化 ===== - d_out = self.inv(out) - d_lbl = self.inv(label) + d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) - total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) @@ -84,27 +46,12 @@ class Trainer: if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.args["max_grad_norm"] - ) + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() - y_pred = torch.cat(y_pred) - y_true = torch.cat(y_true) - - mae, rmse, mape = all_metrics( - y_pred, y_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - - self.logger.info( - f"Epoch #{epoch:02d} {mode:<5} " - f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" - ) + y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s") return total_loss / len(loader) # ---------------- train ---------------- @@ -112,37 +59,24 @@ class Trainer: best, best_test = float("inf"), float("inf") best_w, best_test_w = None, None patience = 0 - self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { "train": self._run_epoch(epoch, self.train_loader, "train"), - "val": self._run_epoch(epoch, self.val_loader, "val"), - "test": self._run_epoch(epoch, self.test_loader, "test"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), } - if losses["train"] > 1e6: - self.logger.warning("Gradient explosion detected") - break - - if losses["val"] < best: - best, patience = losses["val"], 0 - best_w = copy.deepcopy(self.model.state_dict()) - else: - patience += 1 - - if self.args["early_stop"] and patience == self.args["early_stop_patience"]: - break - - if losses["test"] < best_test: - best_test = losses["test"] - best_test_w = copy.deepcopy(self.model.state_dict()) + if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break + if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + else: patience += 1 + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break + if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) - self._final_test(best_w, best_test_w) # ---------------- final test ---------------- @@ -164,25 +98,10 @@ class Trainer: y_pred.append(self.model(data).cpu()) y_true.append(label.cpu()) - # ===== 修改点:反归一化 ===== - d_pred = self.inv(torch.cat(y_pred)) - d_true = self.inv(torch.cat(y_true)) - + d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 for t in range(d_true.shape[1]): - mae, rmse, mape = all_metrics( - d_pred[:, t], d_true[:, t], - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}") - avg_mae, avg_rmse, avg_mape = all_metrics( - d_pred, d_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" - ) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}") -- 2.40.1 From 659b41f6123a43b2513aabd325b35342df610f5f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 21:33:28 +0800 Subject: [PATCH 33/41] =?UTF-8?q?refactor(config):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6=E5=91=BD=E5=90=8D=E5=B9=B6?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新多个模型的配置文件命名格式,统一使用大驼峰格式 调整SolarEnergy和BJTaxi数据集的输入维度和批量大小 删除旧命名格式的配置文件并添加新的配置文件 修改训练脚本中的模型和数据集列表用于调试 --- ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 config/HI/SolarEnergy.yaml | 6 +- ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 config/PatchTST/SolarEnergy.yaml | 2 +- config/REPST/BJTaxi-InFlow.yaml | 5 +- config/REPST/BJTaxi-Inflow.yaml | 55 ------------------- ...{BJTaxi_Inflow.yaml => BJTaxi_InFlow.yaml} | 0 ...JTaxi_Outflow.yaml => BJTaxi_OutFlow.yaml} | 0 ...YCBike_Inflow.yaml => NYCBike_InFlow.yaml} | 0 ...Bike_Outflow.yaml => NYCBike_OutFlow.yaml} | 0 train.py | 8 +-- 16 files changed, 10 insertions(+), 66 deletions(-) rename config/HI/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/MTGNN/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/MTGNN/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/PatchTST/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/PatchTST/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/PatchTST/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/PatchTST/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) mode change 100644 => 100755 config/REPST/BJTaxi-InFlow.yaml delete mode 100755 config/REPST/BJTaxi-Inflow.yaml rename config/STID/{BJTaxi_Inflow.yaml => BJTaxi_InFlow.yaml} (100%) rename config/STID/{BJTaxi_Outflow.yaml => BJTaxi_OutFlow.yaml} (100%) rename config/STID/{NYCBike_Inflow.yaml => NYCBike_InFlow.yaml} (100%) rename config/STID/{NYCBike_Outflow.yaml => NYCBike_OutFlow.yaml} (100%) diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-OutFlow.yaml similarity index 100% rename from config/HI/BJTaxi-Outflow.yaml rename to config/HI/BJTaxi-OutFlow.yaml diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml index 1558d3a..aa07cf6 100644 --- a/config/HI/SolarEnergy.yaml +++ b/config/HI/SolarEnergy.yaml @@ -6,11 +6,11 @@ basic: seed: 2023 data: - batch_size: 512 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 @@ -25,7 +25,7 @@ model: train: - batch_size: 512 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/MTGNN/BJTaxi-Inflow.yaml b/config/MTGNN/BJTaxi-InFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Inflow.yaml rename to config/MTGNN/BJTaxi-InFlow.yaml diff --git a/config/MTGNN/BJTaxi-Outflow.yaml b/config/MTGNN/BJTaxi-OutFlow.yaml similarity index 100% rename from config/MTGNN/BJTaxi-Outflow.yaml rename to config/MTGNN/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-InFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Inflow.yaml rename to config/PatchTST/BJTaxi-InFlow.yaml diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-OutFlow.yaml similarity index 100% rename from config/PatchTST/BJTaxi-Outflow.yaml rename to config/PatchTST/BJTaxi-OutFlow.yaml diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-InFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Inflow.yaml rename to config/PatchTST/NYCBike-InFlow.yaml diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-OutFlow.yaml similarity index 100% rename from config/PatchTST/NYCBike-Outflow.yaml rename to config/PatchTST/NYCBike-OutFlow.yaml diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index d31a458..b6ca055 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/config/REPST/BJTaxi-InFlow.yaml b/config/REPST/BJTaxi-InFlow.yaml old mode 100644 new mode 100755 index 3191eba..e8a17fc --- a/config/REPST/BJTaxi-InFlow.yaml +++ b/config/REPST/BJTaxi-InFlow.yaml @@ -1,6 +1,6 @@ basic: dataset: BJTaxi-InFlow - device: cuda:1 + device: cuda:0 mode: train model: REPST seed: 2023 @@ -27,7 +27,6 @@ model: input_dim: 1 n_heads: 1 num_nodes: 1024 - output_dim: 1 patch_len: 6 pred_len: 24 seq_len: 24 @@ -41,7 +40,7 @@ train: early_stop_patience: 15 epochs: 100 grad_norm: false - log_step: 1000 + log_step: 100 loss_func: mae lr_decay: true lr_decay_rate: 0.3 diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml deleted file mode 100755 index e8a17fc..0000000 --- a/config/REPST/BJTaxi-Inflow.yaml +++ /dev/null @@ -1,55 +0,0 @@ -basic: - dataset: BJTaxi-InFlow - device: cuda:0 - mode: train - model: REPST - seed: 2023 - -data: - batch_size: 16 - column_wise: false - days_per_week: 7 - horizon: 24 - input_dim: 1 - lag: 24 - normalizer: std - num_nodes: 1024 - steps_per_day: 48 - test_ratio: 0.2 - val_ratio: 0.2 - -model: - d_ff: 128 - d_model: 64 - dropout: 0.2 - gpt_layers: 9 - gpt_path: ./GPT-2 - input_dim: 1 - n_heads: 1 - num_nodes: 1024 - patch_len: 6 - pred_len: 24 - seq_len: 24 - stride: 7 - word_num: 1000 - -train: - batch_size: 16 - debug: false - early_stop: true - early_stop_patience: 15 - epochs: 100 - grad_norm: false - log_step: 100 - loss_func: mae - lr_decay: true - lr_decay_rate: 0.3 - lr_decay_step: 5,20,40,70 - lr_init: 0.003 - mae_thresh: None - mape_thresh: 0.001 - max_grad_norm: 5 - output_dim: 1 - plot: false - real_value: true - weight_decay: 0 diff --git a/config/STID/BJTaxi_Inflow.yaml b/config/STID/BJTaxi_InFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Inflow.yaml rename to config/STID/BJTaxi_InFlow.yaml diff --git a/config/STID/BJTaxi_Outflow.yaml b/config/STID/BJTaxi_OutFlow.yaml similarity index 100% rename from config/STID/BJTaxi_Outflow.yaml rename to config/STID/BJTaxi_OutFlow.yaml diff --git a/config/STID/NYCBike_Inflow.yaml b/config/STID/NYCBike_InFlow.yaml similarity index 100% rename from config/STID/NYCBike_Inflow.yaml rename to config/STID/NYCBike_InFlow.yaml diff --git a/config/STID/NYCBike_Outflow.yaml b/config/STID/NYCBike_OutFlow.yaml similarity index 100% rename from config/STID/NYCBike_Outflow.yaml rename to config/STID/NYCBike_OutFlow.yaml diff --git a/train.py b/train.py index b0b5af1..fcaaa6a 100644 --- a/train.py +++ b/train.py @@ -89,10 +89,10 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - model_list = ["iTransformer"] - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # model_list = ["iTransformer"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] main(model_list, dataset_list, debug = True) \ No newline at end of file -- 2.40.1 From 85257bc61ca08edf619dde5c20e32e13bdea2939 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 16:40:40 +0800 Subject: [PATCH 34/41] =?UTF-8?q?refactor(trainer):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=99=A8=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E5=B9=B6=E6=B7=BB=E5=8A=A0=E8=BF=9B=E5=BA=A6=E6=9D=A1=E6=98=BE?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调整训练器代码结构,减少冗余代码,提高可读性 为训练过程添加tqdm进度条,实时显示loss信息 统一TRY环境变量的输出格式 简化日志记录和模型保存逻辑 --- config/ASTRA_v2/SolarEnergy.yaml | 4 +- config/GWN/METR-LA.yaml | 4 +- config/GWN/SolarEnergy.yaml | 4 +- config/MTGNN/SolarEnergy.yaml | 2 +- config/REPST/SolarEnergy.yaml | 4 +- train.py | 4 +- trainer/TSTrainer.py | 163 +++++++------------------------ trainer/Trainer.py | 8 +- 8 files changed, 52 insertions(+), 141 deletions(-) diff --git a/config/ASTRA_v2/SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml index 83a87c2..9b6a223 100644 --- a/config/ASTRA_v2/SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: word_num: 1000 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/METR-LA.yaml b/config/GWN/METR-LA.yaml index ef38574..fc93634 100644 --- a/config/GWN/METR-LA.yaml +++ b/config/GWN/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -40,7 +40,7 @@ model: supports: null train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/SolarEnergy.yaml b/config/GWN/SolarEnergy.yaml index cd1d043..4e572fa 100644 --- a/config/GWN/SolarEnergy.yaml +++ b/config/GWN/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -40,7 +40,7 @@ model: supports: null train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/MTGNN/SolarEnergy.yaml b/config/MTGNN/SolarEnergy.yaml index 2f60b8d..57e17c8 100644 --- a/config/MTGNN/SolarEnergy.yaml +++ b/config/MTGNN/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index dd4579e..a96e58a 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: word_num: 1000 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/train.py b/train.py index fcaaa6a..b5b42b5 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 100 # 训练轮数 + epochs = 1 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -91,7 +91,7 @@ if __name__ == "__main__": # 调试用 model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - # model_list = ["iTransformer"] + # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 81ee54f..3ddf361 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -5,86 +5,46 @@ from utils.loss_function import all_metrics class Trainer: - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, - scaler, args, lr_scheduler=None): - - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - self.model = model.to(self.device) - self.loss = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - self.train_loader = train_loader - self.val_loader = val_loader or test_loader - self.test_loader = test_loader + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.device, self.args = args["basic"]["device"], args["train"] + self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler + self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader self.scaler = scaler - - # ---------- shape magic (replace TSWrapper) ---------- - self.pack = lambda x: ( - x[..., :-2] - .permute(0, 2, 1, 3) - .reshape(-1, x.size(1), x.size(3) - 2), - x.shape - ) - self.unpack = lambda y, s: ( - y.reshape(s[0], s[2], s[1], -1) - .permute(0, 2, 1, 3) - ) - - # ---------- inverse scaler ---------- - self.inv = lambda x: torch.cat( - [s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], - dim=-1 - ) - + self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat self._init_paths() self._init_logger() + # ---------- shape magic (replace TSWrapper) ---------- + self.pack = lambda x:(x[..., :-2].permute(0, 2, 1, 3).reshape(-1, x.size(1), x.size(3) - 2), x.shape) + self.unpack = lambda y, s: (y.reshape(s[0], s[2], s[1], -1).permute(0, 2, 1, 3)) # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] - self.best_path = os.path.join(d, "best_model.pth") - self.best_test_path = os.path.join(d, "best_test_model.pth") + self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth") def _init_logger(self): - if not self.args["debug"]: - os.makedirs(self.args["log_dir"], exist_ok=True) - self.logger = get_logger( - self.args["log_dir"], - name=self.model.__class__.__name__, - debug=self.args["debug"] - ) + if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"]) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() - total_loss, start = 0.0, time.time() y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): - for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)) + for data, target in bar: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, shp = self.pack(data) out = self.unpack(self.model(x), shp) - - if os.environ.get("TRY") == "True": - print(f"out:{out.shape} label:{label.shape}", - "✅" if out.shape == label.shape else "❌") - assert False - + if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else '❌'} " + f"out: {out.shape}, label: {label.shape} \n"); assert False loss = self.loss(out, label) - - d_out, d_lbl = self.inv(out), self.inv(label) + d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) - total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) @@ -92,70 +52,38 @@ class Trainer: if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.args["max_grad_norm"] - ) + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() + bar.set_postfix({"loss": f"{d_loss.item():.4f}"}) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) - mae, rmse, mape = all_metrics( - y_pred, y_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - - self.logger.info( - f"Epoch #{epoch:02d} {mode:<5} " - f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} " - f"Time:{time.time() - start:.2f}s" - ) - + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s") return total_loss / len(loader) # ---------------- train ---------------- def train(self): - best = best_test = float("inf") - best_w = best_test_w = None + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None patience = 0 - self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { - k: self._run_epoch(epoch, l, k) - for k, l in [ - ("train", self.train_loader), - ("val", self.val_loader), - ("test", self.test_loader) - ] + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), } - if losses["train"] > 1e6: - self.logger.warning("Gradient explosion detected") - break - - if losses["val"] < best: - best, patience = losses["val"], 0 - best_w = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved") - else: - patience += 1 - - if self.args["early_stop"] and patience == self.args["early_stop_patience"]: - self.logger.info("Early stopping triggered") - break - - if losses["test"] < best_test: - best_test = losses["test"] - best_test_w = copy.deepcopy(self.model.state_dict()) + if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break + if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + else: patience += 1 + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break + if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) - self._final_test(best_w, best_test_w) # ---------------- final test ---------------- @@ -174,32 +102,13 @@ class Trainer: for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - - x, shp = self.pack(data) - out = self.unpack(self.model(x), shp) - - y_pred.append(out.cpu()) + y_pred.append(self.model(data).cpu()) y_true.append(label.cpu()) - d_pred = self.inv(torch.cat(y_pred)) - d_true = self.inv(torch.cat(y_true)) - + d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 for t in range(d_true.shape[1]): - mae, rmse, mape = all_metrics( - d_pred[:, t], d_true[:, t], - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"Horizon {t+1:02d} " - f"MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}") - mae, rmse, mape = all_metrics( - d_pred, d_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"AVG MAE:{mae:.4f} AVG RMSE:{rmse:.4f} AVG MAPE:{mape:.4f}" - ) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}") diff --git a/trainer/Trainer.py b/trainer/Trainer.py index e0838b5..7036e26 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -30,12 +30,13 @@ class Trainer: y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): - for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)) + for data, target in bar: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] out = self.model(data) - if os.environ.get("TRY") == "True": print(f"out: {out.shape}, label: {label.shape} \ - {'✅' if out.shape == label.shape else '❌'}"); assert False + if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else '❌'} " + f"out: {out.shape}, label: {label.shape} \n"); assert False loss = self.loss(out, label) d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) @@ -48,6 +49,7 @@ class Trainer: loss.backward() if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() + bar.set_postfix({"loss": f"{d_loss.item():.4f}"}) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) -- 2.40.1 From b38e4a5da2790c7fdb463ea5350285bca1a2034f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 17:31:57 +0800 Subject: [PATCH 35/41] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dastra=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/ASTRA/AirQuality.yaml | 1 + config/ASTRA_v2/AirQuality.yaml | 1 + config/ASTRA_v3/AirQuality.yaml | 1 + config/STAEFormer/SolarEnergy.yaml | 2 +- model/ASTRA/astra.py | 83 +++++++++++------------------- model/ASTRA/astrav2.py | 10 ++-- model/ASTRA/astrav3.py | 10 ++-- 7 files changed, 43 insertions(+), 65 deletions(-) diff --git a/config/ASTRA/AirQuality.yaml b/config/ASTRA/AirQuality.yaml index 455fc4b..7d4868e 100644 --- a/config/ASTRA/AirQuality.yaml +++ b/config/ASTRA/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v2/AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml index 10796d2..ed22962 100644 --- a/config/ASTRA_v2/AirQuality.yaml +++ b/config/ASTRA_v2/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml index 68e6acc..d4cb947 100644 --- a/config/ASTRA_v3/AirQuality.yaml +++ b/config/ASTRA_v3/AirQuality.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 6 train: batch_size: 16 diff --git a/config/STAEFormer/SolarEnergy.yaml b/config/STAEFormer/SolarEnergy.yaml index c1151ca..a3fed30 100644 --- a/config/STAEFormer/SolarEnergy.yaml +++ b/config/STAEFormer/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 137 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 0ed2333..71d4ee9 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -7,22 +7,15 @@ from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer import torch.nn.functional as F class DynamicGraphEnhancer(nn.Module): - """ - 动态图增强器,基于节点嵌入自动生成图结构 - 参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵 - """ + """动态图增强编码器""" def __init__(self, num_nodes, in_dim, embed_dim=10): super().__init__() - self.num_nodes = num_nodes - self.embed_dim = embed_dim + self.num_nodes = num_nodes # 节点个数 + self.embed_dim = embed_dim # 节点嵌入维度 - # 节点嵌入参数 - self.node_embeddings = nn.Parameter( - torch.randn(num_nodes, embed_dim), requires_grad=True - ) + self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True) # 节点嵌入参数 - # 特征转换层,用于生成动态调整的嵌入 - self.feature_transform = nn.Sequential( + self.feature_transform = nn.Sequential( # 特征转换网络 nn.Linear(in_dim, 16), nn.Sigmoid(), nn.Linear(16, 2), @@ -30,48 +23,29 @@ class DynamicGraphEnhancer(nn.Module): nn.Linear(2, embed_dim) ) - # 注册单位矩阵作为固定的支持矩阵 - self.register_buffer("eye", torch.eye(num_nodes)) + self.register_buffer("eye", torch.eye(num_nodes)) # 注册单位矩阵 def get_laplacian(self, graph, I, normalize=True): - """ - 计算归一化拉普拉斯矩阵 - """ - # 计算度矩阵的逆平方根 - D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) + D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) # 度矩阵的逆平方根 D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题 - if normalize: - return torch.matmul(torch.matmul(D_inv, graph), D_inv) + return torch.matmul(torch.matmul(D_inv, graph), D_inv) # 归一化拉普拉斯矩阵 else: - return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) + return torch.matmul(torch.matmul(D_inv, graph + I), D_inv) # 带自环的归一化拉普拉斯矩阵 def forward(self, X): - """ - X: 输入特征 [B, N, D] - 返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N] - """ - batch_size = X.size(0) - laplacians = [] - - # 获取单位矩阵 - I = self.eye.to(X.device) + """生成动态拉普拉斯矩阵""" + batch_size = X.size(0) # 批次大小 + laplacians = [] # 存储各批次的拉普拉斯矩阵 + I = self.eye.to(X.device) # 移动单位矩阵到目标设备 for b in range(batch_size): - # 使用特征转换层生成动态嵌入调整因子 - filt = self.feature_transform(X[b]) # [N, embed_dim] - - # 计算节点嵌入向量 - nodevec = torch.tanh(self.node_embeddings * filt) - - # 通过节点嵌入的点积计算邻接矩阵 - adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) - - # 计算归一化拉普拉斯矩阵 - laplacian = self.get_laplacian(adj, I) + filt = self.feature_transform(X[b]) # 特征转换 + nodevec = torch.tanh(self.node_embeddings * filt) # 计算节点嵌入 + adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1))) # 计算邻接矩阵 + laplacian = self.get_laplacian(adj, I) # 计算拉普拉斯矩阵 laplacians.append(laplacian) - - return torch.stack(laplacians, dim=0) + return torch.stack(laplacians, dim=0) # 堆叠并返回 class GraphEnhancedEncoder(nn.Module): """ @@ -190,8 +164,8 @@ class ASTRA(nn.Module): # 添加动态图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), - in_dim=self.d_model, - hidden_dim=configs.get('graph_hidden_dim', 32), + in_dim=self.d_model * self.input_dim, + hidden_dim=self.d_model, num_nodes=self.num_nodes, embed_dim=configs.get('graph_embed_dim', 10), device=self.device @@ -199,14 +173,14 @@ class ASTRA(nn.Module): # 特征融合层 self.feature_fusion = nn.Linear( - self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), + self.d_model * self.input_dim + self.d_model * (configs.get('chebyshev_order', 3) + 1), self.d_model ) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): @@ -229,9 +203,9 @@ class ASTRA(nn.Module): x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') # 原版Patch - enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) + enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim) # 应用图增强编码器(自动生成图结构) - graph_enhanced = self.graph_encoder(enc_out) + graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim) # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = self.feature_fusion(enc_out) @@ -243,9 +217,10 @@ class ASTRA(nn.Module): enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) + dec_out = self.out_mlp(enc_out) #[B, N, T*C] + + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 6a47206..22e25b9 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -128,6 +128,7 @@ class ASTRA(nn.Module): self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -169,7 +170,7 @@ class ASTRA(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) # 设置参数可训练性 wps=word position embeddings @@ -202,9 +203,8 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(enc_out) # [B,N,pred_len] # 维度调整 - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, self.input_dim) - outputs = outputs.permute(0,2,1,3) + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs \ No newline at end of file diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 0e9aebf..59fc11d 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -128,6 +128,7 @@ class ASTRA(nn.Module): self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 self.num_nodes = configs.get('num_nodes', 325) # 节点数量 + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -169,7 +170,7 @@ class ASTRA(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) # 设置参数可训练性 wps=word position embeddings @@ -203,9 +204,8 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, self.input_dim) - outputs = outputs.permute(0,2,1,3) + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs \ No newline at end of file -- 2.40.1 From b97111f5ea1db1e2c14f3cefaef438b6c2868815 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 17:38:37 +0800 Subject: [PATCH 36/41] =?UTF-8?q?=E7=BB=9F=E4=B8=80REPST:AirQuality?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/AirQuality.yaml | 2 +- train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index c035a44..ee690f1 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -50,7 +50,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 3 + output_dim: 6 plot: false real_value: true weight_decay: 0 diff --git a/train.py b/train.py index b5b42b5..da6c058 100644 --- a/train.py +++ b/train.py @@ -89,8 +89,8 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 - model_list = ["iTransformer", "PatchTST", "HI"] - # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"] # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] -- 2.40.1 From 1a13a32688503dabc99bb39cffc0da695bd2775d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 21:47:17 +0800 Subject: [PATCH 37/41] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E3=80=82ASTRA=20v2=20or=20v3=E4=BD=BF=E7=94=A8=E7=A1=AC?= =?UTF-8?q?=E5=8F=82=E6=95=B0=EF=BC=8C=E7=A1=AE=E4=BF=9D=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=AE=8C=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/ASTRA/BJTaxi-InFlow.yaml | 1 + config/ASTRA/BJTaxi-OutFlow.yaml | 3 ++- config/ASTRA/METR-LA.yaml | 1 + config/ASTRA/NYCBike-InFlow.yaml | 3 ++- config/ASTRA/NYCBike-OutFlow.yaml | 1 + config/ASTRA/PEMS-BAY.yaml | 1 + config/ASTRA/SolarEnergy.yaml | 1 + config/ASTRA_v2/AirQuality.yaml | 3 +++ config/ASTRA_v2/BJTaxi-InFlow.yaml | 4 ++++ config/ASTRA_v2/BJTaxi-OutFlow.yaml | 4 ++++ config/ASTRA_v2/METR-LA.yaml | 4 ++++ config/ASTRA_v2/NYCBike-InFlow.yaml | 4 ++++ config/ASTRA_v2/NYCBike-OutFlow.yaml | 4 ++++ config/ASTRA_v2/PEMS-BAY.yaml | 4 ++++ config/ASTRA_v2/SolarEnergy.yaml | 4 ++++ config/ASTRA_v3/AirQuality.yaml | 4 ++++ config/ASTRA_v3/BJTaxi-InFlow.yaml | 4 ++++ config/ASTRA_v3/BJTaxi-OutFlow.yaml | 4 ++++ config/ASTRA_v3/METR-LA.yaml | 4 ++++ config/ASTRA_v3/NYCBike-InFlow.yaml | 4 ++++ config/ASTRA_v3/NYCBike-OutFlow.yaml | 4 ++++ config/ASTRA_v3/PEMS-BAY.yaml | 4 ++++ config/ASTRA_v3/SolarEnergy.yaml | 8 ++++++-- model/ASTRA/astrav2.py | 15 +++++++++------ model/ASTRA/astrav3.py | 17 ++++++++++------- model/REPST/repst.py | 2 +- train.py | 2 +- 27 files changed, 95 insertions(+), 19 deletions(-) diff --git a/config/ASTRA/BJTaxi-InFlow.yaml b/config/ASTRA/BJTaxi-InFlow.yaml index c2766bb..8569919 100644 --- a/config/ASTRA/BJTaxi-InFlow.yaml +++ b/config/ASTRA/BJTaxi-InFlow.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA/BJTaxi-OutFlow.yaml b/config/ASTRA/BJTaxi-OutFlow.yaml index ee570f3..d8f0e5d 100644 --- a/config/ASTRA/BJTaxi-OutFlow.yaml +++ b/config/ASTRA/BJTaxi-OutFlow.yaml @@ -17,7 +17,8 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 - + output_dim: 1 + model: d_ff: 128 d_model: 64 diff --git a/config/ASTRA/METR-LA.yaml b/config/ASTRA/METR-LA.yaml index 87bf1ac..3ae73ec 100644 --- a/config/ASTRA/METR-LA.yaml +++ b/config/ASTRA/METR-LA.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA/NYCBike-InFlow.yaml b/config/ASTRA/NYCBike-InFlow.yaml index 1c80773..0099f8f 100644 --- a/config/ASTRA/NYCBike-InFlow.yaml +++ b/config/ASTRA/NYCBike-InFlow.yaml @@ -32,7 +32,8 @@ model: seq_len: 24 stride: 7 word_num: 1000 - + output_dim: 1 + train: batch_size: 32 debug: false diff --git a/config/ASTRA/NYCBike-OutFlow.yaml b/config/ASTRA/NYCBike-OutFlow.yaml index 1ece121..f46cece 100644 --- a/config/ASTRA/NYCBike-OutFlow.yaml +++ b/config/ASTRA/NYCBike-OutFlow.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA/PEMS-BAY.yaml b/config/ASTRA/PEMS-BAY.yaml index e111654..2b2384d 100755 --- a/config/ASTRA/PEMS-BAY.yaml +++ b/config/ASTRA/PEMS-BAY.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA/SolarEnergy.yaml b/config/ASTRA/SolarEnergy.yaml index 4160077..dd64d64 100644 --- a/config/ASTRA/SolarEnergy.yaml +++ b/config/ASTRA/SolarEnergy.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 64 diff --git a/config/ASTRA_v2/AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml index ed22962..9073676 100644 --- a/config/ASTRA_v2/AirQuality.yaml +++ b/config/ASTRA_v2/AirQuality.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -33,6 +34,8 @@ model: stride: 7 word_num: 1000 output_dim: 6 + graph_dim: 64 + graph_embed_dim: 10 train: batch_size: 16 diff --git a/config/ASTRA_v2/BJTaxi-InFlow.yaml b/config/ASTRA_v2/BJTaxi-InFlow.yaml index d1cc5ea..5968cca 100644 --- a/config/ASTRA_v2/BJTaxi-InFlow.yaml +++ b/config/ASTRA_v2/BJTaxi-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/BJTaxi-OutFlow.yaml b/config/ASTRA_v2/BJTaxi-OutFlow.yaml index d6e0723..03859eb 100644 --- a/config/ASTRA_v2/BJTaxi-OutFlow.yaml +++ b/config/ASTRA_v2/BJTaxi-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/METR-LA.yaml b/config/ASTRA_v2/METR-LA.yaml index dca4bb4..db6e3a8 100644 --- a/config/ASTRA_v2/METR-LA.yaml +++ b/config/ASTRA_v2/METR-LA.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v2/NYCBike-InFlow.yaml b/config/ASTRA_v2/NYCBike-InFlow.yaml index de5b6a1..caeccb7 100644 --- a/config/ASTRA_v2/NYCBike-InFlow.yaml +++ b/config/ASTRA_v2/NYCBike-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/NYCBike-OutFlow.yaml b/config/ASTRA_v2/NYCBike-OutFlow.yaml index dda718d..a586f9a 100644 --- a/config/ASTRA_v2/NYCBike-OutFlow.yaml +++ b/config/ASTRA_v2/NYCBike-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/PEMS-BAY.yaml b/config/ASTRA_v2/PEMS-BAY.yaml index 2f6dfbf..2705006 100755 --- a/config/ASTRA_v2/PEMS-BAY.yaml +++ b/config/ASTRA_v2/PEMS-BAY.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v2/SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml index 9b6a223..f6405a5 100644 --- a/config/ASTRA_v2/SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml index d4cb947..c4481c0 100644 --- a/config/ASTRA_v3/AirQuality.yaml +++ b/config/ASTRA_v3/AirQuality.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -33,6 +34,9 @@ model: stride: 7 word_num: 1000 output_dim: 6 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v3/BJTaxi-InFlow.yaml b/config/ASTRA_v3/BJTaxi-InFlow.yaml index 34abfd8..bb09013 100644 --- a/config/ASTRA_v3/BJTaxi-InFlow.yaml +++ b/config/ASTRA_v3/BJTaxi-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/BJTaxi-OutFlow.yaml b/config/ASTRA_v3/BJTaxi-OutFlow.yaml index 8e6b30d..0b4e8df 100644 --- a/config/ASTRA_v3/BJTaxi-OutFlow.yaml +++ b/config/ASTRA_v3/BJTaxi-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/METR-LA.yaml b/config/ASTRA_v3/METR-LA.yaml index 2b5512b..5efa494 100644 --- a/config/ASTRA_v3/METR-LA.yaml +++ b/config/ASTRA_v3/METR-LA.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/NYCBike-InFlow.yaml b/config/ASTRA_v3/NYCBike-InFlow.yaml index 18c4fa3..52008cc 100644 --- a/config/ASTRA_v3/NYCBike-InFlow.yaml +++ b/config/ASTRA_v3/NYCBike-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/NYCBike-OutFlow.yaml b/config/ASTRA_v3/NYCBike-OutFlow.yaml index ff73662..0977912 100644 --- a/config/ASTRA_v3/NYCBike-OutFlow.yaml +++ b/config/ASTRA_v3/NYCBike-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/PEMS-BAY.yaml b/config/ASTRA_v3/PEMS-BAY.yaml index 6739aeb..9ff0fd0 100755 --- a/config/ASTRA_v3/PEMS-BAY.yaml +++ b/config/ASTRA_v3/PEMS-BAY.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/SolarEnergy.yaml b/config/ASTRA_v3/SolarEnergy.yaml index 289b839..c3f8863 100644 --- a/config/ASTRA_v3/SolarEnergy.yaml +++ b/config/ASTRA_v3/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,9 +33,12 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 22e25b9..f18ac90 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -127,8 +127,11 @@ class ASTRA(nn.Module): self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 - self.num_nodes = configs.get('num_nodes', 325) # 节点数量 - self.output_dim = configs.get('output_dim', 1) + self.num_nodes = configs['num_nodes'] # 节点数量 + self.output_dim = configs['output_dim'] + self.cheb = configs['cheb'] + self.graph_dim = configs['graph_dim'] + self.graph_embed_dim = configs['graph_embed_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -152,18 +155,18 @@ class ASTRA(nn.Module): # 初始化图增强编码器 self.graph_encoder = GraphEnhancedEncoder( - K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 + K=self.cheb, # Chebyshev多项式阶数 in_dim=self.d_model, # 输入特征维度 - hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + hidden_dim=self.graph_dim, # 隐藏层维度 num_nodes=self.num_nodes, # 节点数量 - embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + embed_dim=self.graph_embed_dim, # 节点嵌入维度 device=self.device, # 运行设备 temporal_dim=self.seq_len, # 时间序列长度 num_features=self.input_dim # 特征通道数 ) self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 - configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 + self.graph_dim * (self.cheb + 1), # 输入维度 self.d_model # 输出维度 ) diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 59fc11d..7f4317c 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -127,8 +127,11 @@ class ASTRA(nn.Module): self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 - self.num_nodes = configs.get('num_nodes', 325) # 节点数量 - self.output_dim = configs.get('output_dim', 1) + self.num_nodes = configs['num_nodes'] # 节点数量 + self.output_dim = configs['output_dim'] + self.cheb = configs['cheb'] + self.graph_dim = configs['graph_dim'] + self.graph_embed_dim = configs['graph_embed_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -148,23 +151,23 @@ class ASTRA(nn.Module): self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重 self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小 self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层 - self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 + self.reprogramming_layer = ReprogrammingLayer(self.d_model + self.graph_dim * (self.cheb + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 # 初始化图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 in_dim=self.d_model, # 输入特征维度 - hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + hidden_dim=self.graph_dim, # 隐藏层维度 num_nodes=self.num_nodes, # 节点数量 - embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + embed_dim=self.graph_embed_dim, # 节点嵌入维度 device=self.device, # 运行设备 temporal_dim=self.seq_len, # 时间序列长度 num_features=self.input_dim # 特征通道数 ) self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 - configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 - self.d_model # 输出维度 + self.graph_dim * (self.cheb + 1), # 输入维度 + self.d_model # 输出维度 ) self.out_mlp = nn.Sequential( diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 5b709a4..9afbda1 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -19,7 +19,7 @@ class repst(nn.Module): self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] - self.output_dim = configs.get('output_dim', 1) + self.output_dim = configs['output_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) diff --git a/train.py b/train.py index da6c058..e9db08b 100644 --- a/train.py +++ b/train.py @@ -90,7 +90,7 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v3"] # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] -- 2.40.1 From 21bc05e7636a5f39a9acaea8b802f94aa7c9b3e8 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 17 Dec 2025 17:13:16 +0800 Subject: [PATCH 38/41] =?UTF-8?q?=E4=BF=AE=E5=A4=8DTSLoader=20Bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/PatchTST/BJTaxi-InFlow.yaml | 4 +- config/PatchTST/BJTaxi-OutFlow.yaml | 4 +- config/STNorm/AirQuality.yaml | 64 +++++++++++++ config/STNorm/BJTaxi-InFlow.yaml | 64 +++++++++++++ config/STNorm/BJTaxi-OutFlow.yaml | 64 +++++++++++++ config/STNorm/METR-LA.yaml | 52 +++++++++++ config/STNorm/NYCBike-InFlow.yaml | 64 +++++++++++++ config/STNorm/NYCBike-OutFlow.yaml | 64 +++++++++++++ config/STNorm/PEMS-BAY.yaml | 64 +++++++++++++ config/STNorm/SolarEnergy.yaml | 64 +++++++++++++ model/STNorm/STNorm.py | 140 ++++++++++++++++++++++++++++ model/STNorm/model_config.json | 7 ++ train.py | 10 +- trainer/TSTrainer.py | 4 +- 14 files changed, 659 insertions(+), 10 deletions(-) create mode 100644 config/STNorm/AirQuality.yaml create mode 100644 config/STNorm/BJTaxi-InFlow.yaml create mode 100644 config/STNorm/BJTaxi-OutFlow.yaml create mode 100644 config/STNorm/METR-LA.yaml create mode 100644 config/STNorm/NYCBike-InFlow.yaml create mode 100644 config/STNorm/NYCBike-OutFlow.yaml create mode 100644 config/STNorm/PEMS-BAY.yaml create mode 100644 config/STNorm/SolarEnergy.yaml create mode 100644 model/STNorm/STNorm.py create mode 100644 model/STNorm/model_config.json diff --git a/config/PatchTST/BJTaxi-InFlow.yaml b/config/PatchTST/BJTaxi-InFlow.yaml index a4e0308..95ad0b1 100644 --- a/config/PatchTST/BJTaxi-InFlow.yaml +++ b/config/PatchTST/BJTaxi-InFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-OutFlow.yaml b/config/PatchTST/BJTaxi-OutFlow.yaml index 68c8476..f416372 100644 --- a/config/PatchTST/BJTaxi-OutFlow.yaml +++ b/config/PatchTST/BJTaxi-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/STNorm/AirQuality.yaml b/config/STNorm/AirQuality.yaml new file mode 100644 index 0000000..9846895 --- /dev/null +++ b/config/STNorm/AirQuality.yaml @@ -0,0 +1,64 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 35 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 6 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 6 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/BJTaxi-InFlow.yaml b/config/STNorm/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..09e453a --- /dev/null +++ b/config/STNorm/BJTaxi-InFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/BJTaxi-OutFlow.yaml b/config/STNorm/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..1b62a4e --- /dev/null +++ b/config/STNorm/BJTaxi-OutFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/METR-LA.yaml b/config/STNorm/METR-LA.yaml new file mode 100644 index 0000000..6f118f2 --- /dev/null +++ b/config/STNorm/METR-LA.yaml @@ -0,0 +1,52 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: STNorm + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 207 + in_dim: 1 + out_dim: 24 + channels: 32 + kernel_size: 2 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/NYCBike-InFlow.yaml b/config/STNorm/NYCBike-InFlow.yaml new file mode 100644 index 0000000..95ae41b --- /dev/null +++ b/config/STNorm/NYCBike-InFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/NYCBike-OutFlow.yaml b/config/STNorm/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..b1646ea --- /dev/null +++ b/config/STNorm/NYCBike-OutFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/PEMS-BAY.yaml b/config/STNorm/PEMS-BAY.yaml new file mode 100644 index 0000000..7f28aca --- /dev/null +++ b/config/STNorm/PEMS-BAY.yaml @@ -0,0 +1,64 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 325 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/SolarEnergy.yaml b/config/STNorm/SolarEnergy.yaml new file mode 100644 index 0000000..57e17c8 --- /dev/null +++ b/config/STNorm/SolarEnergy.yaml @@ -0,0 +1,64 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 137 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/model/STNorm/STNorm.py b/model/STNorm/STNorm.py new file mode 100644 index 0000000..71e7118 --- /dev/null +++ b/model/STNorm/STNorm.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SNorm(nn.Module): + def __init__(self, channels): + super().__init__() + self.beta = nn.Parameter(torch.zeros(channels)) + self.gamma = nn.Parameter(torch.ones(channels)) + + def forward(self, x): + x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5 + return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1) + +class TNorm(nn.Module): + def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1): + super().__init__() + self.track_running_stats = track_running_stats + self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) + self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) + self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1)) + self.momentum = momentum + + def forward(self, x): + if self.track_running_stats: + mean = x.mean((0, 3), keepdims=True) + var = x.var((0, 3), keepdims=True, unbiased=False) + if self.training: + n = x.shape[3] * x.shape[0] + with torch.no_grad(): + self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean + self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var + else: + mean = self.running_mean + var = self.running_var + else: + mean = x.mean(3, keepdims=True) + var = x.var(3, keepdims=True, unbiased=True) + x_norm = (x - mean) / (var + 1e-5) ** 0.5 + return x_norm * self.gamma + self.beta + +class stnorm(nn.Module): + def __init__(self, args): + super().__init__() + self.dropout = args["dropout"] + self.blocks = args["blocks"] + self.layers = args["layers"] + self.snorm_bool = args["snorm_bool"] + self.tnorm_bool = args["tnorm_bool"] + self.num_nodes = args["num_nodes"] + in_dim = args["in_dim"] + out_dim = args["out_dim"] + channels = args["channels"] + kernel_size = args["kernel_size"] + + # 初始化卷积层 + self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1)) + + # 初始化模块列表 + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.sn = nn.ModuleList() if self.snorm_bool else None + self.tn = nn.ModuleList() if self.tnorm_bool else None + + # 计算感受野 + self.receptive_field = 1 + additional_scope = kernel_size - 1 + + # 构建网络层 + for b in range(self.blocks): + new_dilation = 1 + for i in range(self.layers): + if self.tnorm_bool: + self.tn.append(TNorm(self.num_nodes, channels)) + if self.snorm_bool: + self.sn.append(SNorm(channels)) + + # 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征 + self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + + # 残差连接和跳跃连接 + self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + + # 更新感受野 + self.receptive_field += additional_scope + additional_scope *= 2 + new_dilation *= 2 + + # 输出层 + self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True) + self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) + + def forward(self, input): + # 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps) + x = input[..., 0:1].transpose(1, 3) + + # 处理感受野 + in_len = x.size(3) + if in_len < self.receptive_field: + x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) + + # 起始卷积 + x = self.start_conv(x) + skip = 0 + + # WaveNet层 + for i in range(self.blocks * self.layers): + residual = x + + # 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接) + x_norm = x + if self.tnorm_bool: + x_norm += self.tn[i](x) + if self.snorm_bool: + x_norm += self.sn[i](x) + + # 膨胀卷积 + filter = torch.tanh(self.filter_convs[i](x_norm)) + gate = torch.sigmoid(self.gate_convs[i](x_norm)) + x = filter * gate + + # 跳跃连接 + s = self.skip_convs[i](x) + skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + + # 残差连接 + x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):] + + # 输出处理 + x = F.relu(skip) + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) + return x \ No newline at end of file diff --git a/model/STNorm/model_config.json b/model/STNorm/model_config.json new file mode 100644 index 0000000..62ea48c --- /dev/null +++ b/model/STNorm/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STNorm", + "module": "model.STNorm.STNorm", + "entry": "stnorm" + } +] \ No newline at end of file diff --git a/train.py b/train.py index e9db08b..40b3bcb 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["ASTRA_v3"] - # model_list = ["MTGNN"] - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # model_list = ["ASTRA_v3"] + model_list = ["PatchTST"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] - main(model_list, dataset_list, debug = True) \ No newline at end of file + # dataset_list = ["METR-LA"] + main(model_list, dataset_list, debug = False) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 3ddf361..11cd431 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -102,7 +102,9 @@ class Trainer: for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - y_pred.append(self.model(data).cpu()) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) + y_pred.append(out.cpu()) y_true.append(label.cpu()) d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 -- 2.40.1 From cc63e5078fe8e331f6dd44639ab17f8ef10c0103 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Fri, 19 Dec 2025 10:19:17 +0800 Subject: [PATCH 39/41] impl FPT --- config/FPT/AirQuality.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/BJTaxi-InFlow.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/BJTaxi-OutFlow.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/METR-LA.yaml | 52 +++++++++++++++++++++++++++++++++ config/FPT/NYCBike-InFlow.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/NYCBike-OutFlow.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/PEMS-BAY.yaml | 51 ++++++++++++++++++++++++++++++++ config/FPT/SolarEnergy.yaml | 51 ++++++++++++++++++++++++++++++++ model/ASTRA/astra.py | 1 - model/FPT/fpt.py | 45 ++++++++++++++++++++++++++++ model/FPT/model_config.json | 7 +++++ train.py | 6 ++-- trainer/trainer_selector.py | 2 +- 13 files changed, 465 insertions(+), 5 deletions(-) create mode 100644 config/FPT/AirQuality.yaml create mode 100644 config/FPT/BJTaxi-InFlow.yaml create mode 100644 config/FPT/BJTaxi-OutFlow.yaml create mode 100644 config/FPT/METR-LA.yaml create mode 100644 config/FPT/NYCBike-InFlow.yaml create mode 100644 config/FPT/NYCBike-OutFlow.yaml create mode 100755 config/FPT/PEMS-BAY.yaml create mode 100644 config/FPT/SolarEnergy.yaml create mode 100644 model/FPT/fpt.py create mode 100644 model/FPT/model_config.json diff --git a/config/FPT/AirQuality.yaml b/config/FPT/AirQuality.yaml new file mode 100644 index 0000000..0604938 --- /dev/null +++ b/config/FPT/AirQuality.yaml @@ -0,0 +1,51 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/FPT/BJTaxi-InFlow.yaml b/config/FPT/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..18abb67 --- /dev/null +++ b/config/FPT/BJTaxi-InFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/BJTaxi-OutFlow.yaml b/config/FPT/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..3e6765a --- /dev/null +++ b/config/FPT/BJTaxi-OutFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/METR-LA.yaml b/config/FPT/METR-LA.yaml new file mode 100644 index 0000000..0c22dcb --- /dev/null +++ b/config/FPT/METR-LA.yaml @@ -0,0 +1,52 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 207 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/FPT/NYCBike-InFlow.yaml b/config/FPT/NYCBike-InFlow.yaml new file mode 100644 index 0000000..41a8c8b --- /dev/null +++ b/config/FPT/NYCBike-InFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/NYCBike-OutFlow.yaml b/config/FPT/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..cc52b1a --- /dev/null +++ b/config/FPT/NYCBike-OutFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/PEMS-BAY.yaml b/config/FPT/PEMS-BAY.yaml new file mode 100755 index 0000000..efe4d7c --- /dev/null +++ b/config/FPT/PEMS-BAY.yaml @@ -0,0 +1,51 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 325 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/SolarEnergy.yaml b/config/FPT/SolarEnergy.yaml new file mode 100644 index 0000000..fe1ea22 --- /dev/null +++ b/config/FPT/SolarEnergy.yaml @@ -0,0 +1,51 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 137 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 71d4ee9..f0d32e5 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -206,7 +206,6 @@ class ASTRA(nn.Module): enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim) # 应用图增强编码器(自动生成图结构) graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim) - # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = self.feature_fusion(enc_out) diff --git a/model/FPT/fpt.py b/model/FPT/fpt.py new file mode 100644 index 0000000..941da6d --- /dev/null +++ b/model/FPT/fpt.py @@ -0,0 +1,45 @@ +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange + +class fpt(nn.Module): + def __init__(self, configs): + super(fpt, self).__init__() + self.patch_len = configs['patch_len'] + self.stride = configs['stride'] + self.input_dim = configs['input_dim'] + self.seq_len = configs['seq_len'] + self.pred_len = configs['pred_len'] + self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 + self.d_model = configs['d_model'] + self.gpt_path = configs['gpt_path'] + + self.patch_num = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量 + self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) + + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) + self.gpts.h = self.gpts.h[:self.gpt_layers] + for i, (name, param) in enumerate(self.gpts.named_parameters()): + if 'wpe' in name: + param.requires_grad = True + else: + param.requires_grad = False + + self.in_layer = nn.Linear(self.patch_len, self.d_model) + self.out_layer = nn.Linear(self.d_model * self.patch_num, self.pred_len) + + def forward(self, x): + B, L, M = x.shape + x = x[..., :self.input_dim] + x = rearrange(x, 'b l m -> b m l') + + x = self.padding_patch_layer(x) + x = x.unfold(dimension = -1, size = self.patch_len, step = self.stride) + x = rearrange(x, 'b m n p -> (b m) n p') + + outputs = self.in_layer(x) + outputs = self.gpts(inputs_embeds=outputs).last_hidden_state + outputs = self.out_layer(outputs.reshape(B*M, -1)) + outputs = rearrange(outputs, '(b m) l -> b l m', b = B) + return outputs + diff --git a/model/FPT/model_config.json b/model/FPT/model_config.json new file mode 100644 index 0000000..a7d040c --- /dev/null +++ b/model/FPT/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "FPT", + "module": "model.FPT.fpt", + "entry": "fpt" + } +] \ No newline at end of file diff --git a/train.py b/train.py index 40b3bcb..7242ac0 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - # model_list = ["ASTRA_v3"] - model_list = ["PatchTST"] - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + model_list = ["FPT"] + # model_list = ["PatchTST"] + dataset_list = ["METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] # dataset_list = ["METR-LA"] main(model_list, dataset_list, debug = False) \ No newline at end of file diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 17aa81d..24d8a10 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -20,7 +20,7 @@ def select_trainer( scaler, args, lr_scheduler ) - if model_name in {"HI", "PatchTST", "iTransformer"}: + if model_name in {"HI", "PatchTST", "iTransformer", "FPT"}: return TSTrainer(*base_args) trainer_map = { -- 2.40.1 From 6a8d1f33f9ce93a3eefeaafa2bbbf174f585c79f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Fri, 19 Dec 2025 11:54:07 +0800 Subject: [PATCH 40/41] fix repst:Air bug --- config/REPST/AirQuality.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index ee690f1..8eeba4f 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -27,7 +27,7 @@ model: input_dim: 6 n_heads: 1 num_nodes: 12 - output_dim: 3 + output_dim: 6 patch_len: 6 pred_len: 24 seq_len: 24 -- 2.40.1 From 9d3293cef7baf23be204144311b90f2b5e426dae Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 20 Dec 2025 15:45:13 +0800 Subject: [PATCH 41/41] impl STNorm --- config/STNorm/AirQuality.yaml | 34 ++--- config/STNorm/BJTaxi-InFlow.yaml | 34 ++--- config/STNorm/BJTaxi-OutFlow.yaml | 34 ++--- config/STNorm/METR-LA.yaml | 2 +- config/STNorm/NYCBike-InFlow.yaml | 34 ++--- config/STNorm/NYCBike-OutFlow.yaml | 34 ++--- config/STNorm/PEMS-BAY.yaml | 34 ++--- config/STNorm/SolarEnergy.yaml | 34 ++--- model/STNorm/STNorm.py | 206 +++++++++++++++-------------- model/STNorm/model_config.json | 2 +- train.py | 10 +- 11 files changed, 191 insertions(+), 267 deletions(-) diff --git a/config/STNorm/AirQuality.yaml b/config/STNorm/AirQuality.yaml index 9846895..384633d 100644 --- a/config/STNorm/AirQuality.yaml +++ b/config/STNorm/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 35 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 6 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 6 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 35 + in_dim: 6 + out_dim: 6 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/BJTaxi-InFlow.yaml b/config/STNorm/BJTaxi-InFlow.yaml index 09e453a..13130be 100644 --- a/config/STNorm/BJTaxi-InFlow.yaml +++ b/config/STNorm/BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 1024 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 1024 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/BJTaxi-OutFlow.yaml b/config/STNorm/BJTaxi-OutFlow.yaml index 1b62a4e..fec550a 100644 --- a/config/STNorm/BJTaxi-OutFlow.yaml +++ b/config/STNorm/BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 1024 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 1024 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/METR-LA.yaml b/config/STNorm/METR-LA.yaml index 6f118f2..f48c978 100644 --- a/config/STNorm/METR-LA.yaml +++ b/config/STNorm/METR-LA.yaml @@ -26,7 +26,7 @@ model: tnorm_bool: True num_nodes: 207 in_dim: 1 - out_dim: 24 + out_dim: 1 channels: 32 kernel_size: 2 diff --git a/config/STNorm/NYCBike-InFlow.yaml b/config/STNorm/NYCBike-InFlow.yaml index 95ae41b..57ad401 100644 --- a/config/STNorm/NYCBike-InFlow.yaml +++ b/config/STNorm/NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 128 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 128 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/NYCBike-OutFlow.yaml b/config/STNorm/NYCBike-OutFlow.yaml index b1646ea..4f32f0a 100644 --- a/config/STNorm/NYCBike-OutFlow.yaml +++ b/config/STNorm/NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 128 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 128 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/PEMS-BAY.yaml b/config/STNorm/PEMS-BAY.yaml index 7f28aca..20f4b5d 100644 --- a/config/STNorm/PEMS-BAY.yaml +++ b/config/STNorm/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 325 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 325 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/SolarEnergy.yaml b/config/STNorm/SolarEnergy.yaml index 57e17c8..d1be59c 100644 --- a/config/STNorm/SolarEnergy.yaml +++ b/config/STNorm/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 137 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 137 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/model/STNorm/STNorm.py b/model/STNorm/STNorm.py index 71e7118..11a72e4 100644 --- a/model/STNorm/STNorm.py +++ b/model/STNorm/STNorm.py @@ -2,139 +2,147 @@ import torch import torch.nn as nn import torch.nn.functional as F + +# ========================= +# Spatial Normalization +# ========================= class SNorm(nn.Module): - def __init__(self, channels): + def __init__(self, channels, eps=1e-5): super().__init__() - self.beta = nn.Parameter(torch.zeros(channels)) - self.gamma = nn.Parameter(torch.ones(channels)) + self.gamma = nn.Parameter(torch.ones(1, channels, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1)) + self.eps = eps def forward(self, x): - x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5 - return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1) + # normalize over node dimension + mean = x.mean(dim=2, keepdim=True) + var = x.var(dim=2, keepdim=True, unbiased=False) + x = (x - mean) / torch.sqrt(var + self.eps) + return x * self.gamma + self.beta + +# ========================= +# Temporal Normalization +# ========================= class TNorm(nn.Module): - def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1): + def __init__(self, num_nodes, channels, momentum=0.1, eps=1e-5): super().__init__() - self.track_running_stats = track_running_stats - self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) - self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1)) - self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1)) + self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer("running_mean", torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer("running_var", torch.ones(1, channels, num_nodes, 1)) self.momentum = momentum + self.eps = eps def forward(self, x): - if self.track_running_stats: - mean = x.mean((0, 3), keepdims=True) - var = x.var((0, 3), keepdims=True, unbiased=False) - if self.training: - n = x.shape[3] * x.shape[0] - with torch.no_grad(): - self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean - self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var - else: - mean = self.running_mean - var = self.running_var + if self.training: + mean = x.mean(dim=(0, 3), keepdim=True) + var = x.var(dim=(0, 3), keepdim=True, unbiased=False) + # in-place update (VERY IMPORTANT) + self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean) + self.running_var.mul_(1 - self.momentum).add_(self.momentum * var) else: - mean = x.mean(3, keepdims=True) - var = x.var(3, keepdims=True, unbiased=True) - x_norm = (x - mean) / (var + 1e-5) ** 0.5 - return x_norm * self.gamma + self.beta + mean = self.running_mean + var = self.running_var -class stnorm(nn.Module): + x = (x - mean) / torch.sqrt(var + self.eps) + return x * self.gamma + self.beta + + +# ========================= +# STNorm WaveNet +# ========================= +class STNormNet(nn.Module): def __init__(self, args): super().__init__() - self.dropout = args["dropout"] self.blocks = args["blocks"] self.layers = args["layers"] - self.snorm_bool = args["snorm_bool"] - self.tnorm_bool = args["tnorm_bool"] + self.dropout = args["dropout"] self.num_nodes = args["num_nodes"] - in_dim = args["in_dim"] - out_dim = args["out_dim"] - channels = args["channels"] - kernel_size = args["kernel_size"] - # 初始化卷积层 - self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1)) - - # 初始化模块列表 + self.in_dim = args["in_dim"] + self.out_dim = args["out_dim"] + self.channels = args["channels"] + self.kernel_size = args["kernel_size"] + + self.use_snorm = args["snorm_bool"] + self.use_tnorm = args["tnorm_bool"] + + self.start_conv = nn.Conv2d(self.in_dim, self.channels, kernel_size=(1, 1)) + self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() - self.sn = nn.ModuleList() if self.snorm_bool else None - self.tn = nn.ModuleList() if self.tnorm_bool else None - # 计算感受野 + self.snorms = nn.ModuleList() + self.tnorms = nn.ModuleList() + self.receptive_field = 1 - additional_scope = kernel_size - 1 - # 构建网络层 for b in range(self.blocks): - new_dilation = 1 - for i in range(self.layers): - if self.tnorm_bool: - self.tn.append(TNorm(self.num_nodes, channels)) - if self.snorm_bool: - self.sn.append(SNorm(channels)) - - # 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征 - self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - - # 残差连接和跳跃连接 - self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) - self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) - - # 更新感受野 - self.receptive_field += additional_scope - additional_scope *= 2 - new_dilation *= 2 + dilation = 1 + rf_add = self.kernel_size - 1 + for _ in range(self.layers): + if self.use_snorm: + self.snorms.append(SNorm(self.channels)) + if self.use_tnorm: + self.tnorms.append(TNorm(self.num_nodes, self.channels)) - # 输出层 - self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True) - self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) + self.filter_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) + self.gate_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) + self.residual_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) + self.skip_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) + + self.receptive_field += rf_add + rf_add *= 2 + dilation *= 2 + + self.end_conv_1 = nn.Conv2d(self.channels, self.channels, (1, 1)) + self.end_conv_2 = nn.Conv2d(self.channels, self.out_dim, (1, 1)) def forward(self, input): - # 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps) - x = input[..., 0:1].transpose(1, 3) - - # 处理感受野 - in_len = x.size(3) - if in_len < self.receptive_field: - x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) - - # 起始卷积 - x = self.start_conv(x) - skip = 0 + # (B, T, N, F) -> (B, F, N, T) + x = input[..., :self.in_dim].transpose(1, 3) - # WaveNet层 + # pad to receptive field + if x.size(3) < self.receptive_field: + x = F.pad(x, (self.receptive_field - x.size(3), 0, 0, 0)) + + x = self.start_conv(x) + skip = None + + norm_idx = 0 for i in range(self.blocks * self.layers): residual = x - - # 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接) - x_norm = x - if self.tnorm_bool: - x_norm += self.tn[i](x) - if self.snorm_bool: - x_norm += self.sn[i](x) - - # 膨胀卷积 - filter = torch.tanh(self.filter_convs[i](x_norm)) - gate = torch.sigmoid(self.gate_convs[i](x_norm)) - x = filter * gate - - # 跳跃连接 + # ---------- STNorm (safe fusion) ---------- + if self.use_tnorm: + x = x + 0.5 * self.tnorms[norm_idx](x) + if self.use_snorm: + x = x + 0.5 * self.snorms[norm_idx](x) + norm_idx += 1 + # ---------- Dilated Conv ---------- + filter_out = torch.tanh(self.filter_convs[i](x)) + gate_out = torch.sigmoid(self.gate_convs[i](x)) + x = filter_out * gate_out + # ---------- Skip (TIME SAFE) ---------- s = self.skip_convs[i](x) - skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) - - # 残差连接 - x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):] + if skip is None: + skip = s + else: + skip = skip[..., -s.size(3) :] + s + # ---------- Residual (TIME SAFE) ---------- + x = self.residual_convs[i](x) + x = x + residual[..., -x.size(3) :] - # 输出处理 x = F.relu(skip) x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) - return x \ No newline at end of file + x = self.end_conv_2(x) # [B, 1, N, T] + T_out = x.size(3) + T_target = input.size(1) + + if T_out < T_target: + x = F.pad(x, (T_target - T_out, 0, 0, 0)) # left pad + + x = x.transpose(1, 3) + return x diff --git a/model/STNorm/model_config.json b/model/STNorm/model_config.json index 62ea48c..f860d07 100644 --- a/model/STNorm/model_config.json +++ b/model/STNorm/model_config.json @@ -2,6 +2,6 @@ { "name": "STNorm", "module": "model.STNorm.STNorm", - "entry": "stnorm" + "entry": "STNormNet" } ] \ No newline at end of file diff --git a/train.py b/train.py index 7242ac0..db9d8dd 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,7 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cpu" # 指定设备为cuda:0 seed = 2023 # 随机种子 epochs = 1 # 训练轮数 @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["FPT"] + model_list = ["STNorm"] # model_list = ["PatchTST"] - dataset_list = ["METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - # dataset_list = ["METR-LA"] - main(model_list, dataset_list, debug = False) \ No newline at end of file + dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] + # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] + main(model_list, dataset_list, debug = True) \ No newline at end of file -- 2.40.1