TrafficWheel/lib/initializer.py

100 lines
3.3 KiB
Python
Executable File

import torch
import torch.nn as nn
from model.model_selector import model_selector
from lib.loss_function import masked_mae_loss
import random
import numpy as np
from datetime import datetime
import os
import yaml
def init_model(args):
device = args["device"]
model = model_selector(args).to(device)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
else:
nn.init.uniform_(p)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数量: {total_params} ")
return model
def init_optimizer(model, args):
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=args['lr_init'],
eps=1.0e-8,
weight_decay=args['weight_decay'],
amsgrad=False
)
lr_scheduler = None
if args['lr_decay']:
lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer,
milestones=lr_decay_steps,
gamma=args['lr_decay_rate']
)
return optimizer, lr_scheduler
def init_seed(seed):
"""初始化种子,保证结果可复现"""
torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def init_device(args):
device_name = args['basic']['device']
if 'model' not in args or not isinstance(args['model'], dict):
args['model'] = {} # Ensure args['model'] is a dictionary
match device_name:
case 'mps':
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
args['device'] = 'mps'
else:
args['device'] = 'cpu'
case device if device.startswith('cuda'):
if torch.cuda.is_available():
torch.cuda.set_device(int(device.split(':')[1]))
args['device'] = device
else:
args['device'] = 'cpu'
case _:
args['device'] = 'cpu'
args['model']['device'] = args['device']
return args
def init_loss(args, scaler):
device = args['basic']['device']
args = args['train']
match args['loss_func']:
case 'mask_mae':
return masked_mae_loss(scaler, mask_value=None)
case 'mae':
return torch.nn.L1Loss().to(device)
case 'mse':
return torch.nn.MSELoss().to(device)
case 'Huber':
return torch.nn.HuberLoss().to(device)
case _:
raise ValueError(f"Unsupported loss function: {args['loss_func']}")
def create_logs(args):
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
current_dir = os.path.dirname(os.path.realpath(__file__))
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['basic']['dataset'], current_time)
config_filename = f"{args['basic']['dataset']}.yaml"
os.makedirs(args['train']['log_dir'], exist_ok=True)
config_content = yaml.safe_dump(args, default_flow_style=False)
destination_path = os.path.join(args['train']['log_dir'], config_filename)
# 将 args 保存为 YAML 文件
with open(destination_path, 'w') as f:
f.write(config_content)