diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 0eacf3e..60eb800 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -43,13 +43,14 @@ train: early_stop_patience: 15 epochs: 100 grad_norm: false - loss_func: mae + loss_func: mse + scheduler: cosineAnnealingLR + tmax: 5 lr_decay: true lr_decay_rate: 0.3 lr_decay_step: "5,20,40,70" lr_init: 0.003 max_grad_norm: 5 - real_value: true weight_decay: 0 debug: false output_dim: 1 diff --git a/trainer/Trainer.py b/trainer/Trainer.py index c7abff8..626e9e9 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -203,9 +203,9 @@ class Trainer: self.stats.record_step_time(step_time, mode) # 累积损失和预测结果 - total_loss += loss.item() - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) diff --git a/utils/initializer.py b/utils/initializer.py index 2d72498..95b352d 100755 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -23,21 +23,29 @@ def init_model(args): def init_optimizer(model, args): - optimizer = torch.optim.Adam( - params=model.parameters(), - lr=args["lr_init"], - eps=1.0e-8, - weight_decay=args["weight_decay"], - amsgrad=False, - ) - - lr_scheduler = None - if args["lr_decay"]: - lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")] - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"] + optim = args.get("optimizer", "Adam") + match optim : + case "Adam": + optimizer = torch.optim.Adam( + params=model.parameters(), + lr=args["lr_init"], + eps=1.0e-8, + weight_decay=args["weight_decay"], + amsgrad=False, ) - + + scheduler = args.get("scheduler", "multistepLR") + match scheduler: + case "multistepLR": + lr_scheduler = None + if args["lr_decay"]: + lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")] + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"] + ) + case "cosineAnnealingLR": + T_max = args.get("tmax", 5) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=T_max, eta_min=1e-8) return optimizer, lr_scheduler