修复反归一化错误

This commit is contained in:
czzhangheng 2025-11-20 20:50:35 +08:00
parent a46edc79a5
commit 96f2ea1239
6 changed files with 23 additions and 33 deletions

View File

@ -11,8 +11,8 @@ data:
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
horizon: 24
lag: 24
normalizer: std
num_nodes: 7
steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16
model:
pred_len: 12
seq_len: 12
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2

View File

@ -11,8 +11,8 @@ data:
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
horizon: 24
lag: 24
normalizer: std
num_nodes: 207
steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16
model:
pred_len: 12
seq_len: 12
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
@ -41,7 +41,7 @@ train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
epochs: 1
grad_norm: false
loss_func: mae
lr_decay: true
@ -52,7 +52,7 @@ train:
real_value: true
weight_decay: 0
debug: false
output_dim: 1
output_dim: 100
log_step: 1000
plot: false
mae_thresh: None

View File

@ -11,8 +11,8 @@ data:
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
horizon: 24
lag: 24
normalizer: std
num_nodes: 137
steps_per_day: 288
@ -24,8 +24,8 @@ data:
batch_size: 16
model:
pred_len: 12
seq_len: 12
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2

View File

@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
return self.gumbel_softmax(logits, 1, self.k, self.hard)
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard)
if hard:
# 生成硬掩码
_, indices = y_soft.topk(k, dim=0) # 选择Top-K

View File

@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module):
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x_value_embed = self.value_embedding(x)
return self.dropout(x_value_embed), n_vars
class ReprogrammingLayer(nn.Module):
@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape
scale = 1. / sqrt(E)
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
return reprogramming_embedding

View File

@ -204,8 +204,8 @@ class Trainer:
# 累积损失和预测结果
total_loss += loss.item()
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
@ -356,18 +356,15 @@ class Trainer:
y_pred.append(output)
y_true.append(label)
# 合并所有批次的预测结果
if args["real_value"]:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
# 计算并记录每个时间步的指标
for t in range(y_true.shape[1]):
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics(
y_pred[:, t, ...],
y_true[:, t, ...],
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
args["mae_thresh"],
args["mape_thresh"],
)
@ -377,7 +374,7 @@ class Trainer:
# 计算并记录平均指标
mae, rmse, mape = all_metrics(
y_pred, y_true, args["mae_thresh"], args["mape_thresh"]
d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]
)
logger.info(
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"