parent
c9a5a54d90
commit
86fabd4ca7
|
|
@ -0,0 +1,51 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 358
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
embed_dim: 10
|
||||||
|
rnn_units: 64
|
||||||
|
num_layers: 1
|
||||||
|
cheb_order: 2
|
||||||
|
use_day: True
|
||||||
|
use_week: True
|
||||||
|
graph_size: 30
|
||||||
|
expert_nums: 8
|
||||||
|
top_k: 2
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 2000
|
||||||
|
plot: False
|
||||||
|
|
@ -14,17 +14,24 @@ data:
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
|
|
||||||
model:
|
model:
|
||||||
|
batch_size: 64
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
embed_dim: 10
|
in_len: 12
|
||||||
rnn_units: 64
|
dropout: 0.3
|
||||||
num_layers: 1
|
supports: null
|
||||||
cheb_order: 2
|
gcn_bool: true
|
||||||
use_day: True
|
addaptadj: true
|
||||||
use_week: True
|
aptinit: null
|
||||||
graph_size: 30
|
in_dim: 2
|
||||||
expert_nums: 8
|
out_dim: 12
|
||||||
top_k: 2
|
residual_channels: 32
|
||||||
|
dilation_channels: 32
|
||||||
|
skip_channels: 256
|
||||||
|
end_channels: 512
|
||||||
|
kernel_size: 2
|
||||||
|
blocks: 4
|
||||||
|
layers: 2
|
||||||
|
|
||||||
train:
|
train:
|
||||||
loss_func: mae
|
loss_func: mae
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 716
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
embed_dim: 12
|
||||||
|
rnn_units: 64
|
||||||
|
num_layers: 1
|
||||||
|
cheb_order: 2
|
||||||
|
use_day: True
|
||||||
|
use_week: True
|
||||||
|
graph_size: 30
|
||||||
|
expert_nums: 8
|
||||||
|
top_k: 2
|
||||||
|
hidden_dim: 64
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 2000
|
||||||
|
plot: False
|
||||||
|
|
@ -121,6 +121,9 @@ def load_st_dataset(dataset, sample):
|
||||||
case 'Hainan':
|
case 'Hainan':
|
||||||
data_path = os.path.join('./data/Hainan/Hainan.npz')
|
data_path = os.path.join('./data/Hainan/Hainan.npz')
|
||||||
data = np.load(data_path)['data'][:, :, 0]
|
data = np.load(data_path)['data'][:, :, 0]
|
||||||
|
case 'SD':
|
||||||
|
data_path = os.path.join('./data/SD/data.npz')
|
||||||
|
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported dataset: {dataset}")
|
raise ValueError(f"Unsupported dataset: {dataset}")
|
||||||
|
|
||||||
|
|
@ -204,3 +207,6 @@ def add_window_y(data, window=3, horizon=1, single=False):
|
||||||
return np.array(y)
|
return np.array(y)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
res = load_st_dataset('SD', 1)
|
||||||
|
k = 1
|
||||||
|
|
@ -0,0 +1,267 @@
|
||||||
|
import pickle
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
# ! X shape: (B, T, N, C)
|
||||||
|
|
||||||
|
def load_pkl(pickle_file: str) -> object:
|
||||||
|
"""
|
||||||
|
Load data from a pickle file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pickle_file (str): Path to the pickle file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: Loaded object from the pickle file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pickle_file, "rb") as f:
|
||||||
|
pickle_data = pickle.load(f)
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
with open(pickle_file, "rb") as f:
|
||||||
|
pickle_data = pickle.load(f, encoding="latin1")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unable to load data from {pickle_file}: {e}")
|
||||||
|
raise
|
||||||
|
return pickle_data
|
||||||
|
|
||||||
|
def get_dataloaders_from_index_data(
|
||||||
|
data_dir, tod=False, dow=False, batch_size=64, log=None, train_size=0.6
|
||||||
|
):
|
||||||
|
data = np.load(os.path.join(data_dir, "data.npz"))["data"].astype(np.float32)
|
||||||
|
|
||||||
|
features = [0]
|
||||||
|
if tod:
|
||||||
|
features.append(1)
|
||||||
|
if dow:
|
||||||
|
features.append(2)
|
||||||
|
# if dom:
|
||||||
|
# features.append(3)
|
||||||
|
data = data[..., features]
|
||||||
|
|
||||||
|
index = np.load(os.path.join(data_dir, "index.npz"))
|
||||||
|
|
||||||
|
train_index = index["train"] # (num_samples, 3)
|
||||||
|
val_index = index["val"]
|
||||||
|
test_index = index["test"]
|
||||||
|
|
||||||
|
x_train_index = vrange(train_index[:, 0], train_index[:, 1])
|
||||||
|
y_train_index = vrange(train_index[:, 1], train_index[:, 2])
|
||||||
|
x_val_index = vrange(val_index[:, 0], val_index[:, 1])
|
||||||
|
y_val_index = vrange(val_index[:, 1], val_index[:, 2])
|
||||||
|
x_test_index = vrange(test_index[:, 0], test_index[:, 1])
|
||||||
|
y_test_index = vrange(test_index[:, 1], test_index[:, 2])
|
||||||
|
|
||||||
|
x_train = data[x_train_index]
|
||||||
|
y_train = data[y_train_index][..., :1]
|
||||||
|
x_val = data[x_val_index]
|
||||||
|
y_val = data[y_val_index][..., :1]
|
||||||
|
x_test = data[x_test_index]
|
||||||
|
y_test = data[y_test_index][..., :1]
|
||||||
|
|
||||||
|
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
|
||||||
|
|
||||||
|
x_train[..., 0] = scaler.transform(x_train[..., 0])
|
||||||
|
x_val[..., 0] = scaler.transform(x_val[..., 0])
|
||||||
|
x_test[..., 0] = scaler.transform(x_test[..., 0])
|
||||||
|
|
||||||
|
print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
|
||||||
|
print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
|
||||||
|
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
|
||||||
|
|
||||||
|
trainset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_train), torch.FloatTensor(y_train)
|
||||||
|
)
|
||||||
|
valset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_val), torch.FloatTensor(y_val)
|
||||||
|
)
|
||||||
|
testset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
|
||||||
|
)
|
||||||
|
if train_size != 0.6:
|
||||||
|
drop_last=True
|
||||||
|
else:
|
||||||
|
drop_last=False
|
||||||
|
trainset_loader = torch.utils.data.DataLoader(
|
||||||
|
trainset, batch_size=batch_size, shuffle=True, drop_last=drop_last
|
||||||
|
)
|
||||||
|
valset_loader = torch.utils.data.DataLoader(
|
||||||
|
valset, batch_size=batch_size, shuffle=False, drop_last=drop_last
|
||||||
|
)
|
||||||
|
testset_loader = torch.utils.data.DataLoader(
|
||||||
|
testset, batch_size=batch_size, shuffle=False, drop_last=drop_last
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainset_loader, valset_loader, testset_loader, scaler
|
||||||
|
|
||||||
|
def get_dataloaders_from_index_data_MTS(
|
||||||
|
data_dir,
|
||||||
|
in_steps=12,
|
||||||
|
out_steps=12,
|
||||||
|
tod=False,
|
||||||
|
dow=False,
|
||||||
|
y_tod=False,
|
||||||
|
y_dow=False,
|
||||||
|
batch_size=64,
|
||||||
|
log=None,
|
||||||
|
):
|
||||||
|
data = np.load(os.path.join(data_dir, f"data.npz"))["data"].astype(np.float32)
|
||||||
|
index = np.load(os.path.join(data_dir, f"index_{in_steps}_{out_steps}.npz"))
|
||||||
|
|
||||||
|
x_features = [0]
|
||||||
|
if tod:
|
||||||
|
x_features.append(1)
|
||||||
|
if dow:
|
||||||
|
x_features.append(2)
|
||||||
|
|
||||||
|
y_features = [0]
|
||||||
|
if y_tod:
|
||||||
|
y_features.append(1)
|
||||||
|
if y_dow:
|
||||||
|
y_features.append(2)
|
||||||
|
|
||||||
|
train_index = index["train"] # (num_samples, 3)
|
||||||
|
val_index = index["val"]
|
||||||
|
test_index = index["test"]
|
||||||
|
|
||||||
|
# Parallel
|
||||||
|
# x_train_index = vrange(train_index[:, 0], train_index[:, 1])
|
||||||
|
# y_train_index = vrange(train_index[:, 1], train_index[:, 2])
|
||||||
|
# x_val_index = vrange(val_index[:, 0], val_index[:, 1])
|
||||||
|
# y_val_index = vrange(val_index[:, 1], val_index[:, 2])
|
||||||
|
# x_test_index = vrange(test_index[:, 0], test_index[:, 1])
|
||||||
|
# y_test_index = vrange(test_index[:, 1], test_index[:, 2])
|
||||||
|
|
||||||
|
# x_train = data[x_train_index][..., x_features]
|
||||||
|
# y_train = data[y_train_index][..., y_features]
|
||||||
|
# x_val = data[x_val_index][..., x_features]
|
||||||
|
# y_val = data[y_val_index][..., y_features]
|
||||||
|
# x_test = data[x_test_index][..., x_features]
|
||||||
|
# y_test = data[y_test_index][..., y_features]
|
||||||
|
|
||||||
|
# Iterative
|
||||||
|
x_train = np.stack([data[idx[0] : idx[1]] for idx in train_index])[..., x_features]
|
||||||
|
y_train = np.stack([data[idx[1] : idx[2]] for idx in train_index])[..., y_features]
|
||||||
|
x_val = np.stack([data[idx[0] : idx[1]] for idx in val_index])[..., x_features]
|
||||||
|
y_val = np.stack([data[idx[1] : idx[2]] for idx in val_index])[..., y_features]
|
||||||
|
x_test = np.stack([data[idx[0] : idx[1]] for idx in test_index])[..., x_features]
|
||||||
|
y_test = np.stack([data[idx[1] : idx[2]] for idx in test_index])[..., y_features]
|
||||||
|
|
||||||
|
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
|
||||||
|
|
||||||
|
x_train[..., 0] = scaler.transform(x_train[..., 0])
|
||||||
|
x_val[..., 0] = scaler.transform(x_val[..., 0])
|
||||||
|
x_test[..., 0] = scaler.transform(x_test[..., 0])
|
||||||
|
|
||||||
|
print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
|
||||||
|
print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
|
||||||
|
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
|
||||||
|
|
||||||
|
trainset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_train), torch.FloatTensor(y_train)
|
||||||
|
)
|
||||||
|
valset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_val), torch.FloatTensor(y_val)
|
||||||
|
)
|
||||||
|
testset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
|
||||||
|
)
|
||||||
|
|
||||||
|
trainset_loader = torch.utils.data.DataLoader(
|
||||||
|
trainset, batch_size=batch_size, shuffle=True
|
||||||
|
)
|
||||||
|
valset_loader = torch.utils.data.DataLoader(
|
||||||
|
valset, batch_size=batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
testset_loader = torch.utils.data.DataLoader(
|
||||||
|
testset, batch_size=batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return trainset_loader, valset_loader, testset_loader, scaler
|
||||||
|
|
||||||
|
def get_dataloaders_from_index_data_Test(
|
||||||
|
data_dir,
|
||||||
|
in_steps=12,
|
||||||
|
out_steps=12,
|
||||||
|
tod=False,
|
||||||
|
dow=False,
|
||||||
|
y_tod=False,
|
||||||
|
y_dow=False,
|
||||||
|
batch_size=64,
|
||||||
|
log=None,
|
||||||
|
):
|
||||||
|
data = np.load(os.path.join(data_dir, f"data.npz"))["data"].astype(np.float32)
|
||||||
|
index = np.load(os.path.join(data_dir, f"index_{in_steps}_{out_steps}.npz"))
|
||||||
|
|
||||||
|
x_features = [0]
|
||||||
|
if tod:
|
||||||
|
x_features.append(1)
|
||||||
|
if dow:
|
||||||
|
x_features.append(2)
|
||||||
|
|
||||||
|
y_features = [0]
|
||||||
|
if y_tod:
|
||||||
|
y_features.append(1)
|
||||||
|
if y_dow:
|
||||||
|
y_features.append(2)
|
||||||
|
|
||||||
|
train_index = index["train"] # (num_samples, 3)
|
||||||
|
# val_index = index["val"]
|
||||||
|
test_index = index["test"]
|
||||||
|
|
||||||
|
# Parallel
|
||||||
|
# x_train_index = vrange(train_index[:, 0], train_index[:, 1])
|
||||||
|
# y_train_index = vrange(train_index[:, 1], train_index[:, 2])
|
||||||
|
# x_val_index = vrange(val_index[:, 0], val_index[:, 1])
|
||||||
|
# y_val_index = vrange(val_index[:, 1], val_index[:, 2])
|
||||||
|
# x_test_index = vrange(test_index[:, 0], test_index[:, 1])
|
||||||
|
# y_test_index = vrange(test_index[:, 1], test_index[:, 2])
|
||||||
|
|
||||||
|
# x_train = data[x_train_index][..., x_features]
|
||||||
|
# y_train = data[y_train_index][..., y_features]
|
||||||
|
# x_val = data[x_val_index][..., x_features]
|
||||||
|
# y_val = data[y_val_index][..., y_features]
|
||||||
|
# x_test = data[x_test_index][..., x_features]
|
||||||
|
# y_test = data[y_test_index][..., y_features]
|
||||||
|
|
||||||
|
# Iterative
|
||||||
|
x_train = np.stack([data[idx[0] : idx[1]] for idx in train_index])[..., x_features]
|
||||||
|
# y_train = np.stack([data[idx[1] : idx[2]] for idx in train_index])[..., y_features]
|
||||||
|
# x_val = np.stack([data[idx[0] : idx[1]] for idx in val_index])[..., x_features]
|
||||||
|
# y_val = np.stack([data[idx[1] : idx[2]] for idx in val_index])[..., y_features]
|
||||||
|
x_test = np.stack([data[idx[0] : idx[1]] for idx in test_index])[..., x_features]
|
||||||
|
y_test = np.stack([data[idx[1] : idx[2]] for idx in test_index])[..., y_features]
|
||||||
|
|
||||||
|
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
|
||||||
|
|
||||||
|
# x_train[..., 0] = scaler.transform(x_train[..., 0])
|
||||||
|
# x_val[..., 0] = scaler.transform(x_val[..., 0])
|
||||||
|
x_test[..., 0] = scaler.transform(x_test[..., 0])
|
||||||
|
|
||||||
|
# print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
|
||||||
|
# print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
|
||||||
|
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
|
||||||
|
|
||||||
|
# trainset = torch.utils.data.TensorDataset(
|
||||||
|
# torch.FloatTensor(x_train), torch.FloatTensor(y_train)
|
||||||
|
# )
|
||||||
|
# valset = torch.utils.data.TensorDataset(
|
||||||
|
# torch.FloatTensor(x_val), torch.FloatTensor(y_val)
|
||||||
|
# )
|
||||||
|
testset = torch.utils.data.TensorDataset(
|
||||||
|
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
|
||||||
|
)
|
||||||
|
|
||||||
|
# trainset_loader = torch.utils.data.DataLoader(
|
||||||
|
# trainset, batch_size=batch_size, shuffle=True
|
||||||
|
# )
|
||||||
|
# valset_loader = torch.utils.data.DataLoader(
|
||||||
|
# valset, batch_size=batch_size, shuffle=False
|
||||||
|
# )
|
||||||
|
testset_loader = torch.utils.data.DataLoader(
|
||||||
|
testset, batch_size=batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return testset_loader, scaler
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicGraphConstructor(nn.Module):
|
||||||
|
def __init__(self, node_num, embed_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
|
||||||
|
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
# (N, D) @ (D, N) -> (N, N)
|
||||||
|
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
|
||||||
|
adj = F.relu(adj)
|
||||||
|
adj = F.softmax(adj, dim=-1)
|
||||||
|
return adj
|
||||||
|
|
||||||
|
|
||||||
|
class GraphConvBlock(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = nn.Linear(input_dim, output_dim)
|
||||||
|
self.residual = input_dim == output_dim
|
||||||
|
if not self.residual:
|
||||||
|
self.res_proj = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x, adj):
|
||||||
|
# x: (B, N, C) / adj: (N, N)
|
||||||
|
res = x
|
||||||
|
x = torch.matmul(adj, x) # (B, N, C)
|
||||||
|
x = self.theta(x)
|
||||||
|
|
||||||
|
# 残差连接
|
||||||
|
if self.residual:
|
||||||
|
x = x + res
|
||||||
|
else:
|
||||||
|
x = x + self.res_proj(res)
|
||||||
|
|
||||||
|
return F.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MANBA_Block(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, input_dim)
|
||||||
|
)
|
||||||
|
self.norm1 = nn.LayerNorm(input_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (B, T, C)
|
||||||
|
res = x
|
||||||
|
x_attn, _ = self.attn(x, x, x)
|
||||||
|
x = self.norm1(res + x_attn)
|
||||||
|
|
||||||
|
res2 = x
|
||||||
|
x_ffn = self.ffn(x)
|
||||||
|
x = self.norm2(res2 + x_ffn)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EXP(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.horizon = args['horizon']
|
||||||
|
self.output_dim = args['output_dim']
|
||||||
|
self.seq_len = args.get('in_len', 12)
|
||||||
|
self.hidden_dim = args.get('hidden_dim', 64)
|
||||||
|
self.num_nodes = args['num_nodes']
|
||||||
|
|
||||||
|
# 动态图构建
|
||||||
|
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
|
||||||
|
|
||||||
|
# 输入映射层
|
||||||
|
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
|
||||||
|
|
||||||
|
# 图卷积
|
||||||
|
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
|
||||||
|
|
||||||
|
# MANBA block
|
||||||
|
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
|
||||||
|
|
||||||
|
# 输出映射
|
||||||
|
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: (B, T, N, D_total)
|
||||||
|
x = x[..., 0] # 只用主通道 (B, T, N)
|
||||||
|
B, T, N = x.shape
|
||||||
|
assert T == self.seq_len
|
||||||
|
|
||||||
|
# 输入投影 (B, T, N) -> (B, N, T) -> (B*N, T) -> (B*N, H)
|
||||||
|
x = x.permute(0, 2, 1).reshape(B * N, T)
|
||||||
|
h = self.input_proj(x) # (B*N, hidden_dim)
|
||||||
|
h = h.view(B, N, self.hidden_dim)
|
||||||
|
|
||||||
|
# 动态图构建
|
||||||
|
adj = self.graph() # (N, N)
|
||||||
|
|
||||||
|
# 空间建模:图卷积
|
||||||
|
h = self.gc(h, adj) # (B, N, hidden_dim)
|
||||||
|
|
||||||
|
# 时间建模:MANBA
|
||||||
|
h = self.manba(h) # (B, N, hidden_dim)
|
||||||
|
|
||||||
|
# 输出映射
|
||||||
|
out = self.out_proj(h) # (B, N, horizon * output_dim)
|
||||||
|
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
|
||||||
|
return out # (B, horizon, N, output_dim)
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicGraphConstructor(nn.Module):
|
||||||
|
def __init__(self, node_num, embed_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim))
|
||||||
|
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim))
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
|
||||||
|
adj = F.relu(adj)
|
||||||
|
adj = F.softmax(adj, dim=-1)
|
||||||
|
return adj
|
||||||
|
|
||||||
|
|
||||||
|
class GraphConvBlock(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = nn.Linear(input_dim, output_dim)
|
||||||
|
self.residual = input_dim == output_dim
|
||||||
|
if not self.residual:
|
||||||
|
self.res_proj = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x, adj):
|
||||||
|
res = x
|
||||||
|
x = torch.matmul(adj, x)
|
||||||
|
x = self.theta(x)
|
||||||
|
if self.residual:
|
||||||
|
x = x + res
|
||||||
|
else:
|
||||||
|
x = x + self.res_proj(res)
|
||||||
|
return F.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MANBA_Block(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, input_dim)
|
||||||
|
)
|
||||||
|
self.norm1 = nn.LayerNorm(input_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
x_attn, _ = self.attn(x, x, x)
|
||||||
|
x = self.norm1(res + x_attn)
|
||||||
|
|
||||||
|
res2 = x
|
||||||
|
x_ffn = self.ffn(x)
|
||||||
|
x = self.norm2(res2 + x_ffn)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EXPExpert(nn.Module): # 原 EXP 改名
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.horizon = args['horizon']
|
||||||
|
self.output_dim = args['output_dim']
|
||||||
|
self.seq_len = args.get('in_len', 12)
|
||||||
|
self.hidden_dim = args.get('hidden_dim', 64)
|
||||||
|
self.num_nodes = args['num_nodes']
|
||||||
|
|
||||||
|
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
|
||||||
|
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
|
||||||
|
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
|
||||||
|
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
|
||||||
|
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x[..., 0] # (B, T, N)
|
||||||
|
B, T, N = x.shape
|
||||||
|
x = x.permute(0, 2, 1).reshape(B * N, T)
|
||||||
|
h = self.input_proj(x).view(B, N, -1)
|
||||||
|
adj = self.graph()
|
||||||
|
h = self.gc(h, adj)
|
||||||
|
h = self.manba(h)
|
||||||
|
out = self.out_proj(h)
|
||||||
|
return out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
class EXP(nn.Module):
|
||||||
|
def __init__(self, args, num_experts=4, top_k=2):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.experts = nn.ModuleList([EXPExpert(args) for _ in range(num_experts)])
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Linear(args['in_len'] * args['num_nodes'], 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, num_experts)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B = x.size(0)
|
||||||
|
# Flatten input for gating
|
||||||
|
gate_input = x[..., 0].reshape(B, -1) # (B, T*N)
|
||||||
|
gate_logits = self.gate(gate_input) # (B, num_experts)
|
||||||
|
gate_scores = F.softmax(gate_logits, dim=-1) # soft selection
|
||||||
|
|
||||||
|
# Get top-k experts
|
||||||
|
topk_val, topk_idx = torch.topk(gate_scores, self.top_k, dim=-1) # (B, k)
|
||||||
|
|
||||||
|
outputs = torch.zeros_like(self.experts[0](x)) # (B, H, N, D_out)
|
||||||
|
for i in range(self.top_k):
|
||||||
|
idx = topk_idx[:, i]
|
||||||
|
for expert_id in torch.unique(idx):
|
||||||
|
mask = idx == expert_id
|
||||||
|
if mask.sum() == 0:
|
||||||
|
continue
|
||||||
|
selected_x = x[mask]
|
||||||
|
expert_output = self.experts[expert_id](selected_x)
|
||||||
|
outputs[mask] += topk_val[mask, i].unsqueeze(1).unsqueeze(1).unsqueeze(1) * expert_output
|
||||||
|
|
||||||
|
return outputs # (B, H, N, D_out)
|
||||||
|
|
@ -13,7 +13,7 @@ from model.STFGNN.STFGNN import STFGNN
|
||||||
from model.STSGCN.STSGCN import STSGCN
|
from model.STSGCN.STSGCN import STSGCN
|
||||||
from model.STGODE.STGODE import ODEGCN
|
from model.STGODE.STGODE import ODEGCN
|
||||||
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
||||||
from model.EXP.EXP7 import EXP as EXP
|
from model.EXP.EXP9 import EXP as EXP
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue