diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 82c90f5..508230e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -316,13 +316,9 @@ class Trainer: def _log_model_params(self): """输出模型可训练参数数量""" - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -353,8 +349,8 @@ class Trainer: for data, target in data_loader: label = target[..., : args["output_dim"]] output = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) @@ -368,17 +364,11 @@ class Trainer: args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics( - d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k):