测试时汇总样本 使用detach到gpu 避免显存爆炸
This commit is contained in:
parent
96f2ea1239
commit
4d87087147
|
|
@ -316,13 +316,9 @@ class Trainer:
|
||||||
|
|
||||||
def _log_model_params(self):
|
def _log_model_params(self):
|
||||||
"""输出模型可训练参数数量"""
|
"""输出模型可训练参数数量"""
|
||||||
try:
|
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
total_params = sum(
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
p.numel() for p in self.model.parameters() if p.requires_grad
|
|
||||||
)
|
|
||||||
self.logger.info(f"Trainable params: {total_params}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
self.model.load_state_dict(best_model)
|
self.model.load_state_dict(best_model)
|
||||||
|
|
@ -353,8 +349,8 @@ class Trainer:
|
||||||
for data, target in data_loader:
|
for data, target in data_loader:
|
||||||
label = target[..., : args["output_dim"]]
|
label = target[..., : args["output_dim"]]
|
||||||
output = model(data)
|
output = model(data)
|
||||||
y_pred.append(output)
|
y_pred.append(output.detach().cpu())
|
||||||
y_true.append(label)
|
y_true.append(label.detach().cpu())
|
||||||
|
|
||||||
|
|
||||||
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||||
|
|
@ -368,17 +364,11 @@ class Trainer:
|
||||||
args["mae_thresh"],
|
args["mae_thresh"],
|
||||||
args["mape_thresh"],
|
args["mape_thresh"],
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算并记录平均指标
|
# 计算并记录平均指标
|
||||||
mae, rmse, mape = all_metrics(
|
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"])
|
||||||
d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]
|
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_sampling_threshold(global_step, k):
|
def _compute_sampling_threshold(global_step, k):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue