From fe3fc186be5deee3fc0e3cb1a544f9d7085c2e67 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 18 Nov 2025 10:46:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=98=BE=E7=A4=BA=E5=8F=8D=E5=BD=92=E4=B8=80?= =?UTF-8?q?=E5=8C=96loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/PEMS-BAY.yaml | 2 +- run.py | 2 +- trainer/Trainer.py | 9 ++++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 3209bd0..0eacf3e 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -1,7 +1,7 @@ basic: dataset: "PEMS-BAY" mode : "train" - device : "cuda:1" + device : "cuda:0" model: "REPST" seed: 2023 diff --git a/run.py b/run.py index 175367f..741f9d7 100755 --- a/run.py +++ b/run.py @@ -46,7 +46,7 @@ def main(): model.load_state_dict( torch.load( f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth", - map_location=args["device"], + map_location=args["basic"]["device"], weights_only=True, ) ) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 3c0e2b1..c7abff8 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -180,8 +180,8 @@ class Trainer: loss = self.loss(output, label) # 反归一化 - self.scaler.inverse_transform(output) - self.scaler.inverse_transform(label) + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) # 反向传播和优化(仅在训练模式) if optimizer_step and self.optimizer is not None: @@ -194,6 +194,9 @@ class Trainer: self.model.parameters(), self.args["max_grad_norm"] ) self.optimizer.step() + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) # 记录步骤时间和内存使用 step_time = time.time() - start_time @@ -205,7 +208,7 @@ class Trainer: y_true.append(label.detach().cpu()) # 更新进度条 - progress_bar.set_postfix(loss=loss.item()) + progress_bar.set_postfix(loss=d_loss.item()) # 合并所有批次的预测结果 y_pred = torch.cat(y_pred, dim=0)