Compare commits

..

No commits in common. "3095b7435b42d2833e002381a4320507504bb8f4" and "3b4acd49513140c34fb3109a4001e2136966ec13" have entirely different histories.

7 changed files with 717 additions and 333 deletions

View File

@ -2,80 +2,95 @@ import os
import numpy as np import numpy as np
import h5py import h5py
def load_st_dataset(config): def load_st_dataset(config):
dataset = config["basic"]["dataset"] dataset = config["basic"]["dataset"]
# sample = config["data"]["sample"]
# output B, N, D
match dataset:
case "BeijingAirQuality":
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 36000, 7, 3
data = data.reshape(L, N, C)
case "AirQuality":
data_path = os.path.join("./data/AirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 8701,35,6
data = data.reshape(L, N, C)
case "PEMS-BAY":
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
with h5py.File(data_path, 'r') as f:
data = f['speed']['block0_values'][:]
case "METR-LA":
data_path = os.path.join("./data/METR-LA/METR-LA.h5")
with h5py.File(data_path, 'r') as f:
data = f['df']['block0_values'][:]
case "SolarEnergy":
data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv")
data = np.loadtxt(data_path, delimiter=",")
case "PEMSD3":
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD4":
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7":
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD8":
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(L)":
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(M)":
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
data = np.genfromtxt(data_path, delimiter=",")
case "BJ":
data_path = os.path.join("./data/BJ/BJ500.csv")
data = np.genfromtxt(data_path, delimiter=",", skip_header=1)
case "Hainan":
data_path = os.path.join("./data/Hainan/Hainan.npz")
data = np.load(data_path)["data"][:, :, 0]
case "SD":
data_path = os.path.join("./data/SD/data.npz")
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case "BJTaxi-InFlow":
data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32)
case "BJTaxi-OutFlow":
data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32)
case "NYCBike-InFlow":
data_path = os.path.join("./data/NYCBike/NYC16x8.h5")
with h5py.File(data_path, 'r') as f:
data = f['data'][:].astype(np.float32)
data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2)
data = data[:, :, 0:1]
case "NYCBike-OutFlow":
data_path = os.path.join("./data/NYCBike/NYC16x8.h5")
with h5py.File(data_path, 'r') as f:
data = f['data'][:].astype(np.float32)
data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2)
data = data[:, :, 1:2]
case _:
raise ValueError(f"Unsupported dataset: {dataset}")
loaders = { # Ensure data shape compatibility
"BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3), if len(data.shape) == 2:
"AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6), data = np.expand_dims(data, axis=-1)
"PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")), print("加载 %s 数据集中... " % dataset)
"METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")), # return data[::sample]
"SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","),
"PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"),
"PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"),
"PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"),
"PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"),
"PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"),
"PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","),
"BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1),
"Hainan": lambda: _npz("./data/Hainan/Hainan.npz"),
"SD": lambda: _npz("./data/SD/data.npz", cast=True),
"BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32),
"BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32),
"NYCBike-InFlow": lambda: _nyc_bike(0),
"NYCBike-OutFlow": lambda: _nyc_bike(1),
}
if dataset not in loaders:
raise ValueError(f"Unsupported dataset: {dataset}")
data = loaders[dataset]()
if data.ndim == 2:
data = data[..., None]
print(f"加载 {dataset} 数据集中... ")
return data return data
# ---------------- helpers ----------------
def _memmap(path, L, N, C):
data = np.memmap(path, dtype=np.float32, mode="r")
return data.reshape(L, N, C)
def _h5(path, keys):
with h5py.File(path, "r") as f:
return f[keys[0]][keys[1]][:]
def _npz(path, cast=False):
data = np.load(path)["data"][:, :, 0]
return data.astype(np.float32) if cast else data
def _nyc_bike(channel):
with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f:
data = f["data"][:].astype(np.float32)
data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2)
return data[:, :, channel:channel + 1]
def read_BeijingTaxi(): def read_BeijingTaxi():
files = [ files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy", all_data = []
] for file in files:
data = np.concatenate( data_path = os.path.join(f"./data/BeijingTaxi/{file}")
[np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0 data = np.load(data_path)
) all_data.append(data)
T = data.shape[0] all_data = np.concatenate(all_data, axis=0)
return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2) time_num = all_data.shape[0]
all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2)
return all_data

View File

@ -8,12 +8,21 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader
def get_dataloader(config, normalizer, single): def get_dataloader(config, normalizer, single):
loader_map = { TS_model = ["iTransformer", "HI", "PatchTST"]
"STGNCDE": cde_loader, model_name = config["basic"]["model"]
"STGNRDE": nrde_loader, # if model_name == "Informer":
"DCRNN": DCRNN_loader, # return Informer_loader(config, normalizer, single)
"EXP": EXP_loader, # elif model_name in TS_model:
} # return TS_loader(config, normalizer, single)
return loader_map.get(config["basic"]["model"], normal_loader)( # else :
config, normalizer, single match model_name:
) case "STGNCDE":
return cde_loader(config, normalizer, single)
case "STGNRDE":
return nrde_loader(config, normalizer, single)
case "DCRNN":
return DCRNN_loader(config, normalizer, single)
case "EXP":
return EXP_loader(config, normalizer, single)
case _:
return normal_loader(config, normalizer, single)

View File

@ -6,14 +6,16 @@ import utils.initializer as init
from dataloader.loader_selector import get_dataloader from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer from trainer.trainer_selector import select_trainer
import cProfile
def read_config(config_path): def read_config(config_path):
with open(config_path, "r") as file: with open(config_path, "r") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
# 全局配置 # 全局配置
device = "cuda:0" # 指定设备为cuda:0 device = "cuda:1" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 1 epochs = 120
# 拷贝项 # 拷贝项
config["basic"]["device"] = device config["basic"]["device"] = device
@ -97,4 +99,4 @@ def main(debug=False):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
main(debug = True) main(debug = False)

View File

@ -1,195 +1,296 @@
import os, time, copy, torch import math
from tqdm import tqdm import os
import time
import copy
import torch
from utils.logger import get_logger from utils.logger import get_logger
from utils.loss_function import all_metrics from utils.loss_function import all_metrics
from tqdm import tqdm
class TSWrapper: class TSWrapper:
def __init__(self, args): def __init__(self, args):
self.b = args['train']['batch_size']
self.t = args['data']['lag']
self.n = args['data']['num_nodes'] self.n = args['data']['num_nodes']
self.c = args['data']['input_dim']
def forward(self, x):
def transpose(self, x : torch.Tensor):
# [b, t, n, c] -> [b*n, t, c] # [b, t, n, c] -> [b*n, t, c]
b, t, n, c = x.shape self.b = x.shape[0]
x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) x = x[..., :-2]
return x, b, t, n, c x = x.permute(0, 2, 1, 3)
x = x.reshape(self.b*self.n, self.t, self.c)
return x
def inverse(self, x, b, t, n, c): def inv_transpose(self, x : torch.Tensor):
return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) x = x.reshape(self.b, self.n, self.t, self.c)
x = x.permute(0, 2, 1, 3)
return x
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, """模型训练器,负责整个训练流程的管理"""
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
def __init__(self, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,
args, lr_scheduler=None,):
# 设备和基本参数
self.config = args self.config = args
self.device = args["basic"]["device"] self.device = args["basic"]["device"]
self.args = args["train"] train_args = args["train"]
# 模型和训练相关组件
self.model = model.to(self.device) self.model = model
self.loss = loss self.loss = loss
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
# 数据加载器
self.train_loader = train_loader self.train_loader = train_loader
self.val_loader = val_loader or test_loader self.val_loader = val_loader
self.test_loader = test_loader self.test_loader = test_loader
# 数据处理工具
self.scaler = scaler self.scaler = scaler
self.args = train_args
self.ts_wrapper = TSWrapper(args)
# 初始化路径、日志和统计
self._initialize_paths(train_args)
self._initialize_logger(train_args)
self.ts = TSWrapper(args) def _initialize_paths(self, args):
self._init_paths() """初始化模型保存路径"""
self._init_logger() self.best_path = os.path.join(args["log_dir"], "best_model.pth")
self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth")
self.loss_figure_path = os.path.join(args["log_dir"], "loss.png")
# ---------------- init ---------------- def _initialize_logger(self, args):
def _init_paths(self): """初始化日志记录器"""
d = self.args["log_dir"] if not os.path.isdir(args["log_dir"]) and not args["debug"]:
self.best_path = os.path.join(d, "best_model.pth") os.makedirs(args["log_dir"], exist_ok=True)
self.best_test_path = os.path.join(d, "best_test_model.pth") self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
def _init_logger(self): def _run_epoch(self, epoch, dataloader, mode):
if not self.args["debug"]: """运行一个训练/验证/测试epoch"""
os.makedirs(self.args["log_dir"], exist_ok=True) # 设置模型模式和是否进行优化
self.logger = get_logger( if mode == "train": self.model.train(); optimizer_step = True
self.args["log_dir"], else: self.model.eval(); optimizer_step = False
name=self.model.__class__.__name__,
debug=self.args["debug"],
)
# ---------------- epoch ---------------- # 初始化变量
def _run_epoch(self, epoch, loader, mode): total_loss = 0
is_train = mode == "train" epoch_time = time.time()
self.model.train() if is_train else self.model.eval()
total_loss, start = 0.0, time.time()
y_pred, y_true = [], [] y_pred, y_true = [], []
with torch.set_grad_enabled(is_train): # 训练/验证循环
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): with torch.set_grad_enabled(optimizer_step):
data, target = data.to(self.device), target.to(self.device) progress_bar = tqdm(
label = target[..., :self.args["output_dim"]] enumerate(dataloader),
total=len(dataloader),
x, b, t, n, c = self.ts.forward(data) desc=f"{mode.capitalize()} Epoch {epoch}"
out = self.model(x) )
out = self.ts.inverse(out, b, t, n, c) for _, (data, target) in progress_bar:
# 转移数据
data = data.to(self.device)
target = target.to(self.device)
label = target[..., : self.args["output_dim"]]
# 转换为 [b*n, t, c]
data = self.ts_wrapper.transpose(data)
# 计算loss和反归一化loss
output = self.model(data)
# 转换回[b, t, n, c]
output = self.ts_wrapper.inv_transpose(output)
# 我的调试开关
if os.environ.get("TRY") == "True": if os.environ.get("TRY") == "True":
print(out.shape, label.shape) print(f"[{'' if output.shape == label.shape else ''}]: output: {output.shape}, label: {label.shape}")
assert out.shape == label.shape assert False
loss = self.loss(output, label)
loss = self.loss(out, label) d_output = self.scaler.inverse_transform(output)
d_out = self.scaler.inverse_transform(out) d_label = self.scaler.inverse_transform(label)
d_lbl = self.scaler.inverse_transform(label) d_loss = self.loss(d_output, d_label)
d_loss = self.loss(d_out, d_lbl) # 累积损失和预测结果
total_loss += d_loss.item() total_loss += d_loss.item()
y_pred.append(d_out.detach().cpu()) y_pred.append(d_output.detach().cpu())
y_true.append(d_lbl.detach().cpu()) y_true.append(d_label.detach().cpu())
# 反向传播和优化(仅在训练模式)
if is_train and self.optimizer: if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
# 梯度裁剪(如果需要)
if self.args["grad_norm"]: if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.model.parameters(),
self.args["max_grad_norm"]
)
self.optimizer.step() self.optimizer.step()
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
y_pred = torch.cat(y_pred) # 合并所有批次的预测结果
y_true = torch.cat(y_true) y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
mae, rmse, mape = all_metrics( # 计算损失并记录指标
y_pred, y_true, avg_loss = total_loss / len(dataloader)
self.args["mae_thresh"], mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.args["mape_thresh"]
)
self.logger.info( self.logger.info(
f"Epoch #{epoch:02d} {mode:<5} " f"Epoch #{epoch:02d}: {mode.capitalize():<5} "
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"
f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s"
) )
return total_loss / len(loader) return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, "train")
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, "test")
# ---------------- train ----------------
def train(self): def train(self):
best, best_test = float("inf"), float("inf") # 初始化记录
best_w, best_test_w = None, None best_model, best_test_model = None, None
patience = 0 best_loss, best_test_loss = float("inf"), float("inf")
not_improved_count = 0
self.logger.info("Training started") # 开始训练
self.logger.info("Training process started")
# 训练循环
for epoch in range(1, self.args["epochs"] + 1): for epoch in range(1, self.args["epochs"] + 1):
losses = { # 训练、验证和测试一个epoch
"train": self._run_epoch(epoch, self.train_loader, "train"), train_epoch_loss = self.train_epoch(epoch)
"val": self._run_epoch(epoch, self.val_loader, "val"), val_epoch_loss = self.val_epoch(epoch)
"test": self._run_epoch(epoch, self.test_loader, "test"), test_epoch_loss = self.test_epoch(epoch)
} # 检查梯度爆炸
if train_epoch_loss > 1e6:
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected. Ending...")
self.logger.warning("Gradient explosion detected")
break break
# 更新最佳验证模型
if losses["val"] < best: if val_epoch_loss < best_loss:
best, patience = losses["val"], 0 best_loss = val_epoch_loss
best_w = copy.deepcopy(self.model.state_dict()) not_improved_count = 0
self.logger.info("Best validation model saved") best_model = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
else: else:
patience += 1 not_improved_count += 1
# 早停
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: if self._should_early_stop(not_improved_count):
self.logger.info("Early stopping triggered")
break break
# 更新最佳测试模型
if losses["test"] < best_test: if test_epoch_loss < best_test_loss:
best_test = losses["test"] best_test_loss = test_epoch_loss
best_test_w = copy.deepcopy(self.model.state_dict()) best_test_model = copy.deepcopy(self.model.state_dict())
# 保存最佳模型
if not self.args["debug"]: if not self.args["debug"]:
torch.save(best_w, self.best_path) self._save_best_models(best_model, best_test_model)
torch.save(best_test_w, self.best_test_path) # 最终评估
self._finalize_training(best_model, best_test_model)
self._final_test(best_w, best_test_w) def _should_early_stop(self, not_improved_count):
"""检查是否满足早停条件"""
if (
self.args["early_stop"]
and not_improved_count == self.args["early_stop_patience"]
):
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops."
)
return True
return False
# ---------------- final test ---------------- def _save_best_models(self, best_model, best_test_model):
def _final_test(self, best_w, best_test_w): """保存最佳模型到文件"""
for name, w in [("best val", best_w), ("best test", best_test_w)]: torch.save(best_model, self.best_path)
self.model.load_state_dict(w) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Testing on {name} model") self.logger.info(
self.evaluate() f"Best models saved at {self.best_path} and {self.best_test_path}"
)
# ---------------- evaluate ---------------- def _log_model_params(self):
def evaluate(self): """输出模型可训练参数数量"""
self.model.eval() total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
"""对模型进行评估并输出性能指标"""
# 确定设备信息
device = None
output_dim = None
# 处理不同的参数格式
if isinstance(args, dict):
if "basic" in args:
# 完整配置情况
device = args["basic"]["device"]
output_dim = args["train"]["output_dim"]
else:
# 只有train_args情况从模型获取设备
device = next(model.parameters()).device
output_dim = args["output_dim"]
else:
raise ValueError(f"Unsupported args type: {type(args)}")
# 加载模型检查点(如果提供了路径)
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"])
model.to(device)
# 设置为评估模式
model.eval()
# 收集预测和真实标签
y_pred, y_true = [], [] y_pred, y_true = [], []
# 不计算梯度的情况下进行预测
with torch.no_grad(): with torch.no_grad():
for data, target in self.test_loader: for data, target in data_loader:
data, target = data.to(self.device), target.to(self.device) # 将数据和标签移动到指定设备
label = target[..., :self.args["output_dim"]] data = data.to(device)
target = target.to(device)
x, b, t, n, c = self.ts.forward(data) data = data[..., :-2]
out = self.model(x) b, t, n, c = data.shape
out = self.ts.inverse(out, b, t, n, c) data = data.permute(0, 2, 1, 3)
data = data.reshape(b*n, t, c)
label = target[..., : output_dim]
output = model(data)
output = output.reshape(b, n, t, c)
output = output.permute(0, 2, 1, 3)
y_pred.append(out.cpu()) y_pred.append(output.detach().cpu())
y_true.append(label.cpu()) y_true.append(label.detach().cpu())
d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_true = self.scaler.inverse_transform(torch.cat(y_true)) d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
for t in range(d_true.shape[1]): # 获取metrics参数
if "basic" in args:
# 完整配置情况
mae_thresh = args["train"]["mae_thresh"]
mape_thresh = args["train"]["mape_thresh"]
else:
# 只有train_args情况
mae_thresh = args["mae_thresh"]
mape_thresh = args["mape_thresh"]
# 计算并记录每个时间步的指标
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(
d_pred[:, t], d_true[:, t], d_y_pred[:, t, ...],
self.args["mae_thresh"], d_y_true[:, t, ...],
self.args["mape_thresh"] mae_thresh,
) mape_thresh,
self.logger.info(
f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
) )
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) # 计算并记录平均指标
self.logger.info( mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh)
f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
)
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -1,3 +1,4 @@
import math
import os import os
import time import time
import copy import copy
@ -7,100 +8,255 @@ from utils.loss_function import all_metrics
from tqdm import tqdm from tqdm import tqdm
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): """模型训练器,负责整个训练流程的管理"""
self.config, self.device, self.args = args, args["basic"]["device"], args["train"]
self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler
self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler
log_dir = self.args["log_dir"] def __init__(self, model, loss, optimizer,
self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] train_loader, val_loader, test_loader, scaler,
args, lr_scheduler=None,):
# 设备和基本参数
self.config = args
self.device = args["basic"]["device"]
train_args = args["train"]
# 模型和训练相关组件
self.model = model
self.loss = loss
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# 数据加载器
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
# 数据处理工具
self.scaler = scaler
self.args = train_args
# 初始化路径、日志和统计
self._initialize_paths(train_args)
self._initialize_logger(train_args)
if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) def _initialize_paths(self, args):
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) """初始化模型保存路径"""
self.logger.info(f"Experiment log path in: {log_dir}") self.best_path = os.path.join(args["log_dir"], "best_model.pth")
self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth")
self.loss_figure_path = os.path.join(args["log_dir"], "loss.png")
def train(self): def _initialize_logger(self, args):
best_model = best_test_model = None """初始化日志记录器"""
best_loss = best_test_loss = float("inf") if not os.path.isdir(args["log_dir"]) and not args["debug"]:
not_improved_count = 0 os.makedirs(args["log_dir"], exist_ok=True)
self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
self.logger.info("Training process started") def _run_epoch(self, epoch, dataloader, mode):
"""运行一个训练/验证/测试epoch"""
# 设置模型模式和是否进行优化
if mode == "train": self.model.train(); optimizer_step = True
else: self.model.eval(); optimizer_step = False
for epoch in range(1, self.args["epochs"] + 1): # 初始化变量
train_loss = self._run_epoch(epoch, self.train_loader, "train") total_loss = 0
val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") epoch_time = time.time()
test_loss = self._run_epoch(epoch, self.test_loader, "test")
if train_loss > 1e6:
self.logger.warning("Gradient explosion detected. Ending...")
break
if val_loss < best_loss:
best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]:
self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_loss < best_test_loss:
best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]:
self.model.load_state_dict(state_dict)
self.logger.info(f"Testing on {model_name} model")
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
def _run_epoch(self, epoch, dataloader, mode, log_horizon=False):
self.model.train() if mode == "train" else self.model.eval()
optimizer_step = mode == "train"
total_loss, epoch_time = 0, time.time()
y_pred, y_true = [], [] y_pred, y_true = [], []
# 训练/验证循环
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): progress_bar = tqdm(
data, target = data.to(self.device), target.to(self.device) enumerate(dataloader),
label = target[..., :self.args["output_dim"]] total=len(dataloader),
desc=f"{mode.capitalize()} Epoch {epoch}"
)
for _, (data, target) in progress_bar:
# 转移数据
data = data.to(self.device)
target = target.to(self.device)
label = target[..., : self.args["output_dim"]]
# 计算loss和反归一化loss
output = self.model(data) output = self.model(data)
# 我的调试开关
if os.environ.get("TRY") == "True":
print(f"[{'' if output.shape == label.shape else ''}]: output: {output.shape}, label: {label.shape}")
assert False
loss = self.loss(output, label) loss = self.loss(output, label)
d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) d_output = self.scaler.inverse_transform(output)
d_label = self.scaler.inverse_transform(label)
d_loss = self.loss(d_output, d_label) d_loss = self.loss(d_output, d_label)
# 累积损失和预测结果
total_loss += d_loss.item() total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu()) y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu()) y_true.append(d_label.detach().cpu())
# 反向传播和优化(仅在训练模式)
if optimizer_step and self.optimizer: if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) # 梯度裁剪(如果需要)
if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.optimizer.step() self.optimizer.step()
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) # 合并所有批次的预测结果
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算损失并记录指标
avg_loss = total_loss / len(dataloader)
mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"Epoch #{epoch:02d}: {mode.capitalize():<5} "
f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"
)
return avg_loss
if log_horizon: def train_epoch(self, epoch):
for t in range(y_true.shape[1]): return self._run_epoch(epoch, self.train_loader, "train")
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
if epoch and mode: def test_epoch(self, epoch):
self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") return self._run_epoch(epoch, self.test_loader, "test")
elif mode:
self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}")
return total_loss / len(dataloader) def train(self):
# 初始化记录
best_model, best_test_model = None, None
best_loss, best_test_loss = float("inf"), float("inf")
not_improved_count = 0
# 开始训练
self.logger.info("Training process started")
# 训练循环
for epoch in range(1, self.args["epochs"] + 1):
# 训练、验证和测试一个epoch
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
# 检查梯度爆炸
if train_epoch_loss > 1e6:
self.logger.warning("Gradient explosion detected. Ending...")
break
# 更新最佳验证模型
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
else:
not_improved_count += 1
# 早停
if self._should_early_stop(not_improved_count):
break
# 更新最佳测试模型
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
# 保存最佳模型
if not self.args["debug"]:
self._save_best_models(best_model, best_test_model)
# 最终评估
self._finalize_training(best_model, best_test_model)
def test(self, path=None): def _should_early_stop(self, not_improved_count):
"""检查是否满足早停条件"""
if (
self.args["early_stop"]
and not_improved_count == self.args["early_stop_patience"]
):
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops."
)
return True
return False
def _save_best_models(self, best_model, best_test_model):
"""保存最佳模型到文件"""
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(
f"Best models saved at {self.best_path} and {self.best_test_path}"
)
def _log_model_params(self):
"""输出模型可训练参数数量"""
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
"""对模型进行评估并输出性能指标"""
# 确定设备信息
device = None
output_dim = None
# 处理不同的参数格式
if isinstance(args, dict):
if "basic" in args:
# 完整配置情况
device = args["basic"]["device"]
output_dim = args["train"]["output_dim"]
else:
# 只有train_args情况从模型获取设备
device = next(model.parameters()).device
output_dim = args["output_dim"]
else:
raise ValueError(f"Unsupported args type: {type(args)}")
# 加载模型检查点(如果提供了路径)
if path: if path:
self.model.load_state_dict(torch.load(path)["state_dict"]) checkpoint = torch.load(path)
self.model.to(self.device) model.load_state_dict(checkpoint["state_dict"])
model.to(device)
self._run_epoch(None, self.test_loader, "test", log_horizon=True) # 设置为评估模式
model.eval()
# 收集预测和真实标签
y_pred, y_true = [], []
# 不计算梯度的情况下进行预测
with torch.no_grad():
for data, target in data_loader:
# 将数据和标签移动到指定设备
data = data.to(device)
target = target.to(device)
label = target[..., : output_dim]
output = model(data)
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
# 获取metrics参数
if "basic" in args:
# 完整配置情况
mae_thresh = args["train"]["mae_thresh"]
mape_thresh = args["train"]["mape_thresh"]
else:
# 只有train_args情况
mae_thresh = args["mae_thresh"]
mape_thresh = args["mape_thresh"]
# 计算并记录每个时间步的指标
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics(
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
mae_thresh,
mape_thresh,
)
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 计算并记录平均指标
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh)
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -7,31 +7,132 @@ from trainer.E32Trainer import Trainer as EXP_Trainer
from trainer.InformerTrainer import InformerTrainer from trainer.InformerTrainer import InformerTrainer
from trainer.TSTrainer import Trainer as TSTrainer from trainer.TSTrainer import Trainer as TSTrainer
def select_trainer( def select_trainer(
model, loss, optimizer, model,
train_loader, val_loader, test_loader, loss,
scaler, args, lr_scheduler, kwargs optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs,
): ):
model_name = args["basic"]["model"] model_name = args["basic"]["model"]
base_args = ( TS_model = ["HI", "PatchTST", "iTransformer"]
model, loss, optimizer, if model_name in TS_model:
train_loader, val_loader, test_loader, return TSTrainer(
scaler, args, lr_scheduler model,
) loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
if model_name in {"HI", "PatchTST", "iTransformer"}:
return TSTrainer(*base_args)
trainer_map = { match model_name:
"DCRNN": DCRNN_Trainer, case "STGNCDE":
"PDG2SEQ": PDG2SEQ_Trainer, return cdeTrainer(
"STMLP": STMLP_Trainer, model,
"EXP": EXP_Trainer, loss,
"Informer": InformerTrainer, optimizer,
} train_loader,
val_loader,
if model_name in {"STGNCDE", "STGNRDE"}: test_loader,
return cdeTrainer(*base_args, kwargs[0], None) scaler,
args,
return trainer_map.get(model_name, Trainer)(*base_args) lr_scheduler,
kwargs[0],
None,
)
case "STGNRDE":
return cdeTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs[0],
None,
)
case "DCRNN":
return DCRNN_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "PDG2SEQ":
return PDG2SEQ_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "STMLP":
return STMLP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "EXP":
return EXP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "Informer":
return InformerTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case _:
return Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)

View File

@ -18,7 +18,7 @@ def get_logger(root, name=None, debug=True):
logger.handlers.clear() logger.handlers.clear()
# 时间格式改为 年/月/日 时:分:秒 # 时间格式改为 年/月/日 时:分:秒
formatter = logging.Formatter("%(asctime)s - %(message)s", "%m/%d %H:%M") formatter = logging.Formatter("%(asctime)s - %(message)s", "%Y/%m/%d %H:%M:%S")
# 控制台输出 # 控制台输出
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()