From 4d8708714738d6987b4b1b01260d36b2acc8357c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 21:21:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=97=B6=E6=B1=87=E6=80=BB?= =?UTF-8?q?=E6=A0=B7=E6=9C=AC=20=E4=BD=BF=E7=94=A8detach=E5=88=B0gpu=20?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E6=98=BE=E5=AD=98=E7=88=86=E7=82=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trainer/Trainer.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 82c90f5..508230e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -316,13 +316,9 @@ class Trainer: def _log_model_params(self): """输出模型可训练参数数量""" - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -353,8 +349,8 @@ class Trainer: for data, target in data_loader: label = target[..., : args["output_dim"]] output = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) @@ -368,17 +364,11 @@ class Trainer: args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics( - d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k):