修复反归一化错误

This commit is contained in:
czzhangheng 2025-11-20 20:50:35 +08:00
parent a46edc79a5
commit 96f2ea1239
6 changed files with 23 additions and 33 deletions

View File

@ -11,8 +11,8 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
default_graph: true default_graph: true
horizon: 12 horizon: 24
lag: 12 lag: 24
normalizer: std normalizer: std
num_nodes: 7 num_nodes: 7
steps_per_day: 288 steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16 batch_size: 16
model: model:
pred_len: 12 pred_len: 24
seq_len: 12 seq_len: 24
patch_len: 6 patch_len: 6
stride: 7 stride: 7
dropout: 0.2 dropout: 0.2

View File

@ -11,8 +11,8 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
default_graph: true default_graph: true
horizon: 12 horizon: 24
lag: 12 lag: 24
normalizer: std normalizer: std
num_nodes: 207 num_nodes: 207
steps_per_day: 288 steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16 batch_size: 16
model: model:
pred_len: 12 pred_len: 24
seq_len: 12 seq_len: 24
patch_len: 6 patch_len: 6
stride: 7 stride: 7
dropout: 0.2 dropout: 0.2
@ -41,7 +41,7 @@ train:
batch_size: 16 batch_size: 16
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
epochs: 100 epochs: 1
grad_norm: false grad_norm: false
loss_func: mae loss_func: mae
lr_decay: true lr_decay: true
@ -52,7 +52,7 @@ train:
real_value: true real_value: true
weight_decay: 0 weight_decay: 0
debug: false debug: false
output_dim: 1 output_dim: 100
log_step: 1000 log_step: 1000
plot: false plot: false
mae_thresh: None mae_thresh: None

View File

@ -11,8 +11,8 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
default_graph: true default_graph: true
horizon: 12 horizon: 24
lag: 12 lag: 24
normalizer: std normalizer: std
num_nodes: 137 num_nodes: 137
steps_per_day: 288 steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16 batch_size: 16
model: model:
pred_len: 12 pred_len: 24
seq_len: 12 seq_len: 24
patch_len: 6 patch_len: 6
stride: 7 stride: 7
dropout: 0.2 dropout: 0.2

View File

@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
return self.gumbel_softmax(logits, 1, self.k, self.hard) return self.gumbel_softmax(logits, 1, self.k, self.hard)
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard) y_soft = F.gumbel_softmax(logits, tau, hard)
if hard: if hard:
# 生成硬掩码 # 生成硬掩码
_, indices = y_soft.topk(k, dim=0) # 选择Top-K _, indices = y_soft.topk(k, dim=0) # 选择Top-K

View File

@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module):
x = self.padding_patch_layer(x) x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x_value_embed = self.value_embedding(x) x_value_embed = self.value_embedding(x)
return self.dropout(x_value_embed), n_vars return self.dropout(x_value_embed), n_vars
class ReprogrammingLayer(nn.Module): class ReprogrammingLayer(nn.Module):
@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
def reprogramming(self, target_embedding, source_embedding, value_embedding): def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape B, L, H, E = target_embedding.shape
scale = 1. / sqrt(E) scale = 1. / sqrt(E)
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
A = self.dropout(torch.softmax(scale * scores, dim=-1)) A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
return reprogramming_embedding return reprogramming_embedding

View File

@ -204,8 +204,8 @@ class Trainer:
# 累积损失和预测结果 # 累积损失和预测结果
total_loss += loss.item() total_loss += loss.item()
y_pred.append(output.detach().cpu()) y_pred.append(d_output.detach().cpu())
y_true.append(label.detach().cpu()) y_true.append(d_label.detach().cpu())
# 更新进度条 # 更新进度条
progress_bar.set_postfix(loss=d_loss.item()) progress_bar.set_postfix(loss=d_loss.item())
@ -356,18 +356,15 @@ class Trainer:
y_pred.append(output) y_pred.append(output)
y_true.append(label) y_true.append(label)
# 合并所有批次的预测结果
if args["real_value"]: d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算并记录每个时间步的指标 # 计算并记录每个时间步的指标
for t in range(y_true.shape[1]): for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(
y_pred[:, t, ...], d_y_pred[:, t, ...],
y_true[:, t, ...], d_y_true[:, t, ...],
args["mae_thresh"], args["mae_thresh"],
args["mape_thresh"], args["mape_thresh"],
) )
@ -377,7 +374,7 @@ class Trainer:
# 计算并记录平均指标 # 计算并记录平均指标
mae, rmse, mape = all_metrics( mae, rmse, mape = all_metrics(
y_pred, y_true, args["mae_thresh"], args["mape_thresh"] d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]
) )
logger.info( logger.info(
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"