import torch import torch.nn as nn from model.model_selector import model_selector from lib.loss_function import masked_mae_loss import random import numpy as np from datetime import datetime import os import yaml def init_model(args): device = args["device"] model = model_selector(args).to(device) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) total_params = sum(p.numel() for p in model.parameters()) print(f"模型参数量: {total_params} ") return model def init_optimizer(model, args): optimizer = torch.optim.Adam( params=model.parameters(), lr=args['lr_init'], eps=1.0e-8, weight_decay=args['weight_decay'], amsgrad=False ) lr_scheduler = None if args['lr_decay']: lr_decay_steps = [int(step) for step in args['lr_decay_step'].split(',')] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=lr_decay_steps, gamma=args['lr_decay_rate'] ) return optimizer, lr_scheduler def init_seed(seed): """初始化种子,保证结果可复现""" torch.cuda.cudnn_enabled = False torch.backends.cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def init_device(args): device_name = args['basic']['device'] if 'model' not in args or not isinstance(args['model'], dict): args['model'] = {} # Ensure args['model'] is a dictionary match device_name: case 'mps': if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): args['device'] = 'mps' else: args['device'] = 'cpu' case device if device.startswith('cuda'): if torch.cuda.is_available(): torch.cuda.set_device(int(device.split(':')[1])) args['device'] = device else: args['device'] = 'cpu' case _: args['device'] = 'cpu' args['model']['device'] = args['device'] return args def init_loss(args, scaler): device = args['basic']['device'] args = args['train'] match args['loss_func']: case 'mask_mae': return masked_mae_loss(scaler, mask_value=None) case 'mae': return torch.nn.L1Loss().to(device) case 'mse': return torch.nn.MSELoss().to(device) case 'Huber': return torch.nn.HuberLoss().to(device) case _: raise ValueError(f"Unsupported loss function: {args['loss_func']}") def create_logs(args): current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') current_dir = os.path.dirname(os.path.realpath(__file__)) args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['basic']['dataset'], current_time) config_filename = f"{args['basic']['dataset']}.yaml" os.makedirs(args['train']['log_dir'], exist_ok=True) config_content = yaml.safe_dump(args, default_flow_style=False) destination_path = os.path.join(args['train']['log_dir'], config_filename) # 将 args 保存为 YAML 文件 with open(destination_path, 'w') as f: f.write(config_content)