diff --git a/lib/initializer.py b/lib/initializer.py index 7cee09a..c08bacc 100644 --- a/lib/initializer.py +++ b/lib/initializer.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn from model.model_selector import model_selector - +import random +import numpy as np def init_model(args, device): model = model_selector(args).to(device) @@ -34,3 +35,13 @@ def init_optimizer(model, args): return optimizer, lr_scheduler +def init_seed(seed): + ''' + Disable cudnn to maximize reproducibility + ''' + torch.cuda.cudnn_enabled = False + torch.backends.cudnn.deterministic = True + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) \ No newline at end of file diff --git a/run.py b/run.py index 16aaf9c..7499708 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,7 @@ import torch from datetime import datetime # import time from config.args_parser import parse_args -from lib.initializer import init_model, init_optimizer +from lib.initializer import init_model, init_optimizer, init_seed from lib.loss_function import get_loss_function from dataloader.loader_selector import get_dataloader