diff --git a/README.md b/README.md index 2ee8496..0be77b9 100755 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin # 跑REPST -第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8。 +第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。 ```bash -python run.py --config ./config/REPST/PEMSD8.yaml +python run.py --config ./config/REPST/PEMS-BAY.yaml ``` diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 1333cb8..54e3c38 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -10,8 +10,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 325 steps_per_day: 288 @@ -23,8 +23,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 @@ -33,6 +33,7 @@ model: gpt_path: ./GPT-2 d_model: 64 n_heads: 1 + input_dim: 1 train: batch_size: 16 diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index f5e9663..fc4871f 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -15,12 +15,13 @@ class ReplicationPad1d(nn.Module): return output class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model): + def __init__(self, c_in, d_model, patch_num, input_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) - self.confusion_layer = nn.Linear(2, 1) + + self.confusion_layer = nn.Linear(patch_num * input_dim, 1) # if air_quality # self.confusion_layer = nn.Linear(42, 1) @@ -34,19 +35,18 @@ class TokenEmbedding(nn.Module): b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len # 768,64,25 x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num - x = self.confusion_layer(x) return x.reshape(b, n, -1) class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, dropout): + def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) - self.value_embedding = TokenEmbedding(patch_len, d_model) + self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 66468ea..2cade59 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -13,6 +13,7 @@ class repst(nn.Module): self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] self.patch_len = configs['patch_len'] + self.input_dim = configs['input_dim'] self.stride = configs['stride'] self.dropout = configs['dropout'] self.gpt_layers = configs['gpt_layers'] @@ -28,7 +29,7 @@ class repst(nn.Module): self.head_nf = self.d_ff * self.patch_nums # 64,6,7,0.2 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout) + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) self.gpts.h = self.gpts.h[:self.gpt_layers] diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 6c62ee3..3048d1b 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -6,6 +6,7 @@ import psutil import torch from utils.logger import get_logger from utils.loss_function import all_metrics +from tqdm import tqdm class TrainingStats: @@ -148,7 +149,8 @@ class Trainer: y_pred, y_true = [], [] with torch.set_grad_enabled(optimizer_step): - for batch_idx, (data, target) in enumerate(dataloader): + progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}") + for batch_idx, (data, target) in progress_bar: start_time = time.time() label = target[..., : self.args["output_dim"]] @@ -175,6 +177,9 @@ class Trainer: y_pred.append(output.detach().cpu()) y_true.append(label.detach().cpu()) + # Update progress bar with the current loss + progress_bar.set_postfix(loss=loss.item()) + y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) diff --git a/utils/Download_data.py b/utils/Download_data.py index 0dae1fd..fcd21f1 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -61,11 +61,6 @@ def check_and_download_data(): missing_adj = True missing_main_files = True else: - # 检查根目录下的 get_adj.py 文件 - if "get_adj.py" not in os.listdir(data_dir): - # print(f"根目录下缺少文件 get_adj.py。") - missing_adj = True - # 遍历预期的文件结构 for subfolder, expected_files in expected_structure.items(): subfolder_path = os.path.join(data_dir, subfolder) @@ -100,8 +95,6 @@ def check_and_download_data(): rearrange_dir() - - return True