diff --git a/config/ASTRA/v2_AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml similarity index 100% rename from config/ASTRA/v2_AirQuality.yaml rename to config/ASTRA_v2/AirQuality.yaml diff --git a/config/ASTRA/v2_BJTaxi-InFlow.yaml b/config/ASTRA_v2/BJTaxi-InFlow.yaml similarity index 100% rename from config/ASTRA/v2_BJTaxi-InFlow.yaml rename to config/ASTRA_v2/BJTaxi-InFlow.yaml diff --git a/config/ASTRA/v2_BJTaxi-OutFlow.yaml b/config/ASTRA_v2/BJTaxi-OutFlow.yaml similarity index 100% rename from config/ASTRA/v2_BJTaxi-OutFlow.yaml rename to config/ASTRA_v2/BJTaxi-OutFlow.yaml diff --git a/config/ASTRA/v2_METR-LA.yaml b/config/ASTRA_v2/METR-LA.yaml similarity index 97% rename from config/ASTRA/v2_METR-LA.yaml rename to config/ASTRA_v2/METR-LA.yaml index bf92089..dca4bb4 100644 --- a/config/ASTRA/v2_METR-LA.yaml +++ b/config/ASTRA_v2/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/v2_NYCBike-InFlow.yaml b/config/ASTRA_v2/NYCBike-InFlow.yaml similarity index 100% rename from config/ASTRA/v2_NYCBike-InFlow.yaml rename to config/ASTRA_v2/NYCBike-InFlow.yaml diff --git a/config/ASTRA/v2_NYCBike-OutFlow.yaml b/config/ASTRA_v2/NYCBike-OutFlow.yaml similarity index 100% rename from config/ASTRA/v2_NYCBike-OutFlow.yaml rename to config/ASTRA_v2/NYCBike-OutFlow.yaml diff --git a/config/ASTRA/v3_PEMS-BAY.yaml b/config/ASTRA_v2/PEMS-BAY.yaml similarity index 97% rename from config/ASTRA/v3_PEMS-BAY.yaml rename to config/ASTRA_v2/PEMS-BAY.yaml index 9f98483..2f6dfbf 100755 --- a/config/ASTRA/v3_PEMS-BAY.yaml +++ b/config/ASTRA_v2/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA_v3 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA/v2_SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml similarity index 97% rename from config/ASTRA/v2_SolarEnergy.yaml rename to config/ASTRA_v2/SolarEnergy.yaml index a45ad73..83a87c2 100644 --- a/config/ASTRA/v2_SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v2 seed: 2023 data: diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml new file mode 100644 index 0000000..68e6acc --- /dev/null +++ b/config/ASTRA_v3/AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/BJTaxi-InFlow.yaml b/config/ASTRA_v3/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..34abfd8 --- /dev/null +++ b/config/ASTRA_v3/BJTaxi-InFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/BJTaxi-OutFlow.yaml b/config/ASTRA_v3/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..8e6b30d --- /dev/null +++ b/config/ASTRA_v3/BJTaxi-OutFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/v3_METR-LA.yaml b/config/ASTRA_v3/METR-LA.yaml similarity index 93% rename from config/ASTRA/v3_METR-LA.yaml rename to config/ASTRA_v3/METR-LA.yaml index 5d22820..2b5512b 100644 --- a/config/ASTRA/v3_METR-LA.yaml +++ b/config/ASTRA_v3/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA_v3 + model: ASTRA_v3 seed: 2023 data: @@ -19,11 +19,9 @@ data: val_ratio: 0.2 model: - chebyshev_order: 3 d_ff: 128 d_model: 64 dropout: 0.2 - graph_hidden_dim: 32 gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 diff --git a/config/ASTRA_v3/NYCBike-InFlow.yaml b/config/ASTRA_v3/NYCBike-InFlow.yaml new file mode 100644 index 0000000..18c4fa3 --- /dev/null +++ b/config/ASTRA_v3/NYCBike-InFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA_v3/NYCBike-OutFlow.yaml b/config/ASTRA_v3/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..ff73662 --- /dev/null +++ b/config/ASTRA_v3/NYCBike-OutFlow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/ASTRA/v2_PEMS-BAY.yaml b/config/ASTRA_v3/PEMS-BAY.yaml similarity index 92% rename from config/ASTRA/v2_PEMS-BAY.yaml rename to config/ASTRA_v3/PEMS-BAY.yaml index c40034d..6739aeb 100755 --- a/config/ASTRA/v2_PEMS-BAY.yaml +++ b/config/ASTRA_v3/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA_v2 + model: ASTRA_v3 seed: 2023 data: @@ -19,11 +19,9 @@ data: val_ratio: 0.2 model: - chebyshev_order: 3 d_ff: 128 d_model: 64 dropout: 0.2 - graph_hidden_dim: 32 gpt_layers: 9 gpt_path: ./GPT-2 input_dim: 1 diff --git a/config/ASTRA_v3/SolarEnergy.yaml b/config/ASTRA_v3/SolarEnergy.yaml new file mode 100644 index 0000000..289b839 --- /dev/null +++ b/config/ASTRA_v3/SolarEnergy.yaml @@ -0,0 +1,54 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: ASTRA_v3 + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_ff: 128 + d_model: 64 + dropout: 0.2 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 137 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + word_num: 1000 + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/MTGNN/NYCBike-Inflow.yaml b/config/MTGNN/NYCBike-InFlow.yaml similarity index 100% rename from config/MTGNN/NYCBike-Inflow.yaml rename to config/MTGNN/NYCBike-InFlow.yaml diff --git a/config/MTGNN/NYCBike-Outflow.yaml b/config/MTGNN/NYCBike-OutFlow.yaml similarity index 100% rename from config/MTGNN/NYCBike-Outflow.yaml rename to config/MTGNN/NYCBike-OutFlow.yaml diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml index a40e11e..c035a44 100755 --- a/config/REPST/AirQuality.yaml +++ b/config/REPST/AirQuality.yaml @@ -13,8 +13,8 @@ data: input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 - steps_per_day: 288 + num_nodes: 35 + steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index a29bfc3..0e9aebf 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -184,7 +184,7 @@ class ASTRA(nn.Module): def forward(self, x): # 数据处理 - x = x[..., :1] # [B,T,N,1] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] # 图编码 @@ -203,7 +203,9 @@ class ASTRA(nn.Module): dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 - outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] - outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, self.input_dim) + outputs = outputs.permute(0,2,1,3) return outputs \ No newline at end of file diff --git a/train.py b/train.py index 139cdfa..2d3a32f 100644 --- a/train.py +++ b/train.py @@ -60,25 +60,17 @@ def run(config): case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") -def main(debug=False): - # 指定模型 - model_list = ["iTransformer"] - # 指定数据集 - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] - +def main(model, data, debug=False): # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) - os.environ["TRY"] = str(False) - + os.environ["TRY"] = str(debug) + for model in model_list: - for dataset in dataset_list: + for dataset in data: config_path = f"./config/{model}/{dataset}.yaml" # 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs config = read_config(config_path) print(f"\nRunning {model} on {dataset}") - # print(f"config: {config}") if os.environ.get("TRY") == "True": try: run(config) @@ -97,4 +89,9 @@ def main(debug=False): if __name__ == "__main__": # 调试用 - main(debug = True) \ No newline at end of file + # model_list = ["iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] + # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] + main(model_list, dataset_list, debug = True) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index c427072..932d8b3 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -73,8 +73,12 @@ class Trainer: out = self.ts.inverse(out, b, t, n, c) if os.environ.get("TRY") == "True": - print(out.shape, label.shape) - assert out.shape == label.shape + if out.shape == label.shape: + print("shape true") + assert False + else: + print("shape false") + assert False loss = self.loss(out, label) d_out = self.scaler.inverse_transform(out) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 65980b9..cdd444b 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -1,106 +1,180 @@ -import os -import time -import copy -import torch +import os, time, copy, torch +from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -from tqdm import tqdm class Trainer: - def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): - self.config, self.device, self.args = args, args["basic"]["device"], args["train"] - self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler - self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler - - log_dir = self.args["log_dir"] - self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] - - if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) - self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) - self.logger.info(f"Experiment log path in: {log_dir}") + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): - def train(self): - best_model = best_test_model = None - best_loss = best_test_loss = float("inf") - not_improved_count = 0 - - self.logger.info("Training process started") - - for epoch in range(1, self.args["epochs"] + 1): - train_loss = self._run_epoch(epoch, self.train_loader, "train") - val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - test_loss = self._run_epoch(epoch, self.test_loader, "test") - - if train_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") - break - - if val_loss < best_loss: - best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") - elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]: - self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") - break - - if test_loss < best_test_loss: - best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict()) - + self.config = args + self.device = args["basic"]["device"] + self.args = args["train"] - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") - - for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]: - self.model.load_state_dict(state_dict) - self.logger.info(f"Testing on {model_name} model") - self._run_epoch(None, self.test_loader, "test", log_horizon=True) - - def _run_epoch(self, epoch, dataloader, mode, log_horizon=False): - self.model.train() if mode == "train" else self.model.eval() - optimizer_step = mode == "train" - - total_loss, epoch_time = 0, time.time() + self.model = model.to(self.device) + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + self.train_loader = train_loader + self.val_loader = val_loader or test_loader + self.test_loader = test_loader + self.scaler = scaler + + self._init_paths() + self._init_logger() + + # ---------------- init ---------------- + def _init_paths(self): + d = self.args["log_dir"] + self.best_path = os.path.join(d, "best_model.pth") + self.best_test_path = os.path.join(d, "best_test_model.pth") + + def _init_logger(self): + if not self.args["debug"]: + os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger( + self.args["log_dir"], + name=self.model.__class__.__name__, + debug=self.args["debug"], + ) + + # ---------------- epoch ---------------- + def _run_epoch(self, epoch, loader, mode): + is_train = mode == "train" + self.model.train() if is_train else self.model.eval() + + total_loss, start = 0.0, time.time() y_pred, y_true = [], [] - - with torch.set_grad_enabled(optimizer_step): - for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): + + with torch.set_grad_enabled(is_train): + for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - - output = self.model(data) - loss = self.loss(output, label) - d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - + + out = self.model(data) + + if os.environ.get("TRY") == "True": + if out.shape == label.shape: + print(f"shape true, out: {out.shape}, label: {label.shape}") + assert False + else: + print(f"shape false, out: {out.shape}, label: {label.shape}") + assert False + + loss = self.loss(out, label) + d_out = self.scaler.inverse_transform(out) + d_lbl = self.scaler.inverse_transform(label) + d_loss = self.loss(d_out, d_lbl) + total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - - if optimizer_step and self.optimizer: + y_pred.append(d_out.detach().cpu()) + y_true.append(d_lbl.detach().cpu()) + + if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.args["max_grad_norm"] + ) self.optimizer.step() - - y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) - - if log_horizon: - for t in range(y_true.shape[1]): - mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - - avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - - if epoch and mode: - self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") - elif mode: - self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}") - - return total_loss / len(dataloader) - def test(self, path=None): - if path: - self.model.load_state_dict(torch.load(path)["state_dict"]) - self.model.to(self.device) + y_pred = torch.cat(y_pred) + y_true = torch.cat(y_true) + + mae, rmse, mape = all_metrics( + y_pred, y_true, + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + + self.logger.info( + f"Epoch #{epoch:02d} {mode:<5} " + f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " + f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + ) + return total_loss / len(loader) + + # ---------------- train ---------------- + def train(self): + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None + patience = 0 + + self.logger.info("Training started") + + for epoch in range(1, self.args["epochs"] + 1): + losses = { + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), + } + + if losses["train"] > 1e6: + self.logger.warning("Gradient explosion detected") + break + + if losses["val"] < best: + best, patience = losses["val"], 0 + best_w = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved") + else: + patience += 1 + + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: + self.logger.info("Early stopping triggered") + break + + if losses["test"] < best_test: + best_test = losses["test"] + best_test_w = copy.deepcopy(self.model.state_dict()) + + if not self.args["debug"]: + torch.save(best_w, self.best_path) + torch.save(best_test_w, self.best_test_path) + + self._final_test(best_w, best_test_w) + + # ---------------- final test ---------------- + def _final_test(self, best_w, best_test_w): + for name, w in [("best val", best_w), ("best test", best_test_w)]: + self.model.load_state_dict(w) + self.logger.info(f"Testing on {name} model") + self.evaluate() + + # ---------------- evaluate ---------------- + def evaluate(self): + self.model.eval() + y_pred, y_true = [], [] + + with torch.no_grad(): + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + out = self.model(data) + + y_pred.append(out.cpu()) + y_true.append(label.cpu()) + + d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) + d_true = self.scaler.inverse_transform(torch.cat(y_true)) + + for t in range(d_true.shape[1]): + mae, rmse, mape = all_metrics( + d_pred[:, t], d_true[:, t], + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + self.logger.info( + f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + ) - self._run_epoch(None, self.test_loader, "test", log_horizon=True) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" + ) +