import os import torch import torch.nn as nn import random import yaml import logging from datetime import datetime import numpy as np from models.model_selector import model_selector from data.data_selector import load_dataset from data.dataloader import get_dataloader import utils.loss_func as loss_func from trainer.trainer_selector import select_trainer def seed(seed : int): """ 固定随机种子以公平测试 """ torch.cuda.cudnn_enabled = False torch.backends.cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # print(f"seed is {seed}") def device(device : str): """初始化使用设备""" if torch.cuda.is_available() and device != 'cpu': torch.cuda.set_device(int(device.split(':')[1])) return device else: return 'cpu' def model(config : dict): """选择模型""" device = config['basic']['device'] model = model_selector(config).to(device) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) total_params = sum(p.numel() for p in model.parameters()) print(f"Model param count : {total_params}") return model def dataloader(config : dict): """初始化dataloader""" data = load_dataset(config) train_loader, val_loader, test_loader, scaler = get_dataloader(config, data) return train_loader, val_loader, test_loader, scaler def loss(config : dict, scaler): loss_name = config['train']['loss'] device = config['basic']['device'] match loss_name : case 'mask_mae': func = loss_func.masked_mae_loss(scaler, mask_value=0.0) case 'mae': func = torch.nn.L1Loss() case 'mse': func = torch.nn.MSELoss() case 'Huber': func = torch.nn.HuberLoss() case _ : raise NotImplementedError('No Loss Func') return func.to(device) def optimizer(config, model): optimizer = torch.optim.Adam( params=model.parameters(), lr=config['train']['lr_init'], eps=1.0e-8, weight_decay=config['train']['weight_decay'], amsgrad=False ) lr_scheduler = None if config['train']['lr_decay']: lr_decay_steps = [int(step) for step in config['train']['lr_decay_step'].split(',')] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=lr_decay_steps, gamma=config['train']['lr_decay_rate'] ) return optimizer, lr_scheduler def trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs): selected_trainer = select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs) return selected_trainer class Logger: """ Logger类,主要调用成员对象logger的info方法来记录 使用logger的all_metrics返回所有损失 """ def __init__(self, config, name=None, debug = True): self.config = config cur_time = datetime.now().strftime("%Y/%m/%d-%H:%M:%S") cur_dir = os.getcwd() dataset_name = config['basic']['dataset'] model_name = config['basic']['model'] self.dir_path = os.path.join(cur_dir, 'exp', f'{dataset_name}_{model_name}_{cur_time}') config['train']['log_dir'] = self.dir_path os.makedirs(self.dir_path, exist_ok=True) # 生成配置并添加到目录 config_content = yaml.safe_dump(config) config_path = os.path.join(self.dir_path, "config.yaml") with open(config_path, 'w') as f: f.write(config_content) # logger self.logger = logging.getLogger(name) self.logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s: %(message)s', "%m/%d %H:%M") # 控制台处理器 console_handler = logging.StreamHandler() if debug: console_handler.setLevel(logging.DEBUG) else: console_handler.setLevel(logging.INFO) console_handler.setFormatter(formatter) # 文件处理器 - 无论是否debug都创建日志文件 logfile = os.path.join(self.dir_path, 'run.log') file_handler = logging.FileHandler(logfile, mode='w') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) # 添加处理器到logger self.logger.addHandler(console_handler) self.logger.addHandler(file_handler) def set_log_dir(self): # Initialize logger if not os.path.isdir(self.dir_path) and not self.config['basic']['debug']: os.makedirs(self.dir_path, exist_ok=True) self.logger.info(f"Experiment log path in: {self.dir_path}") def mae_torch(self, pred, true, mask_value=None): if mask_value is not None: mask = torch.gt(true, mask_value) pred = torch.masked_select(pred, mask) true = torch.masked_select(true, mask) return torch.mean(torch.abs(true - pred)) def rmse_torch(self, pred, true, mask_value=None): if mask_value is not None: mask = torch.gt(true, mask_value) pred = torch.masked_select(pred, mask) true = torch.masked_select(true, mask) return torch.sqrt(torch.mean((pred - true) ** 2)) def mape_torch(self, pred, true, mask_value=None): if mask_value is not None: mask = torch.gt(true, mask_value) pred = torch.masked_select(pred, mask) true = torch.masked_select(true, mask) return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) def all_metrics(self, pred, true, mask1, mask2): if mask1 == 'None': mask1 = None if mask2 == 'None': mask2 = None mae = self.mae_torch(pred, true, mask1) rmse = self.rmse_torch(pred, true, mask1) mape = self.mape_torch(pred, true, mask2) return mae, rmse, mape