添加固定种子

This commit is contained in:
czzhangheng 2025-04-04 14:34:41 +08:00
parent e8fc67b867
commit 229b6320b9
2 changed files with 13 additions and 2 deletions

View File

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from model.model_selector import model_selector
import random
import numpy as np
def init_model(args, device):
model = model_selector(args).to(device)
@ -34,3 +35,13 @@ def init_optimizer(model, args):
return optimizer, lr_scheduler
def init_seed(seed):
'''
Disable cudnn to maximize reproducibility
'''
torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

2
run.py
View File

@ -10,7 +10,7 @@ import torch
from datetime import datetime
# import time
from config.args_parser import parse_args
from lib.initializer import init_model, init_optimizer
from lib.initializer import init_model, init_optimizer, init_seed
from lib.loss_function import get_loss_function
from dataloader.loader_selector import get_dataloader