diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality.yaml index 1e69e2a..595c971 100755 --- a/config/REPST/BeijingAirQuality.yaml +++ b/config/REPST/BeijingAirQuality.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 7 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml index 2e57a1c..68340d1 100755 --- a/config/REPST/METR-LA.yaml +++ b/config/REPST/METR-LA.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 207 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 @@ -41,7 +41,7 @@ train: batch_size: 16 early_stop: true early_stop_patience: 15 - epochs: 100 + epochs: 1 grad_norm: false loss_func: mae lr_decay: true @@ -52,7 +52,7 @@ train: real_value: true weight_decay: 0 debug: false - output_dim: 1 + output_dim: 100 log_step: 1000 plot: false mae_thresh: None diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index 465e53d..282c929 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 137 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 diff --git a/model/REPST/normalizer.py b/model/REPST/normalizer.py index fb7e182..c112c7a 100644 --- a/model/REPST/normalizer.py +++ b/model/REPST/normalizer.py @@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module): return self.gumbel_softmax(logits, 1, self.k, self.hard) def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): - y_soft = F.gumbel_softmax(logits, tau, hard) - if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 289c4b1..1ba7976 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module): x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x_value_embed = self.value_embedding(x) - return self.dropout(x_value_embed), n_vars class ReprogrammingLayer(nn.Module): @@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module): def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape - scale = 1. / sqrt(E) - scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) - A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) - return reprogramming_embedding \ No newline at end of file diff --git a/trainer/Trainer.py b/trainer/Trainer.py index c7abff8..82c90f5 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -204,8 +204,8 @@ class Trainer: # 累积损失和预测结果 total_loss += loss.item() - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) @@ -356,18 +356,15 @@ class Trainer: y_pred.append(output) y_true.append(label) - # 合并所有批次的预测结果 - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) # 计算并记录每个时间步的指标 - for t in range(y_true.shape[1]): + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) @@ -377,7 +374,7 @@ class Trainer: # 计算并记录平均指标 mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] + d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] ) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"