import torch import torch.nn as nn from model.model_selector import model_selector from utils.loss_function import masked_mae_loss import random import numpy as np from datetime import datetime import os import yaml def init_model(args): device = args["device"] model = model_selector(args).to(device) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) total_params = sum(p.numel() for p in model.parameters()) print(f"模型参数量: {total_params} ") return model def init_optimizer(model, args): optimizer = torch.optim.Adam( params=model.parameters(), lr=args["lr_init"], eps=1.0e-8, weight_decay=args["weight_decay"], amsgrad=False, ) lr_scheduler = None if args["lr_decay"]: lr_decay_steps = [int(step) for step in args["lr_decay_step"].split(",")] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=lr_decay_steps, gamma=args["lr_decay_rate"] ) return optimizer, lr_scheduler def init_seed(seed): """初始化种子,保证结果可复现""" torch.cuda.cudnn_enabled = False torch.backends.cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def init_device(args): device_name = args["basic"]["device"] if "model" not in args or not isinstance(args["model"], dict): args["model"] = {} # Ensure args['model'] is a dictionary match device_name: case "mps": if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): args["device"] = "mps" else: args["device"] = "cpu" case device if device.startswith("cuda"): if torch.cuda.is_available(): torch.cuda.set_device(int(device.split(":")[1])) args["device"] = device else: args["device"] = "cpu" case _: args["device"] = "cpu" args["model"]["device"] = args["device"] return args def init_loss(args, scaler): device = args["basic"]["device"] args = args["train"] match args["loss_func"]: case "mask_mae": return masked_mae_loss(scaler, mask_value=None) case "mae": return torch.nn.L1Loss().to(device) case "mse": return torch.nn.MSELoss().to(device) case "Huber": return torch.nn.HuberLoss().to(device) case _: raise ValueError(f"Unsupported loss function: {args['loss_func']}") def create_logs(args): current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") current_dir = os.path.abspath(os.getcwd()) model_name = args["basic"]["model"] dataset_name = args["basic"]["dataset"] args["train"]["log_dir"] = os.path.join( current_dir, "experiments", model_name, dataset_name, current_time ) config_filename = f"{args['basic']['dataset']}.yaml" os.makedirs(args["train"]["log_dir"], exist_ok=True) config_content = yaml.safe_dump(args, default_flow_style=False) destination_path = os.path.join(args["train"]["log_dir"], config_filename) # 将 args 保存为 YAML 文件 with open(destination_path, "w") as f: f.write(config_content)