兼容consineAnnealineLR

This commit is contained in:
czzhangheng 2025-11-23 19:06:48 +08:00
parent fe3fc186be
commit 95f81425b0
3 changed files with 28 additions and 19 deletions

View File

@ -43,13 +43,14 @@ train:
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
loss_func: mse
scheduler: cosineAnnealingLR
tmax: 5
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 1

View File

@ -203,9 +203,9 @@ class Trainer:
self.stats.record_step_time(step_time, mode)
# 累积损失和预测结果
total_loss += loss.item()
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())

View File

@ -23,21 +23,29 @@ def init_model(args):
def init_optimizer(model, args):
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=args["lr_init"],
eps=1.0e-8,
weight_decay=args["weight_decay"],
amsgrad=False,
)
lr_scheduler = None
if args["lr_decay"]:
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
optim = args.get("optimizer", "Adam")
match optim :
case "Adam":
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=args["lr_init"],
eps=1.0e-8,
weight_decay=args["weight_decay"],
amsgrad=False,
)
scheduler = args.get("scheduler", "multistepLR")
match scheduler:
case "multistepLR":
lr_scheduler = None
if args["lr_decay"]:
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
)
case "cosineAnnealingLR":
T_max = args.get("tmax", 5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=T_max, eta_min=1e-8)
return optimizer, lr_scheduler