import torch import torch.nn as nn import torch.nn.functional as F from data.get_adj import get_adj class gcn(torch.nn.Module): def __init__(self, k, d): super(gcn, self).__init__() D = k * d self.fc = torch.nn.Linear(2 * D, D) self.dropout = nn.Dropout(p=0.1) def forward(self, X, STE, A): X = torch.cat((X, STE), dim=-1) H = F.gelu(self.fc(X)) H = torch.einsum('ncvl,vw->ncwl', (H, A)) return self.dropout(H.contiguous()) class randomGAT(torch.nn.Module): def __init__(self, k, d, adj, device): super(randomGAT, self).__init__() D = k * d self.d = d self.K = k num_nodes = adj.shape[0] self.device = device self.fc = torch.nn.Linear(2 * D, D) self.adj = adj self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) def forward(self, X, STE): X = torch.cat((X, STE), dim=-1) H = F.gelu(self.fc(X)) H = torch.cat(torch.split(H, self.d, dim=-1), dim=0) adp = torch.mm(self.nodevec1, self.nodevec2) zero_vec = torch.tensor(-9e15).to(self.device) adp = torch.where(self.adj > 0, adp, zero_vec) adj = F.softmax(adp, dim=-1) H = torch.einsum('vw,ncwl->ncvl', (adj, H)) H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1) return F.gelu(H.contiguous()) class STEmbModel(torch.nn.Module): def __init__(self, SEDims, TEDims, OutDims, device): super(STEmbModel, self).__init__() self.TEDims = TEDims self.fc3 = torch.nn.Linear(SEDims, OutDims) self.fc4 = torch.nn.Linear(OutDims, OutDims) self.fc5 = torch.nn.Linear(TEDims, OutDims) self.fc6 = torch.nn.Linear(OutDims, OutDims) self.device = device def forward(self, SE, TE): SE = SE.unsqueeze(0).unsqueeze(0) SE = self.fc4(F.gelu(self.fc3(SE))) dayofweek = F.one_hot(TE[..., 0], num_classes=7) timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7) TE = torch.cat((dayofweek, timeofday), dim=-1) TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device) TE = self.fc6(F.gelu(self.fc5(TE))) sum_tensor = torch.add(SE, TE) return sum_tensor class SpatialAttentionModel(torch.nn.Module): def __init__(self, K, d, adj, dropout=0.3, mask=True): super(SpatialAttentionModel, self).__init__() D = K * d self.fc7 = torch.nn.Linear(2 * D, D) self.fc8 = torch.nn.Linear(2 * D, D) self.fc9 = torch.nn.Linear(2 * D, D) self.fc10 = torch.nn.Linear(D, D) self.fc11 = torch.nn.Linear(D, D) self.K = K self.d = d self.adj = adj self.mask = mask self.dropout = dropout self.softmax = torch.nn.Softmax(dim=-1) def forward(self, X, STE): X = torch.cat((X, STE), dim=-1) query = F.gelu(self.fc7(X)) key = F.gelu(self.fc8(X)) value = F.gelu(self.fc9(X)) query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) attention = torch.matmul(query, torch.transpose(key, 2, 3)) attention /= (self.d ** 0.5) if self.mask: zero_vec = -9e15 * torch.ones_like(attention) attention = torch.where(self.adj > 0, attention, zero_vec) attention = self.softmax(attention) X = torch.matmul(attention, value) X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) X = self.fc11(F.gelu(self.fc10(X))) return X class TemporalAttentionModel(torch.nn.Module): def __init__(self, K, d, device): super(TemporalAttentionModel, self).__init__() D = K * d self.fc12 = torch.nn.Linear(2 * D, D) self.fc13 = torch.nn.Linear(2 * D, D) self.fc14 = torch.nn.Linear(2 * D, D) self.fc15 = torch.nn.Linear(D, D) self.fc16 = torch.nn.Linear(D, D) self.K = K self.d = d self.device = device self.softmax = torch.nn.Softmax(dim=-1) self.dropout = nn.Dropout(p=0.1) def forward(self, X, STE, Mask=True): X = torch.cat((X, STE), dim=-1) query = F.gelu(self.fc12(X)) key = F.gelu(self.fc13(X)) value = F.gelu(self.fc14(X)) query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) query = torch.transpose(query, 2, 1) key = torch.transpose(torch.transpose(key, 1, 2), 2, 3) value = torch.transpose(value, 2, 1) attention = torch.matmul(query, key) attention /= (self.d ** 0.5) if Mask: num_steps = X.shape[1] mask = torch.ones(num_steps, num_steps).to(self.device) mask = torch.tril(mask) zero_vec = torch.tensor(-9e15).to(self.device) mask = mask.to(torch.bool) attention = torch.where(mask, attention, zero_vec) attention = self.softmax(attention) X = torch.matmul(attention, value) X = torch.transpose(X, 2, 1) X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) X = self.dropout(self.fc16(F.gelu(self.fc15(X)))) return X class GatedFusionModel(torch.nn.Module): def __init__(self, K, d): super(GatedFusionModel, self).__init__() D = K * d self.fc17 = torch.nn.Linear(D, D) self.fc18 = torch.nn.Linear(D, D) self.fc19 = torch.nn.Linear(D, D) self.fc20 = torch.nn.Linear(D, D) self.sigmoid = torch.nn.Sigmoid() def forward(self, HS, HT): XS = self.fc17(HS) XT = self.fc18(HT) z = self.sigmoid(torch.add(XS, XT)) H = torch.add((z * HS), ((1 - z) * HT)) H = self.fc20(F.gelu(self.fc19(H))) return H class STAttModel(torch.nn.Module): def __init__(self, K, d, adj, device): super(STAttModel, self).__init__() D = K * d self.fc30 = torch.nn.Linear(7 * D, D) self.gcn = gcn(K, d) self.gcn1 = randomGAT(K, d, adj[0], device) self.gcn2 = randomGAT(K, d, adj[0], device) self.gcn3 = randomGAT(K, d, adj[1], device) self.gcn4 = randomGAT(K, d, adj[1], device) self.temporalAttention = TemporalAttentionModel(K, d, device) self.gatedFusion = GatedFusionModel(K, d) def forward(self, X, STE, adp, Mask=True): HS1 = self.gcn1(X, STE) HS2 = self.gcn2(HS1, STE) HS3 = self.gcn3(X, STE) HS4 = self.gcn4(HS3, STE) HS5 = self.gcn(X, STE, adp) HS6 = self.gcn(HS5, STE, adp) HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1) HS = F.gelu(self.fc30(HS)) HT = self.temporalAttention(X, STE, Mask) H = self.gatedFusion(HS, HT) return torch.add(X, H) class TransformAttentionModel(torch.nn.Module): def __init__(self, K, d): super(TransformAttentionModel, self).__init__() D = K * d self.fc21 = torch.nn.Linear(D, D) self.fc22 = torch.nn.Linear(D, D) self.fc23 = torch.nn.Linear(D, D) self.fc24 = torch.nn.Linear(D, D) self.fc25 = torch.nn.Linear(D, D) self.K = K self.d = d self.softmax = torch.nn.Softmax(dim=-1) def forward(self, X, STE_P, STE_Q): query = F.gelu(self.fc21(STE_Q)) key = F.gelu(self.fc22(STE_P)) value = F.gelu(self.fc23(X)) query = torch.cat(torch.split(query, self.d, dim=-1), dim=0) key = torch.cat(torch.split(key, self.d, dim=-1), dim=0) value = torch.cat(torch.split(value, self.d, dim=-1), dim=0) query = torch.transpose(query, 2, 1) key = torch.transpose(torch.transpose(key, 1, 2), 2, 3) value = torch.transpose(value, 2, 1) attention = torch.matmul(query, key) attention /= (self.d ** 0.5) attention = self.softmax(attention) X = torch.matmul(attention, value) X = torch.transpose(X, 2, 1) X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1) X = self.fc25(F.gelu(self.fc24(X))) return X class RGDAN(nn.Module): def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes): super(RGDAN, self).__init__() D = K * d self.fc1 = torch.nn.Linear(1, D) self.fc2 = torch.nn.Linear(D, D) self.STEmb = STEmbModel(SEDims, TEDims, K * d, device) self.STAttBlockEnc = STAttModel(K, d, adj, device) self.STAttBlockDec = STAttModel(K, d, adj, device) self.transformAttention = TransformAttentionModel(K, d) self.P = P self.L = L self.device = device self.fc26 = torch.nn.Linear(D, D) self.fc27 = torch.nn.Linear(D, 1) self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) self.dropout = nn.Dropout(p=0.1) def forward(self, X, SE, TE): adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) X = self.fc2(F.gelu(self.fc1(X))) STE = self.STEmb(SE, TE) STE_P = STE[:, : self.P] STE_Q = STE[:, self.P:] X = self.STAttBlockEnc(X, STE_P, adp, Mask=True) X = self.transformAttention(X, STE_P, STE_Q) X = self.STAttBlockDec(X, STE_Q, adp, Mask=True) X = self.fc27(self.dropout(F.gelu(self.fc26(X)))) return X.squeeze(3) class RGDANModel(nn.Module): """Wrapper to integrate RGDAN with TrafficWheel pipeline. Expects dataloader to provide tensors shaped as: - X: [B, T_in, N, F] where F>=1 and we use channel 0 - Y: [B, T_out, N, F] We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended). """ def __init__(self, args): super(RGDANModel, self).__init__() self.args = args self.device = args.get('device', 'cpu') self.num_nodes = args['num_nodes'] self.input_dim = args['input_dim'] self.output_dim = args['output_dim'] self.P = args.get('lag', args.get('history', 12)) self.L = args.get('horizon', 12) # RGDAN hyper-params with defaults self.K = args.get('K', 3) self.d = args.get('d', 8) self.SEDims = args.get('SEDims', 16) self.TEDims = args.get('TEDims', 288 + 7) # adjacency set (two views expected by STAttModel) # use distance matrix from get_adj. Build two masks: forward and backward edges adj_distance = get_adj({'num_nodes': self.num_nodes}) adj = [] if adj_distance is None: base = torch.ones(self.num_nodes, self.num_nodes, device=self.device) adj = [base, base] else: base = torch.from_numpy(adj_distance).float().to(self.device) adj = [base, base.T] self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True) self.rgdan = RGDAN( K=self.K, d=self.d, SEDims=self.SEDims, TEDims=self.TEDims, P=self.P, L=self.L, device=self.device, adj=adj, num_nodes=self.num_nodes, ) def forward(self, x): # x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week] x0 = x[..., 0:1] steps_per_day = self.args.get('steps_per_day', 288) days_per_week = self.args.get('days_per_week', 7) B, T_in, N, F_total = x.shape T_out = self.L # Extract TE for observed window from appended channels (constant across nodes) time_in_day_cont = x[:, :, 0, -2] # [B, T_in] day_in_week_cont = x[:, :, 0, -1] # [B, T_in] tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long() dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long() # Extrapolate TE for horizon last_tod = tod_idx[:, -1] # [B] last_dow = dow_idx[:, -1] # [B] offsets = torch.arange(1, T_out + 1, device=x.device) future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0) future_tod = (future_tod_linear % steps_per_day).long() carry_days = (future_tod_linear // steps_per_day).long() future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2] TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2] TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2] # SE: node static embeddings [N, SEDims] SE = self.se_embedding y = self.rgdan(x0, SE, TE) # Output: [B, T_out, N] if y.dim() == 3: y = y.unsqueeze(-1) return y