修复confuse_layer硬编码bug

This commit is contained in:
czzhangheng 2025-11-11 17:26:05 +08:00
parent d0af46ea5f
commit 2ba061e57a
6 changed files with 20 additions and 20 deletions

View File

@ -36,9 +36,9 @@ wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin
# 跑REPST
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8。
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY
```bash
python run.py --config ./config/REPST/PEMSD8.yaml
python run.py --config ./config/REPST/PEMS-BAY.yaml
```

View File

@ -10,8 +10,8 @@ data:
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
horizon: 24
lag: 24
normalizer: std
num_nodes: 325
steps_per_day: 288
@ -23,8 +23,8 @@ data:
batch_size: 16
model:
pred_len: 12
seq_len: 12
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
@ -33,6 +33,7 @@ model:
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
train:
batch_size: 16

View File

@ -15,12 +15,13 @@ class ReplicationPad1d(nn.Module):
return output
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
def __init__(self, c_in, d_model, patch_num, input_dim):
super(TokenEmbedding, self).__init__()
padding = 1
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
self.confusion_layer = nn.Linear(2, 1)
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
# if air_quality
# self.confusion_layer = nn.Linear(42, 1)
@ -34,19 +35,18 @@ class TokenEmbedding(nn.Module):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, dropout):
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = ReplicationPad1d((0, stride))
self.value_embedding = TokenEmbedding(patch_len, d_model)
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):

View File

@ -13,6 +13,7 @@ class repst(nn.Module):
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
@ -28,7 +29,7 @@ class repst(nn.Module):
self.head_nf = self.d_ff * self.patch_nums
# 64,6,7,0.2
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout)
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]

View File

@ -6,6 +6,7 @@ import psutil
import torch
from utils.logger import get_logger
from utils.loss_function import all_metrics
from tqdm import tqdm
class TrainingStats:
@ -148,7 +149,8 @@ class Trainer:
y_pred, y_true = [], []
with torch.set_grad_enabled(optimizer_step):
for batch_idx, (data, target) in enumerate(dataloader):
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}")
for batch_idx, (data, target) in progress_bar:
start_time = time.time()
label = target[..., : self.args["output_dim"]]
@ -175,6 +177,9 @@ class Trainer:
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
# Update progress bar with the current loss
progress_bar.set_postfix(loss=loss.item())
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)

View File

@ -61,11 +61,6 @@ def check_and_download_data():
missing_adj = True
missing_main_files = True
else:
# 检查根目录下的 get_adj.py 文件
if "get_adj.py" not in os.listdir(data_dir):
# print(f"根目录下缺少文件 get_adj.py。")
missing_adj = True
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
@ -100,8 +95,6 @@ def check_and_download_data():
rearrange_dir()
return True