修复反归一化错误
This commit is contained in:
parent
a46edc79a5
commit
96f2ea1239
|
|
@ -11,8 +11,8 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
default_graph: true
|
default_graph: true
|
||||||
horizon: 12
|
horizon: 24
|
||||||
lag: 12
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 7
|
num_nodes: 7
|
||||||
steps_per_day: 288
|
steps_per_day: 288
|
||||||
|
|
@ -24,8 +24,8 @@ data:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 12
|
pred_len: 24
|
||||||
seq_len: 12
|
seq_len: 24
|
||||||
patch_len: 6
|
patch_len: 6
|
||||||
stride: 7
|
stride: 7
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,8 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
default_graph: true
|
default_graph: true
|
||||||
horizon: 12
|
horizon: 24
|
||||||
lag: 12
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 207
|
num_nodes: 207
|
||||||
steps_per_day: 288
|
steps_per_day: 288
|
||||||
|
|
@ -24,8 +24,8 @@ data:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 12
|
pred_len: 24
|
||||||
seq_len: 12
|
seq_len: 24
|
||||||
patch_len: 6
|
patch_len: 6
|
||||||
stride: 7
|
stride: 7
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
|
|
@ -41,7 +41,7 @@ train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 100
|
epochs: 1
|
||||||
grad_norm: false
|
grad_norm: false
|
||||||
loss_func: mae
|
loss_func: mae
|
||||||
lr_decay: true
|
lr_decay: true
|
||||||
|
|
@ -52,7 +52,7 @@ train:
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
debug: false
|
debug: false
|
||||||
output_dim: 1
|
output_dim: 100
|
||||||
log_step: 1000
|
log_step: 1000
|
||||||
plot: false
|
plot: false
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,8 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
default_graph: true
|
default_graph: true
|
||||||
horizon: 12
|
horizon: 24
|
||||||
lag: 12
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 137
|
num_nodes: 137
|
||||||
steps_per_day: 288
|
steps_per_day: 288
|
||||||
|
|
@ -24,8 +24,8 @@ data:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 12
|
pred_len: 24
|
||||||
seq_len: 12
|
seq_len: 24
|
||||||
patch_len: 6
|
patch_len: 6
|
||||||
stride: 7
|
stride: 7
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
|
||||||
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
||||||
|
|
||||||
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
||||||
|
|
||||||
y_soft = F.gumbel_softmax(logits, tau, hard)
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||||
|
|
||||||
if hard:
|
if hard:
|
||||||
# 生成硬掩码
|
# 生成硬掩码
|
||||||
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module):
|
||||||
x = self.padding_patch_layer(x)
|
x = self.padding_patch_layer(x)
|
||||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||||
x_value_embed = self.value_embedding(x)
|
x_value_embed = self.value_embedding(x)
|
||||||
|
|
||||||
return self.dropout(x_value_embed), n_vars
|
return self.dropout(x_value_embed), n_vars
|
||||||
|
|
||||||
class ReprogrammingLayer(nn.Module):
|
class ReprogrammingLayer(nn.Module):
|
||||||
|
|
@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
|
||||||
|
|
||||||
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
||||||
B, L, H, E = target_embedding.shape
|
B, L, H, E = target_embedding.shape
|
||||||
|
|
||||||
scale = 1. / sqrt(E)
|
scale = 1. / sqrt(E)
|
||||||
|
|
||||||
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
||||||
|
|
||||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||||
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
||||||
|
|
||||||
return reprogramming_embedding
|
return reprogramming_embedding
|
||||||
|
|
||||||
|
|
@ -204,8 +204,8 @@ class Trainer:
|
||||||
|
|
||||||
# 累积损失和预测结果
|
# 累积损失和预测结果
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
y_pred.append(output.detach().cpu())
|
y_pred.append(d_output.detach().cpu())
|
||||||
y_true.append(label.detach().cpu())
|
y_true.append(d_label.detach().cpu())
|
||||||
|
|
||||||
# 更新进度条
|
# 更新进度条
|
||||||
progress_bar.set_postfix(loss=d_loss.item())
|
progress_bar.set_postfix(loss=d_loss.item())
|
||||||
|
|
@ -356,18 +356,15 @@ class Trainer:
|
||||||
y_pred.append(output)
|
y_pred.append(output)
|
||||||
y_true.append(label)
|
y_true.append(label)
|
||||||
|
|
||||||
# 合并所有批次的预测结果
|
|
||||||
if args["real_value"]:
|
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
|
||||||
else:
|
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
|
||||||
y_true = torch.cat(y_true, dim=0)
|
|
||||||
|
|
||||||
# 计算并记录每个时间步的指标
|
# 计算并记录每个时间步的指标
|
||||||
for t in range(y_true.shape[1]):
|
for t in range(d_y_true.shape[1]):
|
||||||
mae, rmse, mape = all_metrics(
|
mae, rmse, mape = all_metrics(
|
||||||
y_pred[:, t, ...],
|
d_y_pred[:, t, ...],
|
||||||
y_true[:, t, ...],
|
d_y_true[:, t, ...],
|
||||||
args["mae_thresh"],
|
args["mae_thresh"],
|
||||||
args["mape_thresh"],
|
args["mape_thresh"],
|
||||||
)
|
)
|
||||||
|
|
@ -377,7 +374,7 @@ class Trainer:
|
||||||
|
|
||||||
# 计算并记录平均指标
|
# 计算并记录平均指标
|
||||||
mae, rmse, mape = all_metrics(
|
mae, rmse, mape = all_metrics(
|
||||||
y_pred, y_true, args["mae_thresh"], args["mape_thresh"]
|
d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue