import torch import torch.nn as nn from model.model_selector import model_selector def init_model(args, device): model = model_selector(args).to(device) # Initialize model parameters for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) return model def init_optimizer(model, args): optimizer = torch.optim.Adam( params=model.parameters(), lr=args['lr_init'], eps=1.0e-8, weight_decay=args['weight_decay'], amsgrad=False ) lr_scheduler = None if args['lr_decay']: lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=lr_decay_steps, gamma=args['lr_decay_rate'] ) return optimizer, lr_scheduler