修复confuse_layer硬编码bug
This commit is contained in:
parent
d0af46ea5f
commit
2ba061e57a
|
|
@ -36,9 +36,9 @@ wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin
|
|||
|
||||
|
||||
# 跑REPST
|
||||
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8。
|
||||
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
|
||||
|
||||
```bash
|
||||
python run.py --config ./config/REPST/PEMSD8.yaml
|
||||
python run.py --config ./config/REPST/PEMS-BAY.yaml
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
horizon: 24
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 325
|
||||
steps_per_day: 288
|
||||
|
|
@ -23,8 +23,8 @@ data:
|
|||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
|
|
@ -33,6 +33,7 @@ model:
|
|||
gpt_path: ./GPT-2
|
||||
d_model: 64
|
||||
n_heads: 1
|
||||
input_dim: 1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
|
|
|
|||
|
|
@ -15,12 +15,13 @@ class ReplicationPad1d(nn.Module):
|
|||
return output
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
self.confusion_layer = nn.Linear(2, 1)
|
||||
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
||||
# if air_quality
|
||||
# self.confusion_layer = nn.Linear(42, 1)
|
||||
|
||||
|
|
@ -34,19 +35,18 @@ class TokenEmbedding(nn.Module):
|
|||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||
# 768,64,25
|
||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||
|
||||
x = self.confusion_layer(x)
|
||||
return x.reshape(b, n, -1)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, dropout):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class repst(nn.Module):
|
|||
self.pred_len = configs['pred_len']
|
||||
self.seq_len = configs['seq_len']
|
||||
self.patch_len = configs['patch_len']
|
||||
self.input_dim = configs['input_dim']
|
||||
self.stride = configs['stride']
|
||||
self.dropout = configs['dropout']
|
||||
self.gpt_layers = configs['gpt_layers']
|
||||
|
|
@ -28,7 +29,7 @@ class repst(nn.Module):
|
|||
self.head_nf = self.d_ff * self.patch_nums
|
||||
|
||||
# 64,6,7,0.2
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout)
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||
|
||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import psutil
|
|||
import torch
|
||||
from utils.logger import get_logger
|
||||
from utils.loss_function import all_metrics
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TrainingStats:
|
||||
|
|
@ -148,7 +149,8 @@ class Trainer:
|
|||
y_pred, y_true = [], []
|
||||
|
||||
with torch.set_grad_enabled(optimizer_step):
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}")
|
||||
for batch_idx, (data, target) in progress_bar:
|
||||
start_time = time.time()
|
||||
|
||||
label = target[..., : self.args["output_dim"]]
|
||||
|
|
@ -175,6 +177,9 @@ class Trainer:
|
|||
y_pred.append(output.detach().cpu())
|
||||
y_true.append(label.detach().cpu())
|
||||
|
||||
# Update progress bar with the current loss
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
|
||||
|
|
|
|||
|
|
@ -61,11 +61,6 @@ def check_and_download_data():
|
|||
missing_adj = True
|
||||
missing_main_files = True
|
||||
else:
|
||||
# 检查根目录下的 get_adj.py 文件
|
||||
if "get_adj.py" not in os.listdir(data_dir):
|
||||
# print(f"根目录下缺少文件 get_adj.py。")
|
||||
missing_adj = True
|
||||
|
||||
# 遍历预期的文件结构
|
||||
for subfolder, expected_files in expected_structure.items():
|
||||
subfolder_path = os.path.join(data_dir, subfolder)
|
||||
|
|
@ -100,8 +95,6 @@ def check_and_download_data():
|
|||
|
||||
rearrange_dir()
|
||||
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue