diff --git a/config/REPST/PEMSD8.yaml b/config/REPST/PEMSD8.yaml index f4b970b..3663a32 100755 --- a/config/REPST/PEMSD8.yaml +++ b/config/REPST/PEMSD8.yaml @@ -38,7 +38,7 @@ train: batch_size: 64 early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 100 grad_norm: false loss_func: mae lr_decay: true diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 772f08c..84c7edd 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -274,12 +274,12 @@ class Trainer: with torch.no_grad(): for data, target in data_loader: - label = target[..., : args["train"]["output_dim"]] + label = target[..., : args["output_dim"]] output = model(data) y_pred.append(output) y_true.append(label) - if args["train"]["real_value"]: + if args["real_value"]: y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) else: y_pred = torch.cat(y_pred, dim=0) @@ -289,15 +289,15 @@ class Trainer: mae, rmse, mape = all_metrics( y_pred[:, t, ...], y_true[:, t, ...], - args["train"]["mae_thresh"], - args["train"]["mape_thresh"], + args["mae_thresh"], + args["mape_thresh"], ) logger.info( f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" ) mae, rmse, mape = all_metrics( - y_pred, y_true, args["train"]["mae_thresh"], args["train"]["mape_thresh"] + y_pred, y_true, args["mae_thresh"], args["mape_thresh"] ) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"