Merge branch 'REPST' of https://github.zhang-heng.com/czzhangheng/TrafficWheel into REPST
This commit is contained in:
commit
3b6dc18742
|
|
@ -1,7 +1,7 @@
|
||||||
basic:
|
basic:
|
||||||
dataset: "PEMS-BAY"
|
dataset: "PEMS-BAY"
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:1"
|
device : "cuda:0"
|
||||||
model: "REPST"
|
model: "REPST"
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
|
|
|
||||||
2
run.py
2
run.py
|
|
@ -48,7 +48,7 @@ def main():
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
||||||
map_location=args["device"],
|
map_location=args["basic"]["device"],
|
||||||
weights_only=True,
|
weights_only=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -180,8 +180,8 @@ class Trainer:
|
||||||
loss = self.loss(output, label)
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
# 反归一化
|
# 反归一化
|
||||||
self.scaler.inverse_transform(output)
|
d_output = self.scaler.inverse_transform(output)
|
||||||
self.scaler.inverse_transform(label)
|
d_label = self.scaler.inverse_transform(label)
|
||||||
|
|
||||||
# 反向传播和优化(仅在训练模式)
|
# 反向传播和优化(仅在训练模式)
|
||||||
if optimizer_step and self.optimizer is not None:
|
if optimizer_step and self.optimizer is not None:
|
||||||
|
|
@ -195,6 +195,9 @@ class Trainer:
|
||||||
)
|
)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# 反归一化的loss
|
||||||
|
d_loss = self.loss(d_output, d_label)
|
||||||
|
|
||||||
# 记录步骤时间和内存使用
|
# 记录步骤时间和内存使用
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
self.stats.record_step_time(step_time, mode)
|
self.stats.record_step_time(step_time, mode)
|
||||||
|
|
@ -205,7 +208,7 @@ class Trainer:
|
||||||
y_true.append(label.detach().cpu())
|
y_true.append(label.detach().cpu())
|
||||||
|
|
||||||
# 更新进度条
|
# 更新进度条
|
||||||
progress_bar.set_postfix(loss=loss.item())
|
progress_bar.set_postfix(loss=d_loss.item())
|
||||||
|
|
||||||
# 合并所有批次的预测结果
|
# 合并所有批次的预测结果
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
y_pred = torch.cat(y_pred, dim=0)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue