修复反归一化错误
This commit is contained in:
parent
a46edc79a5
commit
96f2ea1239
|
|
@ -11,8 +11,8 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
horizon: 24
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 7
|
||||
steps_per_day: 288
|
||||
|
|
@ -24,8 +24,8 @@ data:
|
|||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
horizon: 24
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 207
|
||||
steps_per_day: 288
|
||||
|
|
@ -24,8 +24,8 @@ data:
|
|||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
|
|
@ -41,7 +41,7 @@ train:
|
|||
batch_size: 16
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
epochs: 1
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
|
|
@ -52,7 +52,7 @@ train:
|
|||
real_value: true
|
||||
weight_decay: 0
|
||||
debug: false
|
||||
output_dim: 1
|
||||
output_dim: 100
|
||||
log_step: 1000
|
||||
plot: false
|
||||
mae_thresh: None
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
horizon: 24
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 137
|
||||
steps_per_day: 288
|
||||
|
|
@ -24,8 +24,8 @@ data:
|
|||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
|
|||
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
||||
|
||||
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
||||
|
||||
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||
|
||||
if hard:
|
||||
# 生成硬掩码
|
||||
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module):
|
|||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x_value_embed = self.value_embedding(x)
|
||||
|
||||
return self.dropout(x_value_embed), n_vars
|
||||
|
||||
class ReprogrammingLayer(nn.Module):
|
||||
|
|
@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
|
|||
|
||||
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
||||
B, L, H, E = target_embedding.shape
|
||||
|
||||
scale = 1. / sqrt(E)
|
||||
|
||||
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
||||
|
||||
return reprogramming_embedding
|
||||
|
||||
|
|
@ -204,8 +204,8 @@ class Trainer:
|
|||
|
||||
# 累积损失和预测结果
|
||||
total_loss += loss.item()
|
||||
y_pred.append(output.detach().cpu())
|
||||
y_true.append(label.detach().cpu())
|
||||
y_pred.append(d_output.detach().cpu())
|
||||
y_true.append(d_label.detach().cpu())
|
||||
|
||||
# 更新进度条
|
||||
progress_bar.set_postfix(loss=d_loss.item())
|
||||
|
|
@ -356,18 +356,15 @@ class Trainer:
|
|||
y_pred.append(output)
|
||||
y_true.append(label)
|
||||
|
||||
# 合并所有批次的预测结果
|
||||
if args["real_value"]:
|
||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||
else:
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
|
||||
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
|
||||
|
||||
# 计算并记录每个时间步的指标
|
||||
for t in range(y_true.shape[1]):
|
||||
for t in range(d_y_true.shape[1]):
|
||||
mae, rmse, mape = all_metrics(
|
||||
y_pred[:, t, ...],
|
||||
y_true[:, t, ...],
|
||||
d_y_pred[:, t, ...],
|
||||
d_y_true[:, t, ...],
|
||||
args["mae_thresh"],
|
||||
args["mape_thresh"],
|
||||
)
|
||||
|
|
@ -377,7 +374,7 @@ class Trainer:
|
|||
|
||||
# 计算并记录平均指标
|
||||
mae, rmse, mape = all_metrics(
|
||||
y_pred, y_true, args["mae_thresh"], args["mape_thresh"]
|
||||
d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]
|
||||
)
|
||||
logger.info(
|
||||
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue