添加固定种子
This commit is contained in:
parent
e8fc67b867
commit
229b6320b9
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from model.model_selector import model_selector
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
def init_model(args, device):
|
||||
model = model_selector(args).to(device)
|
||||
|
|
@ -34,3 +35,13 @@ def init_optimizer(model, args):
|
|||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
def init_seed(seed):
|
||||
'''
|
||||
Disable cudnn to maximize reproducibility
|
||||
'''
|
||||
torch.cuda.cudnn_enabled = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
2
run.py
2
run.py
|
|
@ -10,7 +10,7 @@ import torch
|
|||
from datetime import datetime
|
||||
# import time
|
||||
from config.args_parser import parse_args
|
||||
from lib.initializer import init_model, init_optimizer
|
||||
from lib.initializer import init_model, init_optimizer, init_seed
|
||||
from lib.loss_function import get_loss_function
|
||||
|
||||
from dataloader.loader_selector import get_dataloader
|
||||
|
|
|
|||
Loading…
Reference in New Issue