TrafficWheel/lib/initializer.py

49 lines
1.3 KiB
Python
Executable File

import torch
import torch.nn as nn
from model.model_selector import model_selector
import random
import numpy as np
def init_model(args, device):
model = model_selector(args).to(device)
# Initialize model parameters
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model has {total_params} parameters")
return model
def init_optimizer(model, args):
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=args['lr_init'],
eps=1.0e-8,
weight_decay=args['weight_decay'],
amsgrad=False
)
lr_scheduler = None
if args['lr_decay']:
lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer,
milestones=lr_decay_steps,
gamma=args['lr_decay_rate']
)
return optimizer, lr_scheduler
def init_seed(seed):
'''
Disable cudnn to maximize reproducibility
'''
torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)