From cf47acf2bd2b1e46b5024ee00288f21e1d322aea Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 16:38:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dtest=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/PEMSD8.yaml | 2 +- trainer/Trainer.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/REPST/PEMSD8.yaml b/config/REPST/PEMSD8.yaml index f4b970b..3663a32 100755 --- a/config/REPST/PEMSD8.yaml +++ b/config/REPST/PEMSD8.yaml @@ -38,7 +38,7 @@ train: batch_size: 64 early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false loss_func: mae lr_decay: true diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 772f08c..84c7edd 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -274,12 +274,12 @@ class Trainer: with torch.no_grad(): for data, target in data_loader: - label = target[..., : args["train"]["output_dim"]] + label = target[..., : args["output_dim"]] output = model(data) y_pred.append(output) y_true.append(label) - if args["train"]["real_value"]: + if args["real_value"]: y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) else: y_pred = torch.cat(y_pred, dim=0) @@ -289,15 +289,15 @@ class Trainer: mae, rmse, mape = all_metrics( y_pred[:, t, ...], y_true[:, t, ...], - args["train"]["mae_thresh"], - args["train"]["mape_thresh"], + args["mae_thresh"], + args["mape_thresh"], ) logger.info( f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" ) mae, rmse, mape = all_metrics( - y_pred, y_true, args["train"]["mae_thresh"], args["train"]["mape_thresh"] + y_pred, y_true, args["mae_thresh"], args["mape_thresh"] ) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"