添加固定种子
This commit is contained in:
parent
e8fc67b867
commit
229b6320b9
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from model.model_selector import model_selector
|
from model.model_selector import model_selector
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def init_model(args, device):
|
def init_model(args, device):
|
||||||
model = model_selector(args).to(device)
|
model = model_selector(args).to(device)
|
||||||
|
|
@ -34,3 +35,13 @@ def init_optimizer(model, args):
|
||||||
|
|
||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
def init_seed(seed):
|
||||||
|
'''
|
||||||
|
Disable cudnn to maximize reproducibility
|
||||||
|
'''
|
||||||
|
torch.cuda.cudnn_enabled = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
2
run.py
2
run.py
|
|
@ -10,7 +10,7 @@ import torch
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
# import time
|
# import time
|
||||||
from config.args_parser import parse_args
|
from config.args_parser import parse_args
|
||||||
from lib.initializer import init_model, init_optimizer
|
from lib.initializer import init_model, init_optimizer, init_seed
|
||||||
from lib.loss_function import get_loss_function
|
from lib.loss_function import get_loss_function
|
||||||
|
|
||||||
from dataloader.loader_selector import get_dataloader
|
from dataloader.loader_selector import get_dataloader
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue