174 lines
6.0 KiB
Python
174 lines
6.0 KiB
Python
import os
|
||
import torch
|
||
import torch.nn as nn
|
||
import random
|
||
import yaml
|
||
import logging
|
||
from datetime import datetime
|
||
import numpy as np
|
||
|
||
from models.model_selector import model_selector
|
||
from data.data_selector import load_dataset
|
||
from data.dataloader import get_dataloader
|
||
import utils.loss_func as loss_func
|
||
from trainer.trainer_selector import select_trainer
|
||
|
||
|
||
def seed(seed : int):
|
||
""" 固定随机种子以公平测试 """
|
||
torch.cuda.cudnn_enabled = False
|
||
torch.backends.cudnn.deterministic = True
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed(seed)
|
||
# print(f"seed is {seed}")
|
||
|
||
def device(device : str):
|
||
"""初始化使用设备"""
|
||
if torch.cuda.is_available() and device != 'cpu':
|
||
torch.cuda.set_device(int(device.split(':')[1]))
|
||
return device
|
||
else:
|
||
return 'cpu'
|
||
|
||
def model(config : dict):
|
||
"""选择模型"""
|
||
device = config['basic']['device']
|
||
model = model_selector(config).to(device)
|
||
for p in model.parameters():
|
||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||
else: nn.init.uniform_(p)
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
print(f"Model param count : {total_params}")
|
||
return model
|
||
|
||
def dataloader(config : dict):
|
||
"""初始化dataloader"""
|
||
data = load_dataset(config)
|
||
train_loader, val_loader, test_loader, scaler = get_dataloader(config, data)
|
||
return train_loader, val_loader, test_loader, scaler
|
||
|
||
def loss(config : dict, scaler):
|
||
loss_name = config['train']['loss']
|
||
device = config['basic']['device']
|
||
match loss_name :
|
||
case 'mask_mae': func = loss_func.masked_mae_loss(scaler, mask_value=0.0)
|
||
case 'mae': func = torch.nn.L1Loss()
|
||
case 'mse': func = torch.nn.MSELoss()
|
||
case 'Huber': func = torch.nn.HuberLoss()
|
||
case _ : raise NotImplementedError('No Loss Func')
|
||
return func.to(device)
|
||
|
||
|
||
def optimizer(config, model):
|
||
optimizer = torch.optim.Adam(
|
||
params=model.parameters(),
|
||
lr=config['train']['lr_init'],
|
||
eps=1.0e-8,
|
||
weight_decay=config['train']['weight_decay'],
|
||
amsgrad=False
|
||
)
|
||
|
||
lr_scheduler = None
|
||
if config['train']['lr_decay']:
|
||
lr_decay_steps = [int(step) for step in config['train']['lr_decay_step'].split(',')]
|
||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||
optimizer=optimizer,
|
||
milestones=lr_decay_steps,
|
||
gamma=config['train']['lr_decay_rate']
|
||
)
|
||
|
||
return optimizer, lr_scheduler
|
||
|
||
|
||
def trainer(config, model, loss, optimizer,
|
||
train_loader, val_loader, test_loader,
|
||
scaler, lr_scheduler, kwargs):
|
||
selected_trainer = select_trainer(config, model, loss, optimizer,
|
||
train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs)
|
||
return selected_trainer
|
||
|
||
class Logger:
|
||
"""
|
||
Logger类,主要调用成员对象logger的info方法来记录
|
||
使用logger的all_metrics返回所有损失
|
||
"""
|
||
def __init__(self, config, name=None, debug = True):
|
||
self.config = config
|
||
cur_time = datetime.now().strftime("%Y/%m/%d-%H:%M:%S")
|
||
cur_dir = os.getcwd()
|
||
dataset_name = config['basic']['dataset']
|
||
model_name = config['basic']['model']
|
||
self.dir_path = os.path.join(cur_dir, 'exp', f'{dataset_name}_{model_name}_{cur_time}')
|
||
config['train']['log_dir'] = self.dir_path
|
||
os.makedirs(self.dir_path, exist_ok=True)
|
||
# 生成配置并添加到目录
|
||
config_content = yaml.safe_dump(config)
|
||
config_path = os.path.join(self.dir_path, "config.yaml")
|
||
with open(config_path, 'w') as f:
|
||
f.write(config_content)
|
||
|
||
# logger
|
||
self.logger = logging.getLogger(name)
|
||
self.logger.setLevel(logging.DEBUG)
|
||
formatter = logging.Formatter('%(asctime)s: %(message)s', "%m/%d %H:%M")
|
||
|
||
# 控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
if debug:
|
||
console_handler.setLevel(logging.DEBUG)
|
||
else:
|
||
console_handler.setLevel(logging.INFO)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 文件处理器 - 无论是否debug都创建日志文件
|
||
logfile = os.path.join(self.dir_path, 'run.log')
|
||
file_handler = logging.FileHandler(logfile, mode='w')
|
||
file_handler.setLevel(logging.DEBUG)
|
||
file_handler.setFormatter(formatter)
|
||
|
||
# 添加处理器到logger
|
||
self.logger.addHandler(console_handler)
|
||
self.logger.addHandler(file_handler)
|
||
|
||
def set_log_dir(self):
|
||
# Initialize logger
|
||
if not os.path.isdir(self.dir_path) and not self.config['basic']['debug']:
|
||
os.makedirs(self.dir_path, exist_ok=True)
|
||
self.logger.info(f"Experiment log path in: {self.dir_path}")
|
||
|
||
def mae_torch(self, pred, true, mask_value=None):
|
||
if mask_value is not None:
|
||
mask = torch.gt(true, mask_value)
|
||
pred = torch.masked_select(pred, mask)
|
||
true = torch.masked_select(true, mask)
|
||
return torch.mean(torch.abs(true - pred))
|
||
|
||
def rmse_torch(self, pred, true, mask_value=None):
|
||
if mask_value is not None:
|
||
mask = torch.gt(true, mask_value)
|
||
pred = torch.masked_select(pred, mask)
|
||
true = torch.masked_select(true, mask)
|
||
return torch.sqrt(torch.mean((pred - true) ** 2))
|
||
|
||
def mape_torch(self, pred, true, mask_value=None):
|
||
if mask_value is not None:
|
||
mask = torch.gt(true, mask_value)
|
||
pred = torch.masked_select(pred, mask)
|
||
true = torch.masked_select(true, mask)
|
||
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001))))
|
||
|
||
def all_metrics(self, pred, true, mask1, mask2):
|
||
if mask1 == 'None': mask1 = None
|
||
if mask2 == 'None': mask2 = None
|
||
mae = self.mae_torch(pred, true, mask1)
|
||
rmse = self.rmse_torch(pred, true, mask1)
|
||
mape = self.mape_torch(pred, true, mask2)
|
||
return mae, rmse, mape
|
||
|
||
|
||
|
||
|
||
|