增加 exp9 混合专家

exp8 动态图manba
This commit is contained in:
czzhangheng 2025-04-17 18:41:57 +08:00
parent c9a5a54d90
commit 86fabd4ca7
8 changed files with 629 additions and 10 deletions

51
config/EXP/PEMSD3.yaml Normal file
View File

@ -0,0 +1,51 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
embed_dim: 10
rnn_units: 64
num_layers: 1
cheb_order: 2
use_day: True
use_week: True
graph_size: 30
expert_nums: 8
top_k: 2
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

View File

@ -14,17 +14,24 @@ data:
days_per_week: 7 days_per_week: 7
model: model:
batch_size: 64
input_dim: 1 input_dim: 1
output_dim: 1 output_dim: 1
embed_dim: 10 in_len: 12
rnn_units: 64 dropout: 0.3
num_layers: 1 supports: null
cheb_order: 2 gcn_bool: true
use_day: True addaptadj: true
use_week: True aptinit: null
graph_size: 30 in_dim: 2
expert_nums: 8 out_dim: 12
top_k: 2 residual_channels: 32
dilation_channels: 32
skip_channels: 256
end_channels: 512
kernel_size: 2
blocks: 4
layers: 2
train: train:
loss_func: mae loss_func: mae

52
config/EXP/SD.yaml Normal file
View File

@ -0,0 +1,52 @@
data:
num_nodes: 716
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
embed_dim: 12
rnn_units: 64
num_layers: 1
cheb_order: 2
use_day: True
use_week: True
graph_size: 30
expert_nums: 8
top_k: 2
hidden_dim: 64
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

View File

@ -121,6 +121,9 @@ def load_st_dataset(dataset, sample):
case 'Hainan': case 'Hainan':
data_path = os.path.join('./data/Hainan/Hainan.npz') data_path = os.path.join('./data/Hainan/Hainan.npz')
data = np.load(data_path)['data'][:, :, 0] data = np.load(data_path)['data'][:, :, 0]
case 'SD':
data_path = os.path.join('./data/SD/data.npz')
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case _: case _:
raise ValueError(f"Unsupported dataset: {dataset}") raise ValueError(f"Unsupported dataset: {dataset}")
@ -204,3 +207,6 @@ def add_window_y(data, window=3, horizon=1, single=False):
return np.array(y) return np.array(y)
if __name__ == '__main__':
res = load_st_dataset('SD', 1)
k = 1

267
lib/LargeST.py Normal file
View File

@ -0,0 +1,267 @@
import pickle
import torch
import numpy as np
import os
import gc
# ! X shape: (B, T, N, C)
def load_pkl(pickle_file: str) -> object:
"""
Load data from a pickle file.
Args:
pickle_file (str): Path to the pickle file.
Returns:
object: Loaded object from the pickle file.
"""
try:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f, encoding="latin1")
except Exception as e:
print(f"Unable to load data from {pickle_file}: {e}")
raise
return pickle_data
def get_dataloaders_from_index_data(
data_dir, tod=False, dow=False, batch_size=64, log=None, train_size=0.6
):
data = np.load(os.path.join(data_dir, "data.npz"))["data"].astype(np.float32)
features = [0]
if tod:
features.append(1)
if dow:
features.append(2)
# if dom:
# features.append(3)
data = data[..., features]
index = np.load(os.path.join(data_dir, "index.npz"))
train_index = index["train"] # (num_samples, 3)
val_index = index["val"]
test_index = index["test"]
x_train_index = vrange(train_index[:, 0], train_index[:, 1])
y_train_index = vrange(train_index[:, 1], train_index[:, 2])
x_val_index = vrange(val_index[:, 0], val_index[:, 1])
y_val_index = vrange(val_index[:, 1], val_index[:, 2])
x_test_index = vrange(test_index[:, 0], test_index[:, 1])
y_test_index = vrange(test_index[:, 1], test_index[:, 2])
x_train = data[x_train_index]
y_train = data[y_train_index][..., :1]
x_val = data[x_val_index]
y_val = data[y_val_index][..., :1]
x_test = data[x_test_index]
y_test = data[y_test_index][..., :1]
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
x_train[..., 0] = scaler.transform(x_train[..., 0])
x_val[..., 0] = scaler.transform(x_val[..., 0])
x_test[..., 0] = scaler.transform(x_test[..., 0])
print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
trainset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_train), torch.FloatTensor(y_train)
)
valset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_val), torch.FloatTensor(y_val)
)
testset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
)
if train_size != 0.6:
drop_last=True
else:
drop_last=False
trainset_loader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, drop_last=drop_last
)
valset_loader = torch.utils.data.DataLoader(
valset, batch_size=batch_size, shuffle=False, drop_last=drop_last
)
testset_loader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, drop_last=drop_last
)
return trainset_loader, valset_loader, testset_loader, scaler
def get_dataloaders_from_index_data_MTS(
data_dir,
in_steps=12,
out_steps=12,
tod=False,
dow=False,
y_tod=False,
y_dow=False,
batch_size=64,
log=None,
):
data = np.load(os.path.join(data_dir, f"data.npz"))["data"].astype(np.float32)
index = np.load(os.path.join(data_dir, f"index_{in_steps}_{out_steps}.npz"))
x_features = [0]
if tod:
x_features.append(1)
if dow:
x_features.append(2)
y_features = [0]
if y_tod:
y_features.append(1)
if y_dow:
y_features.append(2)
train_index = index["train"] # (num_samples, 3)
val_index = index["val"]
test_index = index["test"]
# Parallel
# x_train_index = vrange(train_index[:, 0], train_index[:, 1])
# y_train_index = vrange(train_index[:, 1], train_index[:, 2])
# x_val_index = vrange(val_index[:, 0], val_index[:, 1])
# y_val_index = vrange(val_index[:, 1], val_index[:, 2])
# x_test_index = vrange(test_index[:, 0], test_index[:, 1])
# y_test_index = vrange(test_index[:, 1], test_index[:, 2])
# x_train = data[x_train_index][..., x_features]
# y_train = data[y_train_index][..., y_features]
# x_val = data[x_val_index][..., x_features]
# y_val = data[y_val_index][..., y_features]
# x_test = data[x_test_index][..., x_features]
# y_test = data[y_test_index][..., y_features]
# Iterative
x_train = np.stack([data[idx[0] : idx[1]] for idx in train_index])[..., x_features]
y_train = np.stack([data[idx[1] : idx[2]] for idx in train_index])[..., y_features]
x_val = np.stack([data[idx[0] : idx[1]] for idx in val_index])[..., x_features]
y_val = np.stack([data[idx[1] : idx[2]] for idx in val_index])[..., y_features]
x_test = np.stack([data[idx[0] : idx[1]] for idx in test_index])[..., x_features]
y_test = np.stack([data[idx[1] : idx[2]] for idx in test_index])[..., y_features]
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
x_train[..., 0] = scaler.transform(x_train[..., 0])
x_val[..., 0] = scaler.transform(x_val[..., 0])
x_test[..., 0] = scaler.transform(x_test[..., 0])
print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
trainset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_train), torch.FloatTensor(y_train)
)
valset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_val), torch.FloatTensor(y_val)
)
testset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
)
trainset_loader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True
)
valset_loader = torch.utils.data.DataLoader(
valset, batch_size=batch_size, shuffle=False
)
testset_loader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False
)
return trainset_loader, valset_loader, testset_loader, scaler
def get_dataloaders_from_index_data_Test(
data_dir,
in_steps=12,
out_steps=12,
tod=False,
dow=False,
y_tod=False,
y_dow=False,
batch_size=64,
log=None,
):
data = np.load(os.path.join(data_dir, f"data.npz"))["data"].astype(np.float32)
index = np.load(os.path.join(data_dir, f"index_{in_steps}_{out_steps}.npz"))
x_features = [0]
if tod:
x_features.append(1)
if dow:
x_features.append(2)
y_features = [0]
if y_tod:
y_features.append(1)
if y_dow:
y_features.append(2)
train_index = index["train"] # (num_samples, 3)
# val_index = index["val"]
test_index = index["test"]
# Parallel
# x_train_index = vrange(train_index[:, 0], train_index[:, 1])
# y_train_index = vrange(train_index[:, 1], train_index[:, 2])
# x_val_index = vrange(val_index[:, 0], val_index[:, 1])
# y_val_index = vrange(val_index[:, 1], val_index[:, 2])
# x_test_index = vrange(test_index[:, 0], test_index[:, 1])
# y_test_index = vrange(test_index[:, 1], test_index[:, 2])
# x_train = data[x_train_index][..., x_features]
# y_train = data[y_train_index][..., y_features]
# x_val = data[x_val_index][..., x_features]
# y_val = data[y_val_index][..., y_features]
# x_test = data[x_test_index][..., x_features]
# y_test = data[y_test_index][..., y_features]
# Iterative
x_train = np.stack([data[idx[0] : idx[1]] for idx in train_index])[..., x_features]
# y_train = np.stack([data[idx[1] : idx[2]] for idx in train_index])[..., y_features]
# x_val = np.stack([data[idx[0] : idx[1]] for idx in val_index])[..., x_features]
# y_val = np.stack([data[idx[1] : idx[2]] for idx in val_index])[..., y_features]
x_test = np.stack([data[idx[0] : idx[1]] for idx in test_index])[..., x_features]
y_test = np.stack([data[idx[1] : idx[2]] for idx in test_index])[..., y_features]
scaler = StandardScaler(mean=x_train[..., 0].mean(), std=x_train[..., 0].std())
# x_train[..., 0] = scaler.transform(x_train[..., 0])
# x_val[..., 0] = scaler.transform(x_val[..., 0])
x_test[..., 0] = scaler.transform(x_test[..., 0])
# print_log(f"Trainset:\tx-{x_train.shape}\ty-{y_train.shape}", log=log)
# print_log(f"Valset: \tx-{x_val.shape} \ty-{y_val.shape}", log=log)
print_log(f"Testset:\tx-{x_test.shape}\ty-{y_test.shape}", log=log)
# trainset = torch.utils.data.TensorDataset(
# torch.FloatTensor(x_train), torch.FloatTensor(y_train)
# )
# valset = torch.utils.data.TensorDataset(
# torch.FloatTensor(x_val), torch.FloatTensor(y_val)
# )
testset = torch.utils.data.TensorDataset(
torch.FloatTensor(x_test), torch.FloatTensor(y_test)
)
# trainset_loader = torch.utils.data.DataLoader(
# trainset, batch_size=batch_size, shuffle=True
# )
# valset_loader = torch.utils.data.DataLoader(
# valset, batch_size=batch_size, shuffle=False
# )
testset_loader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False
)
return testset_loader, scaler

115
model/EXP/EXP8.py Normal file
View File

@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim), requires_grad=True)
def forward(self):
# (N, D) @ (D, N) -> (N, N)
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
# x: (B, N, C) / adj: (N, N)
res = x
x = torch.matmul(adj, x) # (B, N, C)
x = self.theta(x)
# 残差连接
if self.residual:
x = x + res
else:
x = x + self.res_proj(res)
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
# x: (B, T, C)
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
# 动态图构建
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
# 输入映射层
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
# 图卷积
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
# MANBA block
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
# 输出映射
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
# x: (B, T, N, D_total)
x = x[..., 0] # 只用主通道 (B, T, N)
B, T, N = x.shape
assert T == self.seq_len
# 输入投影 (B, T, N) -> (B, N, T) -> (B*N, T) -> (B*N, H)
x = x.permute(0, 2, 1).reshape(B * N, T)
h = self.input_proj(x) # (B*N, hidden_dim)
h = h.view(B, N, self.hidden_dim)
# 动态图构建
adj = self.graph() # (N, N)
# 空间建模:图卷积
h = self.gc(h, adj) # (B, N, hidden_dim)
# 时间建模MANBA
h = self.manba(h) # (B, N, hidden_dim)
# 输出映射
out = self.out_proj(h) # (B, N, horizon * output_dim)
out = out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
return out # (B, horizon, N, output_dim)

121
model/EXP/EXP9.py Normal file
View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim))
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim))
def forward(self):
adj = torch.matmul(self.nodevec1, self.nodevec2.T)
adj = F.relu(adj)
adj = F.softmax(adj, dim=-1)
return adj
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = input_dim == output_dim
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
if self.residual:
x = x + res
else:
x = x + self.res_proj(res)
return F.relu(x)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
x = self.norm2(res2 + x_ffn)
return x
class EXPExpert(nn.Module): # 原 EXP 改名
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
self.seq_len = args.get('in_len', 12)
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.graph = DynamicGraphConstructor(self.num_nodes, embed_dim=16)
self.input_proj = nn.Linear(self.seq_len, self.hidden_dim)
self.gc = GraphConvBlock(self.hidden_dim, self.hidden_dim)
self.manba = MANBA_Block(self.hidden_dim, self.hidden_dim * 2)
self.out_proj = nn.Linear(self.hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
x = x[..., 0] # (B, T, N)
B, T, N = x.shape
x = x.permute(0, 2, 1).reshape(B * N, T)
h = self.input_proj(x).view(B, N, -1)
adj = self.graph()
h = self.gc(h, adj)
h = self.manba(h)
out = self.out_proj(h)
return out.view(B, N, self.horizon, self.output_dim).permute(0, 2, 1, 3)
class EXP(nn.Module):
def __init__(self, args, num_experts=4, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([EXPExpert(args) for _ in range(num_experts)])
self.gate = nn.Sequential(
nn.Linear(args['in_len'] * args['num_nodes'], 128),
nn.ReLU(),
nn.Linear(128, num_experts)
)
def forward(self, x):
B = x.size(0)
# Flatten input for gating
gate_input = x[..., 0].reshape(B, -1) # (B, T*N)
gate_logits = self.gate(gate_input) # (B, num_experts)
gate_scores = F.softmax(gate_logits, dim=-1) # soft selection
# Get top-k experts
topk_val, topk_idx = torch.topk(gate_scores, self.top_k, dim=-1) # (B, k)
outputs = torch.zeros_like(self.experts[0](x)) # (B, H, N, D_out)
for i in range(self.top_k):
idx = topk_idx[:, i]
for expert_id in torch.unique(idx):
mask = idx == expert_id
if mask.sum() == 0:
continue
selected_x = x[mask]
expert_output = self.experts[expert_id](selected_x)
outputs[mask] += topk_val[mask, i].unsqueeze(1).unsqueeze(1).unsqueeze(1) * expert_output
return outputs # (B, H, N, D_out)

View File

@ -13,7 +13,7 @@ from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.PDG2SEQ.PDG2Seq import PDG2Seq
from model.EXP.EXP7 import EXP as EXP from model.EXP.EXP9 import EXP as EXP
def model_selector(model): def model_selector(model):
match model['type']: match model['type']: