From 19fd7622a379111e5772bd3b7f3d713ab154b15c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 11 Dec 2025 23:16:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9InFormer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/Informer/AirQuality.yaml | 66 +++++++ config/Informer/BJTaxi-Inflow.yaml | 66 +++++++ config/Informer/BJTaxi-Outflow.yaml | 66 +++++++ config/Informer/METR-LA.yaml | 66 +++++++ config/Informer/NYCBike-Inflow.yaml | 66 +++++++ config/Informer/NYCBike-Outflow.yaml | 66 +++++++ config/Informer/PEMS-BAY.yaml | 66 +++++++ config/Informer/SolarEnergy.yaml | 66 +++++++ dataloader/Informer_loader.py | 179 +++++++++++++++++++ dataloader/loader_selector.py | 5 +- model/Informer/attn.py | 163 +++++++++++++++++ model/Informer/decoder.py | 51 ++++++ model/Informer/embed.py | 129 ++++++++++++++ model/Informer/encoder.py | 98 +++++++++++ model/Informer/masking.py | 24 +++ model/Informer/model.py | 141 +++++++++++++++ model/model_selector.py | 3 + test_informer.py | 57 ++++++ train.py | 24 ++- trainer/InformerTrainer.py | 250 +++++++++++++++++++++++++++ trainer/trainer_selector.py | 13 ++ 21 files changed, 1656 insertions(+), 9 deletions(-) create mode 100644 config/Informer/AirQuality.yaml create mode 100644 config/Informer/BJTaxi-Inflow.yaml create mode 100644 config/Informer/BJTaxi-Outflow.yaml create mode 100644 config/Informer/METR-LA.yaml create mode 100644 config/Informer/NYCBike-Inflow.yaml create mode 100644 config/Informer/NYCBike-Outflow.yaml create mode 100644 config/Informer/PEMS-BAY.yaml create mode 100644 config/Informer/SolarEnergy.yaml create mode 100644 dataloader/Informer_loader.py create mode 100644 model/Informer/attn.py create mode 100644 model/Informer/decoder.py create mode 100644 model/Informer/embed.py create mode 100644 model/Informer/encoder.py create mode 100644 model/Informer/masking.py create mode 100644 model/Informer/model.py create mode 100644 test_informer.py create mode 100644 trainer/InformerTrainer.py diff --git a/config/Informer/AirQuality.yaml b/config/Informer/AirQuality.yaml new file mode 100644 index 0000000..4b1568a --- /dev/null +++ b/config/Informer/AirQuality.yaml @@ -0,0 +1,66 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 24 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 6 + dec_in: 6 + c_out: 6 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/BJTaxi-Inflow.yaml b/config/Informer/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..56d089d --- /dev/null +++ b/config/Informer/BJTaxi-Inflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/BJTaxi-Outflow.yaml b/config/Informer/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..875cce8 --- /dev/null +++ b/config/Informer/BJTaxi-Outflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/METR-LA.yaml b/config/Informer/METR-LA.yaml new file mode 100644 index 0000000..731fa3e --- /dev/null +++ b/config/Informer/METR-LA.yaml @@ -0,0 +1,66 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/NYCBike-Inflow.yaml b/config/Informer/NYCBike-Inflow.yaml new file mode 100644 index 0000000..30ca485 --- /dev/null +++ b/config/Informer/NYCBike-Inflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/NYCBike-Outflow.yaml b/config/Informer/NYCBike-Outflow.yaml new file mode 100644 index 0000000..9fcfe6b --- /dev/null +++ b/config/Informer/NYCBike-Outflow.yaml @@ -0,0 +1,66 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/PEMS-BAY.yaml b/config/Informer/PEMS-BAY.yaml new file mode 100644 index 0000000..961bd6f --- /dev/null +++ b/config/Informer/PEMS-BAY.yaml @@ -0,0 +1,66 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/Informer/SolarEnergy.yaml b/config/Informer/SolarEnergy.yaml new file mode 100644 index 0000000..0d31425 --- /dev/null +++ b/config/Informer/SolarEnergy.yaml @@ -0,0 +1,66 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: Informer + seed: 2023 + +data: + batch_size: 1024 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + label_len: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + label_len: 12 + pred_len: 24 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + d_layers: 1 + n_heads: 8 + output_attention: False + factor: 5 + attn: prob + embed: fixed + freq: h + distil: true + mix: true + enc_in: 1 + dec_in: 1 + c_out: 1 + + +train: + batch_size: 1024 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + label_len: 24 + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + pred_len: 24 + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/Informer_loader.py b/dataloader/Informer_loader.py new file mode 100644 index 0000000..1192c03 --- /dev/null +++ b/dataloader/Informer_loader.py @@ -0,0 +1,179 @@ +import numpy as np +import torch +from dataloader.data_selector import load_st_dataset +from utils.normalization import normalize_dataset + + +# ============================================================== +# MAIN ENTRY +# ============================================================== + +def get_dataloader(args, normalizer="std", single=True): + """ + Return dataloaders with x, y, x_mark, y_mark. + This version follows Informer/ETSformer official dataloader behavior. + """ + data = load_st_dataset(args) + args = args["data"] + + x, y, x_mark, y_mark = _prepare_data_with_windows(data, args) + + # --- split --- + split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio + x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) + y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) + x_mark_train, x_mark_val, x_mark_test = split_fn(x_mark, args["val_ratio"], args["test_ratio"]) + y_mark_train, y_mark_val, y_mark_test = split_fn(y_mark, args["val_ratio"], args["test_ratio"]) + + # --- normalization --- + scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) + _apply_existing_scaler(y_train, y_val, y_test, scaler, args) + + # reshape [b, t, n, c] -> [b*n, t, c] + (x_train, x_val, x_test, + y_train, y_val, y_test, + x_mark_train, x_mark_val, x_mark_test, + y_mark_train, y_mark_val, y_mark_test) = _reshape_tensor( + x_train, x_val, x_test, + y_train, y_val, y_test, + x_mark_train, x_mark_val, x_mark_test, + y_mark_train, y_mark_val, y_mark_test + ) + + # --- dataloaders --- + return ( + _create_dataloader(x_train, y_train, x_mark_train, y_mark_train, + args["batch_size"], True, False), + _create_dataloader(x_val, y_val, x_mark_val, y_mark_val, + args["batch_size"], False, False), + _create_dataloader(x_test, y_test, x_mark_test, y_mark_test, + args["batch_size"], False, False), + scaler + ) + + +# ============================================================== +# Informer-style WINDOW GENERATION +# ============================================================== + +def _prepare_data_with_windows(data, args): + """ + Generate x, y, x_mark, y_mark using Informer slicing rule. + + x: [seq_len] + y: [label_len + pred_len] + """ + seq_len = args["lag"] + label_len = args["label_len"] + pred_len = args["horizon"] + + L, N, C = data.shape + + # ---------- construct timestamp features ---------- + time_in_day, day_in_week = _generate_time_features(L, args) + data_mark = np.concatenate([time_in_day, day_in_week], axis=-1) + + xs, ys, x_marks, y_marks = [], [], [], [] + + for s_begin in range(L - seq_len - pred_len - 1): + s_end = s_begin + seq_len + r_begin = s_end - label_len + r_end = r_begin + label_len + pred_len + + xs.append(data[s_begin:s_end]) + ys.append(data[r_begin:r_end]) + + x_marks.append(data_mark[s_begin:s_end]) + y_marks.append(data_mark[r_begin:r_end]) + + return np.array(xs), np.array(ys), np.array(x_marks), np.array(y_marks) + + +# ============================================================== +# TIME FEATURE +# ============================================================== + +def _generate_time_features(L, args): + N = args["num_nodes"] + + # Time in day + tid = np.array([i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)]) + tid = np.tile(tid[:, None], (1, N)) + + # Day in week + diw = np.array([(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)]) + diw = np.tile(diw[:, None], (1, N)) + + return tid[..., None], diw[..., None] + + +# ============================================================== +# NORMALIZATION +# ============================================================== + +def _normalize_data(train_data, val_data, test_data, args, normalizer): + scaler = normalize_dataset( + train_data[..., :args["input_dim"]], + normalizer, args["column_wise"] + ) + for data in [train_data, val_data, test_data]: + data[..., :args["input_dim"]] = scaler.transform( + data[..., :args["input_dim"]] + ) + return scaler + + +def _apply_existing_scaler(train_data, val_data, test_data, scaler, args): + for data in [train_data, val_data, test_data]: + data[..., :args["input_dim"]] = scaler.transform( + data[..., :args["input_dim"]] + ) + + +# ============================================================== +# DATALOADER +# ============================================================== + +def _create_dataloader(x, y, x_mark, y_mark, batch_size, shuffle, drop_last): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = torch.utils.data.TensorDataset( + torch.tensor(x, dtype=torch.float32, device=device), + torch.tensor(y, dtype=torch.float32, device=device), + torch.tensor(x_mark, dtype=torch.float32, device=device), + torch.tensor(y_mark, dtype=torch.float32, device=device), + ) + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + + +# ============================================================== +# SPLIT +# ============================================================== + +def split_data_by_days(data, val_days, test_days, interval=30): + t = int((24 * 60) / interval) + test_data = data[-t * int(test_days):] + val_data = data[-t * int(test_days + val_days):-t * int(test_days)] + train_data = data[:-t * int(test_days + val_days)] + return train_data, val_data, test_data + + +def split_data_by_ratio(data, val_ratio, test_ratio): + L = len(data) + test_data = data[-int(L * test_ratio):] + val_data = data[-int(L * (test_ratio + val_ratio)):-int(L * test_ratio)] + train_data = data[: -int(L * (test_ratio + val_ratio))] + return train_data, val_data, test_data + + +# ============================================================== +# RESHAPE [B,T,N,C] -> [B*N,T,C] +# ============================================================== + +def _reshape_tensor(*tensors): + reshaped = [] + for x in tensors: + b, t, n, c = x.shape + x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c) + reshaped.append(x_new) + return reshaped diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index 88d1e2d..c1862df 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -4,12 +4,15 @@ from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader from dataloader.TSloader import get_dataloader as TS_loader +from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] - if model_name in TS_model: + if model_name == "Informer": + return Informer_loader(config, normalizer, single) + elif model_name in TS_model: return TS_loader(config, normalizer, single) else : match model_name: diff --git a/model/Informer/attn.py b/model/Informer/attn.py new file mode 100644 index 0000000..45344a8 --- /dev/null +++ b/model/Informer/attn.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from math import sqrt +from model.Informer.masking import TriangularCausalMask, ProbMask + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1./sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward(self, queries, keys, values, attn_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2,1) + keys = keys.transpose(2,1) + values = values.transpose(2,1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + + U_part = U_part if U_part='1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular') + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2) + return x + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4; hour_size = 24 + weekday_size = 7; day_size = 32; month_size = 13 + + Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding + if freq=='t': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + + # Check the size of x's last dimension to avoid index errors + last_dim = x.shape[-1] + + minute_x = 0. + hour_x = 0. + weekday_x = 0. + day_x = 0. + month_x = 0. + + # For our generated time features, we have only 2 dimensions: [day_of_week, hour] + # So we need to map them to the appropriate embedding layers + if last_dim > 0: + # Use the first dimension for hour + # Ensure hour is in the valid range [0, 23] + hour = torch.clamp(x[:,:,0], 0, 23) + hour_x = self.hour_embed(hour) + + if last_dim > 1: + # Use the second dimension for weekday + # Ensure weekday is in the valid range [0, 6] + weekday = torch.clamp(x[:,:,1], 0, 6) + weekday_x = self.weekday_embed(weekday) + + return hour_x + weekday_x + day_x + month_x + minute_x + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model) + + def forward(self, x): + return self.embed(x) + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + a = self.value_embedding(x) + b = self.position_embedding(x) + c = self.temporal_embedding(x_mark) + x = a + b + c + + return self.dropout(x) \ No newline at end of file diff --git a/model/Informer/encoder.py b/model/Informer/encoder.py new file mode 100644 index 0000000..7aeb877 --- /dev/null +++ b/model/Informer/encoder.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + padding = 1 if torch.__version__>='1.5.0' else 2 + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=padding, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1,2) + return x + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + # x [B, L, D] + # x = x + self.dropout(self.attention( + # x, x, x, + # attn_mask = attn_mask + # )) + new_x, attn = self.attention( + x, x, x, + attn_mask = attn_mask + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) + y = self.dropout(self.conv2(y).transpose(-1,1)) + + return self.norm2(x+y), attn + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + +class EncoderStack(nn.Module): + def __init__(self, encoders, inp_lens): + super(EncoderStack, self).__init__() + self.encoders = nn.ModuleList(encoders) + self.inp_lens = inp_lens + + def forward(self, x, attn_mask=None): + # x [B, L, D] + x_stack = []; attns = [] + for i_len, encoder in zip(self.inp_lens, self.encoders): + inp_len = x.shape[1]//(2**i_len) + x_s, attn = encoder(x[:, -inp_len:, :]) + x_stack.append(x_s); attns.append(attn) + x_stack = torch.cat(x_stack, -2) + + return x_stack, attns \ No newline at end of file diff --git a/model/Informer/masking.py b/model/Informer/masking.py new file mode 100644 index 0000000..7fd479e --- /dev/null +++ b/model/Informer/masking.py @@ -0,0 +1,24 @@ +import torch + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + +class ProbMask(): + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask \ No newline at end of file diff --git a/model/Informer/model.py b/model/Informer/model.py new file mode 100644 index 0000000..fb7471f --- /dev/null +++ b/model/Informer/model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack +from model.Informer.decoder import Decoder, DecoderLayer +from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer +from model.Informer.embed import DataEmbedding + +class Informer(nn.Module): + def __init__(self, args): + super(Informer, self).__init__() + self.pred_len = args['pred_len'] + self.attn = args['attn'] + self.output_attention = args['output_attention'] + + # Encoding + self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + # Attention + Attn = ProbAttention if args['attn']=='prob' else FullAttention + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'] + ) for l in range(args['e_layers']) + ], + [ + ConvLayer( + args['d_model'] + ) for l in range(args['e_layers']-1) + ] if args['distil'] else None, + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=args['mix']), + AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'], + ) + for l in range(args['d_layers']) + ], + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + dec_out = self.projection(dec_out) + + if self.output_attention: + return dec_out[:,-self.pred_len:,:], attns + else: + return dec_out[:,-self.pred_len:,:] # [B, L, D] + + +class InformerStack(nn.Module): + def __init__(self, args): + super(InformerStack, self).__init__() + self.pred_len = args['pred_len'] + self.attn = args['attn'] + self.output_attention = args['output_attention'] + + # Encoding + self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) + # Attention + Attn = ProbAttention if args['attn']=='prob' else FullAttention + # Encoder + + inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here + encoders = [ + Encoder( + [ + EncoderLayer( + AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'] + ) for l in range(el) + ], + [ + ConvLayer( + args['d_model'] + ) for l in range(el-1) + ] if args['distil'] else None, + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) for el in args['e_layers']] + self.encoder = EncoderStack(encoders, inp_lens) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=args['mix']), + AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), + args['d_model'], args['n_heads'], mix=False), + args['d_model'], + args['d_ff'], + dropout=args['dropout'], + activation=args['activation'], + ) + for l in range(args['d_layers']) + ], + norm_layer=torch.nn.LayerNorm(args['d_model']) + ) + self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + dec_out = self.projection(dec_out) + + if self.output_attention: + return dec_out[:,-self.pred_len:,:], attns + else: + return dec_out[:,-self.pred_len:,:] # [B, L, D] \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 5621037..f74dde2 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -28,6 +28,7 @@ from model.ASTRA.astra import ASTRA as ASTRA from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer +from model.Informer.model import Informer from model.HI.HI import HI from model.PatchTST.PatchTST import Model as PatchTST from model.MTGNN.MTGNN import gtnet as MTGNN @@ -96,6 +97,8 @@ def model_selector(config): return ASTRAv3(model_config) case "iTransformer": return iTransformer(model_config) + case "Informer": + return Informer(model_config) case "HI": return HI(model_config) case "PatchTST": diff --git a/test_informer.py b/test_informer.py new file mode 100644 index 0000000..b614533 --- /dev/null +++ b/test_informer.py @@ -0,0 +1,57 @@ +import torch +from model.model_selector import model_selector +import yaml + +# 读取配置文件 +with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f: + config = yaml.safe_load(f) + +# 初始化模型 +model = model_selector(config) +print('Informer模型初始化成功!') +print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}') + +# 创建测试数据 +B, T, C = 2, 24, 6 +x_enc = torch.randn(B, T, C) + +# 测试1: 完整参数 +print('\n测试1: 完整参数') +x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维 +x_dec = torch.randn(B, 12+24, C) # label_len + pred_len +x_mark_dec = torch.randn(B, 12+24, 4) +try: + output = model(x_enc, x_mark_enc, x_dec, x_mark_dec) + print(f'输出形状: {output.shape}') + print('测试1通过!') +except Exception as e: + print(f'测试1失败: {e}') + +# 测试2: 省略x_mark_enc +print('\n测试2: 省略x_mark_enc') +try: + output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) + print(f'输出形状: {output.shape}') + print('测试2通过!') +except Exception as e: + print(f'测试2失败: {e}') + +# 测试3: 省略x_dec和x_mark_dec +print('\n测试3: 省略x_dec和x_mark_dec') +try: + output = model(x_enc, x_mark_enc=x_mark_enc) + print(f'输出形状: {output.shape}') + print('测试3通过!') +except Exception as e: + print(f'测试3失败: {e}') + +# 测试4: 仅传入x_enc +print('\n测试4: 仅传入x_enc') +try: + output = model(x_enc) + print(f'输出形状: {output.shape}') + print('测试4通过!') +except Exception as e: + print(f'测试4失败: {e}') + +print('\n所有测试完成!') \ No newline at end of file diff --git a/train.py b/train.py index 7bd72ad..9c81209 100644 --- a/train.py +++ b/train.py @@ -11,13 +11,14 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备 + device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 100 + epochs = 1 # 拷贝项 config["basic"]["device"] = device config["model"]["device"] = device + config["train"]["device"] = device config["basic"]["seed"] = seed config["train"]["epochs"] = epochs return config @@ -62,14 +63,15 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["MTGNN"] + model_list = ["Informer"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["PEMS-BAY"] # 我的调试开关,不做测试就填 str(False) - os.environ["TRY"] = str(False) - + # os.environ["TRY"] = str(False) + os.environ["TRY"] = str(True) + for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" @@ -81,7 +83,13 @@ if __name__ == "__main__": try: run(config) except Exception as e: - pass + import traceback + import sys, traceback + tb_lines = traceback.format_exc().splitlines() + # 如果不是AssertionError,才打印完整traceback + if not tb_lines[-1].startswith("AssertionError"): + traceback.print_exc() + print(f"\n===== {model} on {dataset} failed with error: {e} =====\n") else: run(config) diff --git a/trainer/InformerTrainer.py b/trainer/InformerTrainer.py new file mode 100644 index 0000000..7b7ed27 --- /dev/null +++ b/trainer/InformerTrainer.py @@ -0,0 +1,250 @@ +import math +import os +import time +import copy +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + +class InformerTrainer: + """Informer模型训练器,负责整个训练流程的管理,支持多输入模型""" + + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): + # 设备和基本参数 + self.config = args + self.device = args["basic"]["device"] + train_args = args["train"] + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + # 数据处理工具 + self.scaler = scaler + self.args = train_args + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch,支持多输入模型""" + # 设置模型模式和是否进行优化 + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + for _, (x, y, x_mark, y_mark) in progress_bar: + # 转移数据 + x = x.to(self.device) + y = y[:, -self.args['pred_len']:, :self.args["output_dim"]].to(self.device) + x_mark = x_mark.to(self.device) + y_mark = y_mark.to(self.device) + # [256, 24, 6] + dec_inp = torch.zeros_like(y[:, -self.args['pred_len']:, :]).float() + # [256, 48(pred+label), 6] + dec_inp = torch.cat([y[:, :self.args['label_len'], :], dec_inp], dim=1).float().to(self.device) + + # 计算loss和反归一化loss + output = self.model(x, x_mark, dec_inp, y_mark) + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == y.shape else '❌'}]: output: {output.shape}, label: {y.shape}") + assert False + loss = self.loss(output, y) + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(y) + d_loss = self.loss(d_output, d_label) + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + # 计算损失并记录指标 + avg_loss = total_loss / len(dataloader) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + # 初始化记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + # 开始训练 + self.logger.info("Training process started") + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + # 早停 + if self._should_early_stop(not_improved_count): + break + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + # 最终评估 + self._finalize_training(best_model, best_test_model) + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标,支持多输入模型""" + device = args["device"] + + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + pred_len = args['pred_len'] + label_len = args['label_len'] + output_dim = args['output_dim'] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for _, (x, y, x_mark, y_mark) in enumerate(data_loader): + # 转移数据 + x = x.to(device) + y = y[:, -pred_len:, :output_dim].to(device) + x_mark = x_mark.to(device) + y_mark = y_mark.to(device) + # 生成dec_inp + dec_inp = torch.zeros_like(y[:, -pred_len:, :]).float() + dec_inp = torch.cat([y[:, :label_len, :], dec_inp], dim=1).float().to(device) + output = model(x, x_mark, dec_inp, y_mark) + y_pred.append(output.detach().cpu()) + y_true.append(y.detach().cpu()) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 97d6b0b..89340ea 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -4,6 +4,7 @@ from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer +from trainer.InformerTrainer import InformerTrainer def select_trainer( @@ -96,6 +97,18 @@ def select_trainer( args, lr_scheduler, ) + case "Informer": + return InformerTrainer( + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler, + ) case _: return Trainer( model,