From 9e22712d7705462377d918759d2d07a10f08f4a8 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 19 Aug 2025 15:37:14 +0800 Subject: [PATCH] add RGDAN --- config/RGDAN/PEMSD3.yaml | 48 +++++ config/RGDAN/PEMSD4.yaml | 48 +++++ config/RGDAN/PEMSD7.yaml | 48 +++++ config/RGDAN/PEMSD8.yaml | 48 +++++ config/STEP/PEMSD4.yaml | 4 +- model/DCRNN/dcrnn_model.py | 8 +- model/RGDAN/RGDAN.py | 348 ++++++++++++++++++++++++++++++++++++ model/model_selector.py | 4 + temp_repo/STEP | 1 + trainer/trainer_selector.py | 2 +- 10 files changed, 552 insertions(+), 7 deletions(-) create mode 100644 config/RGDAN/PEMSD3.yaml create mode 100644 config/RGDAN/PEMSD4.yaml create mode 100644 config/RGDAN/PEMSD7.yaml create mode 100644 config/RGDAN/PEMSD8.yaml create mode 100644 model/RGDAN/RGDAN.py create mode 160000 temp_repo/STEP diff --git a/config/RGDAN/PEMSD3.yaml b/config/RGDAN/PEMSD3.yaml new file mode 100644 index 0000000..dab3817 --- /dev/null +++ b/config/RGDAN/PEMSD3.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + K: 3 + d: 8 + SEDims: 16 + TEDims: 295 + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/RGDAN/PEMSD4.yaml b/config/RGDAN/PEMSD4.yaml new file mode 100644 index 0000000..f1d4a03 --- /dev/null +++ b/config/RGDAN/PEMSD4.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + K: 3 + d: 8 + SEDims: 16 + TEDims: 295 # 7 + 288 + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/RGDAN/PEMSD7.yaml b/config/RGDAN/PEMSD7.yaml new file mode 100644 index 0000000..90c3f38 --- /dev/null +++ b/config/RGDAN/PEMSD7.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + K: 3 + d: 8 + SEDims: 16 + TEDims: 295 + +train: + loss_func: mae + seed: 10 + batch_size: 8 # larger graph may need smaller batch + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/RGDAN/PEMSD8.yaml b/config/RGDAN/PEMSD8.yaml new file mode 100644 index 0000000..0ec4744 --- /dev/null +++ b/config/RGDAN/PEMSD8.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + K: 3 + d: 8 + SEDims: 16 + TEDims: 295 + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/STEP/PEMSD4.yaml b/config/STEP/PEMSD4.yaml index 70d9ef3..f0e8e79 100644 --- a/config/STEP/PEMSD4.yaml +++ b/config/STEP/PEMSD4.yaml @@ -15,7 +15,7 @@ data: days_per_week: 7 sample: 1 input_dim: 3 - batch_size: 8 + batch_size: 64 model: type: 'STEP' @@ -68,7 +68,7 @@ model: train: loss_func: mae seed: 10 - batch_size: 8 + batch_size: 64 epochs: 100 lr_init: 0.002 weight_decay: 1.0e-5 diff --git a/model/DCRNN/dcrnn_model.py b/model/DCRNN/dcrnn_model.py index f7e7648..b0565c3 100755 --- a/model/DCRNN/dcrnn_model.py +++ b/model/DCRNN/dcrnn_model.py @@ -141,10 +141,10 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): decoder_hidden_state) decoder_input = decoder_output outputs.append(decoder_output) - if self.training and self.use_curriculum_learning: - c = np.random.uniform(0, 1) - if c < self._compute_sampling_threshold(batches_seen): - decoder_input = labels[t] + # if self.training and self.use_curriculum_learning: + # c = np.random.uniform(0, 1) + # if c < self._compute_sampling_threshold(batches_seen): + # decoder_input = labels[t] outputs = torch.stack(outputs) return outputs diff --git a/model/RGDAN/RGDAN.py b/model/RGDAN/RGDAN.py new file mode 100644 index 0000000..5365561 --- /dev/null +++ b/model/RGDAN/RGDAN.py @@ -0,0 +1,348 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from data.get_adj import get_adj + + +class gcn(torch.nn.Module): + def __init__(self, k, d): + super(gcn, self).__init__() + D = k * d + self.fc = torch.nn.Linear(2 * D, D) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, X, STE, A): + X = torch.cat((X, STE), dim=-1) + H = F.gelu(self.fc(X)) + H = torch.einsum('ncvl,vw->ncwl', (H, A)) + return self.dropout(H.contiguous()) + + +class randomGAT(torch.nn.Module): + def __init__(self, k, d, adj, device): + super(randomGAT, self).__init__() + D = k * d + self.d = d + self.K = k + num_nodes = adj.shape[0] + self.device = device + self.fc = torch.nn.Linear(2 * D, D) + self.adj = adj + self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) + self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) + + def forward(self, X, STE): + X = torch.cat((X, STE), dim=-1) + H = F.gelu(self.fc(X)) + H = torch.cat(torch.split(H, self.d, dim=-1), dim=0) + adp = torch.mm(self.nodevec1, self.nodevec2) + zero_vec = torch.tensor(-9e15).to(self.device) + adp = torch.where(self.adj > 0, adp, zero_vec) + adj = F.softmax(adp, dim=-1) + H = torch.einsum('vw,ncwl->ncvl', (adj, H)) + H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1) + return F.gelu(H.contiguous()) + + +class STEmbModel(torch.nn.Module): + def __init__(self, SEDims, TEDims, OutDims, device): + super(STEmbModel, self).__init__() + self.TEDims = TEDims + self.fc3 = torch.nn.Linear(SEDims, OutDims) + self.fc4 = torch.nn.Linear(OutDims, OutDims) + self.fc5 = torch.nn.Linear(TEDims, OutDims) + self.fc6 = torch.nn.Linear(OutDims, OutDims) + self.device = device + + def forward(self, SE, TE): + SE = SE.unsqueeze(0).unsqueeze(0) + SE = self.fc4(F.gelu(self.fc3(SE))) + dayofweek = F.one_hot(TE[..., 0], num_classes=7) + timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7) + TE = torch.cat((dayofweek, timeofday), dim=-1) + TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device) + TE = self.fc6(F.gelu(self.fc5(TE))) + sum_tensor = torch.add(SE, TE) + return sum_tensor + + +class SpatialAttentionModel(torch.nn.Module): + def __init__(self, K, d, adj, dropout=0.3, mask=True): + super(SpatialAttentionModel, self).__init__() + D = K * d + self.fc7 = torch.nn.Linear(2 * D, D) + self.fc8 = torch.nn.Linear(2 * D, D) + self.fc9 = torch.nn.Linear(2 * D, D) + self.fc10 = torch.nn.Linear(D, D) + self.fc11 = torch.nn.Linear(D, D) + self.K = K + self.d = d + self.adj = adj + self.mask = mask + self.dropout = dropout + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, X, STE): + X = torch.cat((X, STE), dim=-1) + query = F.gelu(self.fc7(X)) + key = F.gelu(self.fc8(X)) + value = F.gelu(self.fc9(X)) + query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) + key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) + value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) + attention = torch.matmul(query, torch.transpose(key, 2, 3)) + attention /= (self.d ** 0.5) + if self.mask: + zero_vec = -9e15 * torch.ones_like(attention) + attention = torch.where(self.adj > 0, attention, zero_vec) + attention = self.softmax(attention) + X = torch.matmul(attention, value) + X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) + X = self.fc11(F.gelu(self.fc10(X))) + return X + + +class TemporalAttentionModel(torch.nn.Module): + def __init__(self, K, d, device): + super(TemporalAttentionModel, self).__init__() + D = K * d + self.fc12 = torch.nn.Linear(2 * D, D) + self.fc13 = torch.nn.Linear(2 * D, D) + self.fc14 = torch.nn.Linear(2 * D, D) + self.fc15 = torch.nn.Linear(D, D) + self.fc16 = torch.nn.Linear(D, D) + self.K = K + self.d = d + self.device = device + self.softmax = torch.nn.Softmax(dim=-1) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, X, STE, Mask=True): + X = torch.cat((X, STE), dim=-1) + query = F.gelu(self.fc12(X)) + key = F.gelu(self.fc13(X)) + value = F.gelu(self.fc14(X)) + query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) + key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) + value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) + query = torch.transpose(query, 2, 1) + key = torch.transpose(torch.transpose(key, 1, 2), 2, 3) + value = torch.transpose(value, 2, 1) + attention = torch.matmul(query, key) + attention /= (self.d ** 0.5) + if Mask: + num_steps = X.shape[1] + mask = torch.ones(num_steps, num_steps).to(self.device) + mask = torch.tril(mask) + zero_vec = torch.tensor(-9e15).to(self.device) + mask = mask.to(torch.bool) + attention = torch.where(mask, attention, zero_vec) + attention = self.softmax(attention) + X = torch.matmul(attention, value) + X = torch.transpose(X, 2, 1) + X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) + X = self.dropout(self.fc16(F.gelu(self.fc15(X)))) + return X + + +class GatedFusionModel(torch.nn.Module): + def __init__(self, K, d): + super(GatedFusionModel, self).__init__() + D = K * d + self.fc17 = torch.nn.Linear(D, D) + self.fc18 = torch.nn.Linear(D, D) + self.fc19 = torch.nn.Linear(D, D) + self.fc20 = torch.nn.Linear(D, D) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, HS, HT): + XS = self.fc17(HS) + XT = self.fc18(HT) + z = self.sigmoid(torch.add(XS, XT)) + H = torch.add((z * HS), ((1 - z) * HT)) + H = self.fc20(F.gelu(self.fc19(H))) + return H + + +class STAttModel(torch.nn.Module): + def __init__(self, K, d, adj, device): + super(STAttModel, self).__init__() + D = K * d + self.fc30 = torch.nn.Linear(7 * D, D) + self.gcn = gcn(K, d) + self.gcn1 = randomGAT(K, d, adj[0], device) + self.gcn2 = randomGAT(K, d, adj[0], device) + self.gcn3 = randomGAT(K, d, adj[1], device) + self.gcn4 = randomGAT(K, d, adj[1], device) + self.temporalAttention = TemporalAttentionModel(K, d, device) + self.gatedFusion = GatedFusionModel(K, d) + + def forward(self, X, STE, adp, Mask=True): + HS1 = self.gcn1(X, STE) + HS2 = self.gcn2(HS1, STE) + HS3 = self.gcn3(X, STE) + HS4 = self.gcn4(HS3, STE) + HS5 = self.gcn(X, STE, adp) + HS6 = self.gcn(HS5, STE, adp) + HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1) + HS = F.gelu(self.fc30(HS)) + HT = self.temporalAttention(X, STE, Mask) + H = self.gatedFusion(HS, HT) + return torch.add(X, H) + + +class TransformAttentionModel(torch.nn.Module): + def __init__(self, K, d): + super(TransformAttentionModel, self).__init__() + D = K * d + self.fc21 = torch.nn.Linear(D, D) + self.fc22 = torch.nn.Linear(D, D) + self.fc23 = torch.nn.Linear(D, D) + self.fc24 = torch.nn.Linear(D, D) + self.fc25 = torch.nn.Linear(D, D) + self.K = K + self.d = d + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, X, STE_P, STE_Q): + query = F.gelu(self.fc21(STE_Q)) + key = F.gelu(self.fc22(STE_P)) + value = F.gelu(self.fc23(X)) + query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) + key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) + value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) + query = torch.transpose(query, 2, 1) + key = torch.transpose(torch.transpose(key, 1, 2), 2, 3) + value = torch.transpose(value, 2, 1) + attention = torch.matmul(query, key) + attention /= (self.d ** 0.5) + attention = self.softmax(attention) + X = torch.matmul(attention, value) + X = torch.transpose(X, 2, 1) + X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) + X = self.fc25(F.gelu(self.fc24(X))) + return X + + +class RGDAN(nn.Module): + def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes): + super(RGDAN, self).__init__() + D = K * d + self.fc1 = torch.nn.Linear(1, D) + self.fc2 = torch.nn.Linear(D, D) + self.STEmb = STEmbModel(SEDims, TEDims, K * d, device) + self.STAttBlockEnc = STAttModel(K, d, adj, device) + self.STAttBlockDec = STAttModel(K, d, adj, device) + self.transformAttention = TransformAttentionModel(K, d) + self.P = P + self.L = L + self.device = device + self.fc26 = torch.nn.Linear(D, D) + self.fc27 = torch.nn.Linear(D, 1) + self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) + self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, X, SE, TE): + adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) + X = self.fc2(F.gelu(self.fc1(X))) + STE = self.STEmb(SE, TE) + STE_P = STE[:, : self.P] + STE_Q = STE[:, self.P:] + X = self.STAttBlockEnc(X, STE_P, adp, Mask=True) + X = self.transformAttention(X, STE_P, STE_Q) + X = self.STAttBlockDec(X, STE_Q, adp, Mask=True) + X = self.fc27(self.dropout(F.gelu(self.fc26(X)))) + return X.squeeze(3) + + +class RGDANModel(nn.Module): + """Wrapper to integrate RGDAN with TrafficWheel pipeline. + + Expects dataloader to provide tensors shaped as: + - X: [B, T_in, N, F] where F>=1 and we use channel 0 + - Y: [B, T_out, N, F] + We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended). + """ + + def __init__(self, args): + super(RGDANModel, self).__init__() + + self.args = args + self.device = args.get('device', 'cpu') + self.num_nodes = args['num_nodes'] + self.input_dim = args['input_dim'] + self.output_dim = args['output_dim'] + self.P = args.get('lag', args.get('history', 12)) + self.L = args.get('horizon', 12) + + # RGDAN hyper-params with defaults + self.K = args.get('K', 3) + self.d = args.get('d', 8) + self.SEDims = args.get('SEDims', 16) + self.TEDims = args.get('TEDims', 288 + 7) + + # adjacency set (two views expected by STAttModel) + # use distance matrix from get_adj. Build two masks: forward and backward edges + adj_distance = get_adj({'num_nodes': self.num_nodes}) + adj = [] + if adj_distance is None: + base = torch.ones(self.num_nodes, self.num_nodes, device=self.device) + adj = [base, base] + else: + base = torch.from_numpy(adj_distance).float().to(self.device) + adj = [base, base.T] + + self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True) + + self.rgdan = RGDAN( + K=self.K, + d=self.d, + SEDims=self.SEDims, + TEDims=self.TEDims, + P=self.P, + L=self.L, + device=self.device, + adj=adj, + num_nodes=self.num_nodes, + ) + + def forward(self, x): + # x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week] + x0 = x[..., 0:1] + + steps_per_day = self.args.get('steps_per_day', 288) + days_per_week = self.args.get('days_per_week', 7) + B, T_in, N, F_total = x.shape + T_out = self.L + + # Extract TE for observed window from appended channels (constant across nodes) + time_in_day_cont = x[:, :, 0, -2] # [B, T_in] + day_in_week_cont = x[:, :, 0, -1] # [B, T_in] + tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long() + dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long() + + # Extrapolate TE for horizon + last_tod = tod_idx[:, -1] # [B] + last_dow = dow_idx[:, -1] # [B] + offsets = torch.arange(1, T_out + 1, device=x.device) + future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0) + future_tod = (future_tod_linear % steps_per_day).long() + carry_days = (future_tod_linear // steps_per_day).long() + future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week + + TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2] + TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2] + TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2] + + # SE: node static embeddings [N, SEDims] + SE = self.se_embedding + + y = self.rgdan(x0, SE, TE) + # Output: [B, T_out, N] + if y.dim() == 3: + y = y.unsqueeze(-1) + return y + + diff --git a/model/model_selector.py b/model/model_selector.py index 4fba6bc..dfdae04 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -22,6 +22,8 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet +from model.STEP.STEP import STEP +from model.RGDAN.RGDAN import RGDANModel def model_selector(model): match model['type']: @@ -49,4 +51,6 @@ def model_selector(model): case 'ST_SSL': return STSSLModel(model) case 'STGNRDE': return make_nrde_model(model) case 'STAWnet': return STAWnet(model) + case 'STEP': return STEP(model) + case 'RGDAN': return RGDANModel(model) diff --git a/temp_repo/STEP b/temp_repo/STEP new file mode 160000 index 0000000..566e273 --- /dev/null +++ b/temp_repo/STEP @@ -0,0 +1 @@ +Subproject commit 566e2738da2d83f055718d8edb609ad8dc325204 diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 5f66185..4c2d5e5 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -13,7 +13,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader lr_scheduler, kwargs[0], None) case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler, kwargs[0], None) - case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case 'DCRNN': return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler)