兼容consineAnnealineLR
This commit is contained in:
parent
fe3fc186be
commit
95f81425b0
|
|
@ -43,13 +43,14 @@ train:
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 100
|
epochs: 100
|
||||||
grad_norm: false
|
grad_norm: false
|
||||||
loss_func: mae
|
loss_func: mse
|
||||||
|
scheduler: cosineAnnealingLR
|
||||||
|
tmax: 5
|
||||||
lr_decay: true
|
lr_decay: true
|
||||||
lr_decay_rate: 0.3
|
lr_decay_rate: 0.3
|
||||||
lr_decay_step: "5,20,40,70"
|
lr_decay_step: "5,20,40,70"
|
||||||
lr_init: 0.003
|
lr_init: 0.003
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
debug: false
|
debug: false
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
|
|
|
||||||
|
|
@ -203,9 +203,9 @@ class Trainer:
|
||||||
self.stats.record_step_time(step_time, mode)
|
self.stats.record_step_time(step_time, mode)
|
||||||
|
|
||||||
# 累积损失和预测结果
|
# 累积损失和预测结果
|
||||||
total_loss += loss.item()
|
total_loss += d_loss.item()
|
||||||
y_pred.append(output.detach().cpu())
|
y_pred.append(d_output.detach().cpu())
|
||||||
y_true.append(label.detach().cpu())
|
y_true.append(d_label.detach().cpu())
|
||||||
|
|
||||||
# 更新进度条
|
# 更新进度条
|
||||||
progress_bar.set_postfix(loss=d_loss.item())
|
progress_bar.set_postfix(loss=d_loss.item())
|
||||||
|
|
|
||||||
|
|
@ -23,21 +23,29 @@ def init_model(args):
|
||||||
|
|
||||||
|
|
||||||
def init_optimizer(model, args):
|
def init_optimizer(model, args):
|
||||||
optimizer = torch.optim.Adam(
|
optim = args.get("optimizer", "Adam")
|
||||||
params=model.parameters(),
|
match optim :
|
||||||
lr=args["lr_init"],
|
case "Adam":
|
||||||
eps=1.0e-8,
|
optimizer = torch.optim.Adam(
|
||||||
weight_decay=args["weight_decay"],
|
params=model.parameters(),
|
||||||
amsgrad=False,
|
lr=args["lr_init"],
|
||||||
)
|
eps=1.0e-8,
|
||||||
|
weight_decay=args["weight_decay"],
|
||||||
lr_scheduler = None
|
amsgrad=False,
|
||||||
if args["lr_decay"]:
|
|
||||||
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
||||||
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler = args.get("scheduler", "multistepLR")
|
||||||
|
match scheduler:
|
||||||
|
case "multistepLR":
|
||||||
|
lr_scheduler = None
|
||||||
|
if args["lr_decay"]:
|
||||||
|
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
|
||||||
|
)
|
||||||
|
case "cosineAnnealingLR":
|
||||||
|
T_max = args.get("tmax", 5)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=T_max, eta_min=1e-8)
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue