refactor(trainer): 优化训练器代码结构并添加进度条显示

调整训练器代码结构,减少冗余代码,提高可读性
为训练过程添加tqdm进度条,实时显示loss信息
统一TRY环境变量的输出格式
简化日志记录和模型保存逻辑
This commit is contained in:
czzhangheng 2025-12-16 16:40:40 +08:00
parent 659b41f612
commit 85257bc61c
8 changed files with 52 additions and 141 deletions

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 64 batch_size: 16
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -34,7 +34,7 @@ model:
word_num: 1000 word_num: 1000
train: train:
batch_size: 64 batch_size: 16
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 64 batch_size: 16
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -40,7 +40,7 @@ model:
supports: null supports: null
train: train:
batch_size: 64 batch_size: 16
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 64 batch_size: 16
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -40,7 +40,7 @@ model:
supports: null supports: null
train: train:
batch_size: 64 batch_size: 16
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -10,7 +10,7 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
input_dim: 6 input_dim: 1
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 137 num_nodes: 137

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 64 batch_size: 16
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -34,7 +34,7 @@ model:
word_num: 1000 word_num: 1000
train: train:
batch_size: 64 batch_size: 16
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -13,7 +13,7 @@ def read_config(config_path):
# 全局配置 # 全局配置
device = "cuda:0" # 指定设备为cuda:0 device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 100 # 训练轮数 epochs = 1 # 训练轮数
# 拷贝项 # 拷贝项
config["basic"]["device"] = device config["basic"]["device"] = device
@ -91,7 +91,7 @@ if __name__ == "__main__":
# 调试用 # 调试用
model_list = ["iTransformer", "PatchTST", "HI"] model_list = ["iTransformer", "PatchTST", "HI"]
# model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"]
# model_list = ["iTransformer"] # model_list = ["MTGNN"]
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
# dataset_list = ["AirQuality"] # dataset_list = ["AirQuality"]
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"]

View File

@ -5,86 +5,46 @@ from utils.loss_function import all_metrics
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None):
train_loader, val_loader, test_loader, self.device, self.args = args["basic"]["device"], args["train"]
scaler, args, lr_scheduler=None): self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler
self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader
self.config = args
self.device = args["basic"]["device"]
self.args = args["train"]
self.model = model.to(self.device)
self.loss = loss
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_loader = train_loader
self.val_loader = val_loader or test_loader
self.test_loader = test_loader
self.scaler = scaler self.scaler = scaler
self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat
# ---------- shape magic (replace TSWrapper) ----------
self.pack = lambda x: (
x[..., :-2]
.permute(0, 2, 1, 3)
.reshape(-1, x.size(1), x.size(3) - 2),
x.shape
)
self.unpack = lambda y, s: (
y.reshape(s[0], s[2], s[1], -1)
.permute(0, 2, 1, 3)
)
# ---------- inverse scaler ----------
self.inv = lambda x: torch.cat(
[s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)],
dim=-1
)
self._init_paths() self._init_paths()
self._init_logger() self._init_logger()
# ---------- shape magic (replace TSWrapper) ----------
self.pack = lambda x:(x[..., :-2].permute(0, 2, 1, 3).reshape(-1, x.size(1), x.size(3) - 2), x.shape)
self.unpack = lambda y, s: (y.reshape(s[0], s[2], s[1], -1).permute(0, 2, 1, 3))
# ---------------- init ---------------- # ---------------- init ----------------
def _init_paths(self): def _init_paths(self):
d = self.args["log_dir"] d = self.args["log_dir"]
self.best_path = os.path.join(d, "best_model.pth") self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth")
self.best_test_path = os.path.join(d, "best_test_model.pth")
def _init_logger(self): def _init_logger(self):
if not self.args["debug"]: if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True)
os.makedirs(self.args["log_dir"], exist_ok=True) self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"])
self.logger = get_logger(
self.args["log_dir"],
name=self.model.__class__.__name__,
debug=self.args["debug"]
)
# ---------------- epoch ---------------- # ---------------- epoch ----------------
def _run_epoch(self, epoch, loader, mode): def _run_epoch(self, epoch, loader, mode):
is_train = mode == "train" is_train = mode == "train"
self.model.train() if is_train else self.model.eval() self.model.train() if is_train else self.model.eval()
total_loss, start = 0.0, time.time() total_loss, start = 0.0, time.time()
y_pred, y_true = [], [] y_pred, y_true = [], []
with torch.set_grad_enabled(is_train): with torch.set_grad_enabled(is_train):
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader))
for data, target in bar:
data, target = data.to(self.device), target.to(self.device) data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]] label = target[..., :self.args["output_dim"]]
x, shp = self.pack(data) x, shp = self.pack(data)
out = self.unpack(self.model(x), shp) out = self.unpack(self.model(x), shp)
if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else ''} "
if os.environ.get("TRY") == "True": f"out: {out.shape}, label: {label.shape} \n"); assert False
print(f"out:{out.shape} label:{label.shape}",
"" if out.shape == label.shape else "")
assert False
loss = self.loss(out, label) loss = self.loss(out, label)
d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化
d_out, d_lbl = self.inv(out), self.inv(label)
d_loss = self.loss(d_out, d_lbl) d_loss = self.loss(d_out, d_lbl)
total_loss += d_loss.item() total_loss += d_loss.item()
y_pred.append(d_out.detach().cpu()) y_pred.append(d_out.detach().cpu())
y_true.append(d_lbl.detach().cpu()) y_true.append(d_lbl.detach().cpu())
@ -92,70 +52,38 @@ class Trainer:
if is_train and self.optimizer: if is_train and self.optimizer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
if self.args["grad_norm"]: if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.args["max_grad_norm"]
)
self.optimizer.step() self.optimizer.step()
bar.set_postfix({"loss": f"{d_loss.item():.4f}"})
y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
y_pred, y_true, self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s")
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Epoch #{epoch:02d} {mode:<5} "
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} "
f"MAPE:{mape:7.4f} "
f"Time:{time.time() - start:.2f}s"
)
return total_loss / len(loader) return total_loss / len(loader)
# ---------------- train ---------------- # ---------------- train ----------------
def train(self): def train(self):
best = best_test = float("inf") best, best_test = float("inf"), float("inf")
best_w = best_test_w = None best_w, best_test_w = None, None
patience = 0 patience = 0
self.logger.info("Training started") self.logger.info("Training started")
for epoch in range(1, self.args["epochs"] + 1): for epoch in range(1, self.args["epochs"] + 1):
losses = { losses = {
k: self._run_epoch(epoch, l, k) "train": self._run_epoch(epoch, self.train_loader, "train"),
for k, l in [ "val": self._run_epoch(epoch, self.val_loader, "val"),
("train", self.train_loader), "test": self._run_epoch(epoch, self.test_loader, "test"),
("val", self.val_loader),
("test", self.test_loader)
]
} }
if losses["train"] > 1e6: if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
self.logger.warning("Gradient explosion detected") if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
break else: patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
if losses["val"] < best: if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
best, patience = losses["val"], 0
best_w = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved")
else:
patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]:
self.logger.info("Early stopping triggered")
break
if losses["test"] < best_test:
best_test = losses["test"]
best_test_w = copy.deepcopy(self.model.state_dict())
if not self.args["debug"]: if not self.args["debug"]:
torch.save(best_w, self.best_path) torch.save(best_w, self.best_path)
torch.save(best_test_w, self.best_test_path) torch.save(best_test_w, self.best_test_path)
self._final_test(best_w, best_test_w) self._final_test(best_w, best_test_w)
# ---------------- final test ---------------- # ---------------- final test ----------------
@ -174,32 +102,13 @@ class Trainer:
for data, target in self.test_loader: for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device) data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]] label = target[..., :self.args["output_dim"]]
y_pred.append(self.model(data).cpu())
x, shp = self.pack(data)
out = self.unpack(self.model(x), shp)
y_pred.append(out.cpu())
y_true.append(label.cpu()) y_true.append(label.cpu())
d_pred = self.inv(torch.cat(y_pred)) d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化
d_true = self.inv(torch.cat(y_true))
for t in range(d_true.shape[1]): for t in range(d_true.shape[1]):
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"])
d_pred[:, t], d_true[:, t], self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}")
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Horizon {t+1:02d} "
f"MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
)
mae, rmse, mape = all_metrics( avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"])
d_pred, d_true, self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}")
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"AVG MAE:{mae:.4f} AVG RMSE:{rmse:.4f} AVG MAPE:{mape:.4f}"
)

View File

@ -30,12 +30,13 @@ class Trainer:
y_pred, y_true = [], [] y_pred, y_true = [], []
with torch.set_grad_enabled(is_train): with torch.set_grad_enabled(is_train):
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader))
for data, target in bar:
data, target = data.to(self.device), target.to(self.device) data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]] label = target[..., :self.args["output_dim"]]
out = self.model(data) out = self.model(data)
if os.environ.get("TRY") == "True": print(f"out: {out.shape}, label: {label.shape} \ if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else ''} "
{'' if out.shape == label.shape else ''}"); assert False f"out: {out.shape}, label: {label.shape} \n"); assert False
loss = self.loss(out, label) loss = self.loss(out, label)
d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化
d_loss = self.loss(d_out, d_lbl) d_loss = self.loss(d_out, d_lbl)
@ -48,6 +49,7 @@ class Trainer:
loss.backward() loss.backward()
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.optimizer.step() self.optimizer.step()
bar.set_postfix({"loss": f"{d_loss.item():.4f}"})
y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])