Compare commits

..

No commits in common. "5cd81f4d4ca299343f08b88406bf2a71a182871e" and "475a4788cd0a69eee2dea42963bc8330b9a73788" have entirely different histories.

2 changed files with 16 additions and 25 deletions

View File

@ -43,14 +43,13 @@ train:
early_stop_patience: 15 early_stop_patience: 15
epochs: 100 epochs: 100
grad_norm: false grad_norm: false
loss_func: mse loss_func: mae
scheduler: cosineAnnealingLR
tmax: 5
lr_decay: true lr_decay: true
lr_decay_rate: 0.3 lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70" lr_decay_step: "5,20,40,70"
lr_init: 0.003 lr_init: 0.003
max_grad_norm: 5 max_grad_norm: 5
real_value: true
weight_decay: 0 weight_decay: 0
debug: false debug: false
output_dim: 1 output_dim: 1

View File

@ -23,29 +23,21 @@ def init_model(args):
def init_optimizer(model, args): def init_optimizer(model, args):
optim = args.get("optimizer", "Adam") optimizer = torch.optim.Adam(
match optim : params=model.parameters(),
case "Adam": lr=args["lr_init"],
optimizer = torch.optim.Adam( eps=1.0e-8,
params=model.parameters(), weight_decay=args["weight_decay"],
lr=args["lr_init"], amsgrad=False,
eps=1.0e-8, )
weight_decay=args["weight_decay"],
amsgrad=False, lr_scheduler = None
if args["lr_decay"]:
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
) )
scheduler = args.get("scheduler", "multistepLR")
match scheduler:
case "multistepLR":
lr_scheduler = None
if args["lr_decay"]:
lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"]
)
case "cosineAnnealingLR":
T_max = args.get("tmax", 5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=T_max, eta_min=1e-8)
return optimizer, lr_scheduler return optimizer, lr_scheduler