测试时汇总样本 使用detach到gpu 避免显存爆炸

This commit is contained in:
czzhangheng 2025-11-20 21:21:39 +08:00
parent 96f2ea1239
commit 4d87087147
1 changed files with 8 additions and 18 deletions

View File

@ -316,13 +316,9 @@ class Trainer:
def _log_model_params(self): def _log_model_params(self):
"""输出模型可训练参数数量""" """输出模型可训练参数数量"""
try: total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
total_params = sum( self.logger.info(f"Trainable params: {total_params}")
p.numel() for p in self.model.parameters() if p.requires_grad
)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model) self.model.load_state_dict(best_model)
@ -353,8 +349,8 @@ class Trainer:
for data, target in data_loader: for data, target in data_loader:
label = target[..., : args["output_dim"]] label = target[..., : args["output_dim"]]
output = model(data) output = model(data)
y_pred.append(output) y_pred.append(output.detach().cpu())
y_true.append(label) y_true.append(label.detach().cpu())
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
@ -368,17 +364,11 @@ class Trainer:
args["mae_thresh"], args["mae_thresh"],
args["mape_thresh"], args["mape_thresh"],
) )
logger.info( logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
)
# 计算并记录平均指标 # 计算并记录平均指标
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"])
d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
)
logger.info(
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
)
@staticmethod @staticmethod
def _compute_sampling_threshold(global_step, k): def _compute_sampling_threshold(global_step, k):