From 229b6320b9762329d6742c0c5ca12bd572433ab7 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Fri, 4 Apr 2025 14:34:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=9B=BA=E5=AE=9A=E7=A7=8D?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/initializer.py | 13 ++++++++++++- run.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/initializer.py b/lib/initializer.py index 7cee09a..c08bacc 100644 --- a/lib/initializer.py +++ b/lib/initializer.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn from model.model_selector import model_selector - +import random +import numpy as np def init_model(args, device): model = model_selector(args).to(device) @@ -34,3 +35,13 @@ def init_optimizer(model, args): return optimizer, lr_scheduler +def init_seed(seed): + ''' + Disable cudnn to maximize reproducibility + ''' + torch.cuda.cudnn_enabled = False + torch.backends.cudnn.deterministic = True + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) \ No newline at end of file diff --git a/run.py b/run.py index 16aaf9c..7499708 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,7 @@ import torch from datetime import datetime # import time from config.args_parser import parse_args -from lib.initializer import init_model, init_optimizer +from lib.initializer import init_model, init_optimizer, init_seed from lib.loss_function import get_loss_function from dataloader.loader_selector import get_dataloader