import torch import torch.nn as nn from model.model_selector import model_selector import random import numpy as np def init_model(args, device): model = model_selector(args).to(device) # Initialize model parameters for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) total_params = sum(p.numel() for p in model.parameters()) print(f"Model has {total_params} parameters") return model def init_optimizer(model, args): optimizer = torch.optim.Adam( params=model.parameters(), lr=args['lr_init'], eps=1.0e-8, weight_decay=args['weight_decay'], amsgrad=False ) lr_scheduler = None if args['lr_decay']: lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=lr_decay_steps, gamma=args['lr_decay_rate'] ) return optimizer, lr_scheduler def init_seed(seed): ''' Disable cudnn to maximize reproducibility ''' torch.cuda.cudnn_enabled = False torch.backends.cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed)