Merge branch 'REPST' of https://github.zhang-heng.com/czzhangheng/TrafficWheel into REPST
This commit is contained in:
commit
3b6dc18742
|
|
@ -1,7 +1,7 @@
|
|||
basic:
|
||||
dataset: "PEMS-BAY"
|
||||
mode : "train"
|
||||
device : "cuda:1"
|
||||
device : "cuda:0"
|
||||
model: "REPST"
|
||||
seed: 2023
|
||||
|
||||
|
|
|
|||
2
run.py
2
run.py
|
|
@ -48,7 +48,7 @@ def main():
|
|||
model.load_state_dict(
|
||||
torch.load(
|
||||
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
||||
map_location=args["device"],
|
||||
map_location=args["basic"]["device"],
|
||||
weights_only=True,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -180,8 +180,8 @@ class Trainer:
|
|||
loss = self.loss(output, label)
|
||||
|
||||
# 反归一化
|
||||
self.scaler.inverse_transform(output)
|
||||
self.scaler.inverse_transform(label)
|
||||
d_output = self.scaler.inverse_transform(output)
|
||||
d_label = self.scaler.inverse_transform(label)
|
||||
|
||||
# 反向传播和优化(仅在训练模式)
|
||||
if optimizer_step and self.optimizer is not None:
|
||||
|
|
@ -194,6 +194,9 @@ class Trainer:
|
|||
self.model.parameters(), self.args["max_grad_norm"]
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
# 反归一化的loss
|
||||
d_loss = self.loss(d_output, d_label)
|
||||
|
||||
# 记录步骤时间和内存使用
|
||||
step_time = time.time() - start_time
|
||||
|
|
@ -205,7 +208,7 @@ class Trainer:
|
|||
y_true.append(label.detach().cpu())
|
||||
|
||||
# 更新进度条
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
progress_bar.set_postfix(loss=d_loss.item())
|
||||
|
||||
# 合并所有批次的预测结果
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue