47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from model.model_selector import model_selector
|
|
import random
|
|
import numpy as np
|
|
|
|
def init_model(args, device):
|
|
model = model_selector(args).to(device)
|
|
# Initialize model parameters
|
|
for p in model.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
else:
|
|
nn.init.uniform_(p)
|
|
return model
|
|
|
|
def init_optimizer(model, args):
|
|
|
|
optimizer = torch.optim.Adam(
|
|
params=model.parameters(),
|
|
lr=args['lr_init'],
|
|
eps=1.0e-8,
|
|
weight_decay=args['weight_decay'],
|
|
amsgrad=False
|
|
)
|
|
|
|
lr_scheduler = None
|
|
if args['lr_decay']:
|
|
lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')]
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer=optimizer,
|
|
milestones=lr_decay_steps,
|
|
gamma=args['lr_decay_rate']
|
|
)
|
|
|
|
return optimizer, lr_scheduler
|
|
|
|
def init_seed(seed):
|
|
'''
|
|
Disable cudnn to maximize reproducibility
|
|
'''
|
|
torch.cuda.cudnn_enabled = False
|
|
torch.backends.cudnn.deterministic = True
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed) |