add RGDAN

This commit is contained in:
czzhangheng 2025-08-19 15:37:14 +08:00
parent fa6eb90d65
commit 9e22712d77
10 changed files with 552 additions and 7 deletions

48
config/RGDAN/PEMSD3.yaml Normal file
View File

@ -0,0 +1,48 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
K: 3
d: 8
SEDims: 16
TEDims: 295
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

48
config/RGDAN/PEMSD4.yaml Normal file
View File

@ -0,0 +1,48 @@
data:
num_nodes: 307
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
K: 3
d: 8
SEDims: 16
TEDims: 295 # 7 + 288
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

48
config/RGDAN/PEMSD7.yaml Normal file
View File

@ -0,0 +1,48 @@
data:
num_nodes: 883
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
K: 3
d: 8
SEDims: 16
TEDims: 295
train:
loss_func: mae
seed: 10
batch_size: 8 # larger graph may need smaller batch
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

48
config/RGDAN/PEMSD8.yaml Normal file
View File

@ -0,0 +1,48 @@
data:
num_nodes: 170
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
K: 3
d: 8
SEDims: 16
TEDims: 295
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -15,7 +15,7 @@ data:
days_per_week: 7 days_per_week: 7
sample: 1 sample: 1
input_dim: 3 input_dim: 3
batch_size: 8 batch_size: 64
model: model:
type: 'STEP' type: 'STEP'
@ -68,7 +68,7 @@ model:
train: train:
loss_func: mae loss_func: mae
seed: 10 seed: 10
batch_size: 8 batch_size: 64
epochs: 100 epochs: 100
lr_init: 0.002 lr_init: 0.002
weight_decay: 1.0e-5 weight_decay: 1.0e-5

View File

@ -141,10 +141,10 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
decoder_hidden_state) decoder_hidden_state)
decoder_input = decoder_output decoder_input = decoder_output
outputs.append(decoder_output) outputs.append(decoder_output)
if self.training and self.use_curriculum_learning: # if self.training and self.use_curriculum_learning:
c = np.random.uniform(0, 1) # c = np.random.uniform(0, 1)
if c < self._compute_sampling_threshold(batches_seen): # if c < self._compute_sampling_threshold(batches_seen):
decoder_input = labels[t] # decoder_input = labels[t]
outputs = torch.stack(outputs) outputs = torch.stack(outputs)
return outputs return outputs

348
model/RGDAN/RGDAN.py Normal file
View File

@ -0,0 +1,348 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from data.get_adj import get_adj
class gcn(torch.nn.Module):
def __init__(self, k, d):
super(gcn, self).__init__()
D = k * d
self.fc = torch.nn.Linear(2 * D, D)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, STE, A):
X = torch.cat((X, STE), dim=-1)
H = F.gelu(self.fc(X))
H = torch.einsum('ncvl,vw->ncwl', (H, A))
return self.dropout(H.contiguous())
class randomGAT(torch.nn.Module):
def __init__(self, k, d, adj, device):
super(randomGAT, self).__init__()
D = k * d
self.d = d
self.K = k
num_nodes = adj.shape[0]
self.device = device
self.fc = torch.nn.Linear(2 * D, D)
self.adj = adj
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
def forward(self, X, STE):
X = torch.cat((X, STE), dim=-1)
H = F.gelu(self.fc(X))
H = torch.cat(torch.split(H, self.d, dim=-1), dim=0)
adp = torch.mm(self.nodevec1, self.nodevec2)
zero_vec = torch.tensor(-9e15).to(self.device)
adp = torch.where(self.adj > 0, adp, zero_vec)
adj = F.softmax(adp, dim=-1)
H = torch.einsum('vw,ncwl->ncvl', (adj, H))
H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1)
return F.gelu(H.contiguous())
class STEmbModel(torch.nn.Module):
def __init__(self, SEDims, TEDims, OutDims, device):
super(STEmbModel, self).__init__()
self.TEDims = TEDims
self.fc3 = torch.nn.Linear(SEDims, OutDims)
self.fc4 = torch.nn.Linear(OutDims, OutDims)
self.fc5 = torch.nn.Linear(TEDims, OutDims)
self.fc6 = torch.nn.Linear(OutDims, OutDims)
self.device = device
def forward(self, SE, TE):
SE = SE.unsqueeze(0).unsqueeze(0)
SE = self.fc4(F.gelu(self.fc3(SE)))
dayofweek = F.one_hot(TE[..., 0], num_classes=7)
timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7)
TE = torch.cat((dayofweek, timeofday), dim=-1)
TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device)
TE = self.fc6(F.gelu(self.fc5(TE)))
sum_tensor = torch.add(SE, TE)
return sum_tensor
class SpatialAttentionModel(torch.nn.Module):
def __init__(self, K, d, adj, dropout=0.3, mask=True):
super(SpatialAttentionModel, self).__init__()
D = K * d
self.fc7 = torch.nn.Linear(2 * D, D)
self.fc8 = torch.nn.Linear(2 * D, D)
self.fc9 = torch.nn.Linear(2 * D, D)
self.fc10 = torch.nn.Linear(D, D)
self.fc11 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.adj = adj
self.mask = mask
self.dropout = dropout
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, X, STE):
X = torch.cat((X, STE), dim=-1)
query = F.gelu(self.fc7(X))
key = F.gelu(self.fc8(X))
value = F.gelu(self.fc9(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
attention = torch.matmul(query, torch.transpose(key, 2, 3))
attention /= (self.d ** 0.5)
if self.mask:
zero_vec = -9e15 * torch.ones_like(attention)
attention = torch.where(self.adj > 0, attention, zero_vec)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.fc11(F.gelu(self.fc10(X)))
return X
class TemporalAttentionModel(torch.nn.Module):
def __init__(self, K, d, device):
super(TemporalAttentionModel, self).__init__()
D = K * d
self.fc12 = torch.nn.Linear(2 * D, D)
self.fc13 = torch.nn.Linear(2 * D, D)
self.fc14 = torch.nn.Linear(2 * D, D)
self.fc15 = torch.nn.Linear(D, D)
self.fc16 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, STE, Mask=True):
X = torch.cat((X, STE), dim=-1)
query = F.gelu(self.fc12(X))
key = F.gelu(self.fc13(X))
value = F.gelu(self.fc14(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
query = torch.transpose(query, 2, 1)
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
value = torch.transpose(value, 2, 1)
attention = torch.matmul(query, key)
attention /= (self.d ** 0.5)
if Mask:
num_steps = X.shape[1]
mask = torch.ones(num_steps, num_steps).to(self.device)
mask = torch.tril(mask)
zero_vec = torch.tensor(-9e15).to(self.device)
mask = mask.to(torch.bool)
attention = torch.where(mask, attention, zero_vec)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.transpose(X, 2, 1)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.dropout(self.fc16(F.gelu(self.fc15(X))))
return X
class GatedFusionModel(torch.nn.Module):
def __init__(self, K, d):
super(GatedFusionModel, self).__init__()
D = K * d
self.fc17 = torch.nn.Linear(D, D)
self.fc18 = torch.nn.Linear(D, D)
self.fc19 = torch.nn.Linear(D, D)
self.fc20 = torch.nn.Linear(D, D)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, HS, HT):
XS = self.fc17(HS)
XT = self.fc18(HT)
z = self.sigmoid(torch.add(XS, XT))
H = torch.add((z * HS), ((1 - z) * HT))
H = self.fc20(F.gelu(self.fc19(H)))
return H
class STAttModel(torch.nn.Module):
def __init__(self, K, d, adj, device):
super(STAttModel, self).__init__()
D = K * d
self.fc30 = torch.nn.Linear(7 * D, D)
self.gcn = gcn(K, d)
self.gcn1 = randomGAT(K, d, adj[0], device)
self.gcn2 = randomGAT(K, d, adj[0], device)
self.gcn3 = randomGAT(K, d, adj[1], device)
self.gcn4 = randomGAT(K, d, adj[1], device)
self.temporalAttention = TemporalAttentionModel(K, d, device)
self.gatedFusion = GatedFusionModel(K, d)
def forward(self, X, STE, adp, Mask=True):
HS1 = self.gcn1(X, STE)
HS2 = self.gcn2(HS1, STE)
HS3 = self.gcn3(X, STE)
HS4 = self.gcn4(HS3, STE)
HS5 = self.gcn(X, STE, adp)
HS6 = self.gcn(HS5, STE, adp)
HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1)
HS = F.gelu(self.fc30(HS))
HT = self.temporalAttention(X, STE, Mask)
H = self.gatedFusion(HS, HT)
return torch.add(X, H)
class TransformAttentionModel(torch.nn.Module):
def __init__(self, K, d):
super(TransformAttentionModel, self).__init__()
D = K * d
self.fc21 = torch.nn.Linear(D, D)
self.fc22 = torch.nn.Linear(D, D)
self.fc23 = torch.nn.Linear(D, D)
self.fc24 = torch.nn.Linear(D, D)
self.fc25 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, X, STE_P, STE_Q):
query = F.gelu(self.fc21(STE_Q))
key = F.gelu(self.fc22(STE_P))
value = F.gelu(self.fc23(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
query = torch.transpose(query, 2, 1)
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
value = torch.transpose(value, 2, 1)
attention = torch.matmul(query, key)
attention /= (self.d ** 0.5)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.transpose(X, 2, 1)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.fc25(F.gelu(self.fc24(X)))
return X
class RGDAN(nn.Module):
def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes):
super(RGDAN, self).__init__()
D = K * d
self.fc1 = torch.nn.Linear(1, D)
self.fc2 = torch.nn.Linear(D, D)
self.STEmb = STEmbModel(SEDims, TEDims, K * d, device)
self.STAttBlockEnc = STAttModel(K, d, adj, device)
self.STAttBlockDec = STAttModel(K, d, adj, device)
self.transformAttention = TransformAttentionModel(K, d)
self.P = P
self.L = L
self.device = device
self.fc26 = torch.nn.Linear(D, D)
self.fc27 = torch.nn.Linear(D, 1)
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, SE, TE):
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
X = self.fc2(F.gelu(self.fc1(X)))
STE = self.STEmb(SE, TE)
STE_P = STE[:, : self.P]
STE_Q = STE[:, self.P:]
X = self.STAttBlockEnc(X, STE_P, adp, Mask=True)
X = self.transformAttention(X, STE_P, STE_Q)
X = self.STAttBlockDec(X, STE_Q, adp, Mask=True)
X = self.fc27(self.dropout(F.gelu(self.fc26(X))))
return X.squeeze(3)
class RGDANModel(nn.Module):
"""Wrapper to integrate RGDAN with TrafficWheel pipeline.
Expects dataloader to provide tensors shaped as:
- X: [B, T_in, N, F] where F>=1 and we use channel 0
- Y: [B, T_out, N, F]
We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended).
"""
def __init__(self, args):
super(RGDANModel, self).__init__()
self.args = args
self.device = args.get('device', 'cpu')
self.num_nodes = args['num_nodes']
self.input_dim = args['input_dim']
self.output_dim = args['output_dim']
self.P = args.get('lag', args.get('history', 12))
self.L = args.get('horizon', 12)
# RGDAN hyper-params with defaults
self.K = args.get('K', 3)
self.d = args.get('d', 8)
self.SEDims = args.get('SEDims', 16)
self.TEDims = args.get('TEDims', 288 + 7)
# adjacency set (two views expected by STAttModel)
# use distance matrix from get_adj. Build two masks: forward and backward edges
adj_distance = get_adj({'num_nodes': self.num_nodes})
adj = []
if adj_distance is None:
base = torch.ones(self.num_nodes, self.num_nodes, device=self.device)
adj = [base, base]
else:
base = torch.from_numpy(adj_distance).float().to(self.device)
adj = [base, base.T]
self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True)
self.rgdan = RGDAN(
K=self.K,
d=self.d,
SEDims=self.SEDims,
TEDims=self.TEDims,
P=self.P,
L=self.L,
device=self.device,
adj=adj,
num_nodes=self.num_nodes,
)
def forward(self, x):
# x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week]
x0 = x[..., 0:1]
steps_per_day = self.args.get('steps_per_day', 288)
days_per_week = self.args.get('days_per_week', 7)
B, T_in, N, F_total = x.shape
T_out = self.L
# Extract TE for observed window from appended channels (constant across nodes)
time_in_day_cont = x[:, :, 0, -2] # [B, T_in]
day_in_week_cont = x[:, :, 0, -1] # [B, T_in]
tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long()
dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long()
# Extrapolate TE for horizon
last_tod = tod_idx[:, -1] # [B]
last_dow = dow_idx[:, -1] # [B]
offsets = torch.arange(1, T_out + 1, device=x.device)
future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0)
future_tod = (future_tod_linear % steps_per_day).long()
carry_days = (future_tod_linear // steps_per_day).long()
future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week
TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2]
TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2]
TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2]
# SE: node static embeddings [N, SEDims]
SE = self.se_embedding
y = self.rgdan(x0, SE, TE)
# Output: [B, T_out, N]
if y.dim() == 3:
y = y.unsqueeze(-1)
return y

View File

@ -22,6 +22,8 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel
from model.ST_SSL.ST_SSL import STSSLModel from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet from model.STAWnet.STAWnet import STAWnet
from model.STEP.STEP import STEP
from model.RGDAN.RGDAN import RGDANModel
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:
@ -49,4 +51,6 @@ def model_selector(model):
case 'ST_SSL': return STSSLModel(model) case 'ST_SSL': return STSSLModel(model)
case 'STGNRDE': return make_nrde_model(model) case 'STGNRDE': return make_nrde_model(model)
case 'STAWnet': return STAWnet(model) case 'STAWnet': return STAWnet(model)
case 'STEP': return STEP(model)
case 'RGDAN': return RGDANModel(model)

1
temp_repo/STEP Submodule

@ -0,0 +1 @@
Subproject commit 566e2738da2d83f055718d8edb609ad8dc325204

View File

@ -13,7 +13,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
lr_scheduler, kwargs[0], None) lr_scheduler, kwargs[0], None)
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None) lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'DCRNN': return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler) lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler) lr_scheduler)