Project-I/utils/init.py

174 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
import random
import yaml
import logging
from datetime import datetime
import numpy as np
from models.model_selector import model_selector
from data.data_selector import load_dataset
from data.dataloader import get_dataloader
import utils.loss_func as loss_func
from trainer.trainer_selector import select_trainer
def seed(seed : int):
""" 固定随机种子以公平测试 """
torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# print(f"seed is {seed}")
def device(device : str):
"""初始化使用设备"""
if torch.cuda.is_available() and device != 'cpu':
torch.cuda.set_device(int(device.split(':')[1]))
return device
else:
return 'cpu'
def model(config : dict):
"""选择模型"""
device = config['basic']['device']
model = model_selector(config).to(device)
for p in model.parameters():
if p.dim() > 1: nn.init.xavier_uniform_(p)
else: nn.init.uniform_(p)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model param count : {total_params}")
return model
def dataloader(config : dict):
"""初始化dataloader"""
data = load_dataset(config)
train_loader, val_loader, test_loader, scaler = get_dataloader(config, data)
return train_loader, val_loader, test_loader, scaler
def loss(config : dict, scaler):
loss_name = config['train']['loss']
device = config['basic']['device']
match loss_name :
case 'mask_mae': func = loss_func.masked_mae_loss(scaler, mask_value=0.0)
case 'mae': func = torch.nn.L1Loss()
case 'mse': func = torch.nn.MSELoss()
case 'Huber': func = torch.nn.HuberLoss()
case _ : raise NotImplementedError('No Loss Func')
return func.to(device)
def optimizer(config, model):
optimizer = torch.optim.Adam(
params=model.parameters(),
lr=config['train']['lr_init'],
eps=1.0e-8,
weight_decay=config['train']['weight_decay'],
amsgrad=False
)
lr_scheduler = None
if config['train']['lr_decay']:
lr_decay_steps = [int(step) for step in config['train']['lr_decay_step'].split(',')]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer,
milestones=lr_decay_steps,
gamma=config['train']['lr_decay_rate']
)
return optimizer, lr_scheduler
def trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, lr_scheduler, kwargs):
selected_trainer = select_trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs)
return selected_trainer
class Logger:
"""
Logger类主要调用成员对象logger的info方法来记录
使用logger的all_metrics返回所有损失
"""
def __init__(self, config, name=None, debug = True):
self.config = config
cur_time = datetime.now().strftime("%Y/%m/%d-%H:%M:%S")
cur_dir = os.getcwd()
dataset_name = config['basic']['dataset']
model_name = config['basic']['model']
self.dir_path = os.path.join(cur_dir, 'exp', f'{dataset_name}_{model_name}_{cur_time}')
config['train']['log_dir'] = self.dir_path
os.makedirs(self.dir_path, exist_ok=True)
# 生成配置并添加到目录
config_content = yaml.safe_dump(config)
config_path = os.path.join(self.dir_path, "config.yaml")
with open(config_path, 'w') as f:
f.write(config_content)
# logger
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s: %(message)s', "%m/%d %H:%M")
# 控制台处理器
console_handler = logging.StreamHandler()
if debug:
console_handler.setLevel(logging.DEBUG)
else:
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
# 文件处理器 - 无论是否debug都创建日志文件
logfile = os.path.join(self.dir_path, 'run.log')
file_handler = logging.FileHandler(logfile, mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
# 添加处理器到logger
self.logger.addHandler(console_handler)
self.logger.addHandler(file_handler)
def set_log_dir(self):
# Initialize logger
if not os.path.isdir(self.dir_path) and not self.config['basic']['debug']:
os.makedirs(self.dir_path, exist_ok=True)
self.logger.info(f"Experiment log path in: {self.dir_path}")
def mae_torch(self, pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.mean(torch.abs(true - pred))
def rmse_torch(self, pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.sqrt(torch.mean((pred - true) ** 2))
def mape_torch(self, pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001))))
def all_metrics(self, pred, true, mask1, mask2):
if mask1 == 'None': mask1 = None
if mask2 == 'None': mask2 = None
mae = self.mae_torch(pred, true, mask1)
rmse = self.rmse_torch(pred, true, mask1)
mape = self.mape_torch(pred, true, mask2)
return mae, rmse, mape