TrafficWheel/model/RGDAN/RGDAN.py

349 lines
13 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from data.get_adj import get_adj
class gcn(torch.nn.Module):
def __init__(self, k, d):
super(gcn, self).__init__()
D = k * d
self.fc = torch.nn.Linear(2 * D, D)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, STE, A):
X = torch.cat((X, STE), dim=-1)
H = F.gelu(self.fc(X))
H = torch.einsum('ncvl,vw->ncwl', (H, A))
return self.dropout(H.contiguous())
class randomGAT(torch.nn.Module):
def __init__(self, k, d, adj, device):
super(randomGAT, self).__init__()
D = k * d
self.d = d
self.K = k
num_nodes = adj.shape[0]
self.device = device
self.fc = torch.nn.Linear(2 * D, D)
self.adj = adj
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
def forward(self, X, STE):
X = torch.cat((X, STE), dim=-1)
H = F.gelu(self.fc(X))
H = torch.cat(torch.split(H, self.d, dim=-1), dim=0)
adp = torch.mm(self.nodevec1, self.nodevec2)
zero_vec = torch.tensor(-9e15).to(self.device)
adp = torch.where(self.adj > 0, adp, zero_vec)
adj = F.softmax(adp, dim=-1)
H = torch.einsum('vw,ncwl->ncvl', (adj, H))
H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1)
return F.gelu(H.contiguous())
class STEmbModel(torch.nn.Module):
def __init__(self, SEDims, TEDims, OutDims, device):
super(STEmbModel, self).__init__()
self.TEDims = TEDims
self.fc3 = torch.nn.Linear(SEDims, OutDims)
self.fc4 = torch.nn.Linear(OutDims, OutDims)
self.fc5 = torch.nn.Linear(TEDims, OutDims)
self.fc6 = torch.nn.Linear(OutDims, OutDims)
self.device = device
def forward(self, SE, TE):
SE = SE.unsqueeze(0).unsqueeze(0)
SE = self.fc4(F.gelu(self.fc3(SE)))
dayofweek = F.one_hot(TE[..., 0], num_classes=7)
timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7)
TE = torch.cat((dayofweek, timeofday), dim=-1)
TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device)
TE = self.fc6(F.gelu(self.fc5(TE)))
sum_tensor = torch.add(SE, TE)
return sum_tensor
class SpatialAttentionModel(torch.nn.Module):
def __init__(self, K, d, adj, dropout=0.3, mask=True):
super(SpatialAttentionModel, self).__init__()
D = K * d
self.fc7 = torch.nn.Linear(2 * D, D)
self.fc8 = torch.nn.Linear(2 * D, D)
self.fc9 = torch.nn.Linear(2 * D, D)
self.fc10 = torch.nn.Linear(D, D)
self.fc11 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.adj = adj
self.mask = mask
self.dropout = dropout
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, X, STE):
X = torch.cat((X, STE), dim=-1)
query = F.gelu(self.fc7(X))
key = F.gelu(self.fc8(X))
value = F.gelu(self.fc9(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
attention = torch.matmul(query, torch.transpose(key, 2, 3))
attention /= (self.d ** 0.5)
if self.mask:
zero_vec = -9e15 * torch.ones_like(attention)
attention = torch.where(self.adj > 0, attention, zero_vec)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.fc11(F.gelu(self.fc10(X)))
return X
class TemporalAttentionModel(torch.nn.Module):
def __init__(self, K, d, device):
super(TemporalAttentionModel, self).__init__()
D = K * d
self.fc12 = torch.nn.Linear(2 * D, D)
self.fc13 = torch.nn.Linear(2 * D, D)
self.fc14 = torch.nn.Linear(2 * D, D)
self.fc15 = torch.nn.Linear(D, D)
self.fc16 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, STE, Mask=True):
X = torch.cat((X, STE), dim=-1)
query = F.gelu(self.fc12(X))
key = F.gelu(self.fc13(X))
value = F.gelu(self.fc14(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
query = torch.transpose(query, 2, 1)
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
value = torch.transpose(value, 2, 1)
attention = torch.matmul(query, key)
attention /= (self.d ** 0.5)
if Mask:
num_steps = X.shape[1]
mask = torch.ones(num_steps, num_steps).to(self.device)
mask = torch.tril(mask)
zero_vec = torch.tensor(-9e15).to(self.device)
mask = mask.to(torch.bool)
attention = torch.where(mask, attention, zero_vec)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.transpose(X, 2, 1)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.dropout(self.fc16(F.gelu(self.fc15(X))))
return X
class GatedFusionModel(torch.nn.Module):
def __init__(self, K, d):
super(GatedFusionModel, self).__init__()
D = K * d
self.fc17 = torch.nn.Linear(D, D)
self.fc18 = torch.nn.Linear(D, D)
self.fc19 = torch.nn.Linear(D, D)
self.fc20 = torch.nn.Linear(D, D)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, HS, HT):
XS = self.fc17(HS)
XT = self.fc18(HT)
z = self.sigmoid(torch.add(XS, XT))
H = torch.add((z * HS), ((1 - z) * HT))
H = self.fc20(F.gelu(self.fc19(H)))
return H
class STAttModel(torch.nn.Module):
def __init__(self, K, d, adj, device):
super(STAttModel, self).__init__()
D = K * d
self.fc30 = torch.nn.Linear(7 * D, D)
self.gcn = gcn(K, d)
self.gcn1 = randomGAT(K, d, adj[0], device)
self.gcn2 = randomGAT(K, d, adj[0], device)
self.gcn3 = randomGAT(K, d, adj[1], device)
self.gcn4 = randomGAT(K, d, adj[1], device)
self.temporalAttention = TemporalAttentionModel(K, d, device)
self.gatedFusion = GatedFusionModel(K, d)
def forward(self, X, STE, adp, Mask=True):
HS1 = self.gcn1(X, STE)
HS2 = self.gcn2(HS1, STE)
HS3 = self.gcn3(X, STE)
HS4 = self.gcn4(HS3, STE)
HS5 = self.gcn(X, STE, adp)
HS6 = self.gcn(HS5, STE, adp)
HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1)
HS = F.gelu(self.fc30(HS))
HT = self.temporalAttention(X, STE, Mask)
H = self.gatedFusion(HS, HT)
return torch.add(X, H)
class TransformAttentionModel(torch.nn.Module):
def __init__(self, K, d):
super(TransformAttentionModel, self).__init__()
D = K * d
self.fc21 = torch.nn.Linear(D, D)
self.fc22 = torch.nn.Linear(D, D)
self.fc23 = torch.nn.Linear(D, D)
self.fc24 = torch.nn.Linear(D, D)
self.fc25 = torch.nn.Linear(D, D)
self.K = K
self.d = d
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, X, STE_P, STE_Q):
query = F.gelu(self.fc21(STE_Q))
key = F.gelu(self.fc22(STE_P))
value = F.gelu(self.fc23(X))
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
query = torch.transpose(query, 2, 1)
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
value = torch.transpose(value, 2, 1)
attention = torch.matmul(query, key)
attention /= (self.d ** 0.5)
attention = self.softmax(attention)
X = torch.matmul(attention, value)
X = torch.transpose(X, 2, 1)
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
X = self.fc25(F.gelu(self.fc24(X)))
return X
class RGDAN(nn.Module):
def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes):
super(RGDAN, self).__init__()
D = K * d
self.fc1 = torch.nn.Linear(1, D)
self.fc2 = torch.nn.Linear(D, D)
self.STEmb = STEmbModel(SEDims, TEDims, K * d, device)
self.STAttBlockEnc = STAttModel(K, d, adj, device)
self.STAttBlockDec = STAttModel(K, d, adj, device)
self.transformAttention = TransformAttentionModel(K, d)
self.P = P
self.L = L
self.device = device
self.fc26 = torch.nn.Linear(D, D)
self.fc27 = torch.nn.Linear(D, 1)
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
self.dropout = nn.Dropout(p=0.1)
def forward(self, X, SE, TE):
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
X = self.fc2(F.gelu(self.fc1(X)))
STE = self.STEmb(SE, TE)
STE_P = STE[:, : self.P]
STE_Q = STE[:, self.P:]
X = self.STAttBlockEnc(X, STE_P, adp, Mask=True)
X = self.transformAttention(X, STE_P, STE_Q)
X = self.STAttBlockDec(X, STE_Q, adp, Mask=True)
X = self.fc27(self.dropout(F.gelu(self.fc26(X))))
return X.squeeze(3)
class RGDANModel(nn.Module):
"""Wrapper to integrate RGDAN with TrafficWheel pipeline.
Expects dataloader to provide tensors shaped as:
- X: [B, T_in, N, F] where F>=1 and we use channel 0
- Y: [B, T_out, N, F]
We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended).
"""
def __init__(self, args):
super(RGDANModel, self).__init__()
self.args = args
self.device = args.get('device', 'cpu')
self.num_nodes = args['num_nodes']
self.input_dim = args['input_dim']
self.output_dim = args['output_dim']
self.P = args.get('lag', args.get('history', 12))
self.L = args.get('horizon', 12)
# RGDAN hyper-params with defaults
self.K = args.get('K', 3)
self.d = args.get('d', 8)
self.SEDims = args.get('SEDims', 16)
self.TEDims = args.get('TEDims', 288 + 7)
# adjacency set (two views expected by STAttModel)
# use distance matrix from get_adj. Build two masks: forward and backward edges
adj_distance = get_adj({'num_nodes': self.num_nodes})
adj = []
if adj_distance is None:
base = torch.ones(self.num_nodes, self.num_nodes, device=self.device)
adj = [base, base]
else:
base = torch.from_numpy(adj_distance).float().to(self.device)
adj = [base, base.T]
self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True)
self.rgdan = RGDAN(
K=self.K,
d=self.d,
SEDims=self.SEDims,
TEDims=self.TEDims,
P=self.P,
L=self.L,
device=self.device,
adj=adj,
num_nodes=self.num_nodes,
)
def forward(self, x):
# x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week]
x0 = x[..., 0:1]
steps_per_day = self.args.get('steps_per_day', 288)
days_per_week = self.args.get('days_per_week', 7)
B, T_in, N, F_total = x.shape
T_out = self.L
# Extract TE for observed window from appended channels (constant across nodes)
time_in_day_cont = x[:, :, 0, -2] # [B, T_in]
day_in_week_cont = x[:, :, 0, -1] # [B, T_in]
tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long()
dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long()
# Extrapolate TE for horizon
last_tod = tod_idx[:, -1] # [B]
last_dow = dow_idx[:, -1] # [B]
offsets = torch.arange(1, T_out + 1, device=x.device)
future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0)
future_tod = (future_tod_linear % steps_per_day).long()
carry_days = (future_tod_linear // steps_per_day).long()
future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week
TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2]
TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2]
TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2]
# SE: node static embeddings [N, SEDims]
SE = self.se_embedding
y = self.rgdan(x0, SE, TE)
# Output: [B, T_out, N]
if y.dim() == 3:
y = y.unsqueeze(-1)
return y