349 lines
13 KiB
Python
349 lines
13 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from data.get_adj import get_adj
|
|
|
|
|
|
class gcn(torch.nn.Module):
|
|
def __init__(self, k, d):
|
|
super(gcn, self).__init__()
|
|
D = k * d
|
|
self.fc = torch.nn.Linear(2 * D, D)
|
|
self.dropout = nn.Dropout(p=0.1)
|
|
|
|
def forward(self, X, STE, A):
|
|
X = torch.cat((X, STE), dim=-1)
|
|
H = F.gelu(self.fc(X))
|
|
H = torch.einsum('ncvl,vw->ncwl', (H, A))
|
|
return self.dropout(H.contiguous())
|
|
|
|
|
|
class randomGAT(torch.nn.Module):
|
|
def __init__(self, k, d, adj, device):
|
|
super(randomGAT, self).__init__()
|
|
D = k * d
|
|
self.d = d
|
|
self.K = k
|
|
num_nodes = adj.shape[0]
|
|
self.device = device
|
|
self.fc = torch.nn.Linear(2 * D, D)
|
|
self.adj = adj
|
|
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
|
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
|
|
|
def forward(self, X, STE):
|
|
X = torch.cat((X, STE), dim=-1)
|
|
H = F.gelu(self.fc(X))
|
|
H = torch.cat(torch.split(H, self.d, dim=-1), dim=0)
|
|
adp = torch.mm(self.nodevec1, self.nodevec2)
|
|
zero_vec = torch.tensor(-9e15).to(self.device)
|
|
adp = torch.where(self.adj > 0, adp, zero_vec)
|
|
adj = F.softmax(adp, dim=-1)
|
|
H = torch.einsum('vw,ncwl->ncvl', (adj, H))
|
|
H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1)
|
|
return F.gelu(H.contiguous())
|
|
|
|
|
|
class STEmbModel(torch.nn.Module):
|
|
def __init__(self, SEDims, TEDims, OutDims, device):
|
|
super(STEmbModel, self).__init__()
|
|
self.TEDims = TEDims
|
|
self.fc3 = torch.nn.Linear(SEDims, OutDims)
|
|
self.fc4 = torch.nn.Linear(OutDims, OutDims)
|
|
self.fc5 = torch.nn.Linear(TEDims, OutDims)
|
|
self.fc6 = torch.nn.Linear(OutDims, OutDims)
|
|
self.device = device
|
|
|
|
def forward(self, SE, TE):
|
|
SE = SE.unsqueeze(0).unsqueeze(0)
|
|
SE = self.fc4(F.gelu(self.fc3(SE)))
|
|
dayofweek = F.one_hot(TE[..., 0], num_classes=7)
|
|
timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7)
|
|
TE = torch.cat((dayofweek, timeofday), dim=-1)
|
|
TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device)
|
|
TE = self.fc6(F.gelu(self.fc5(TE)))
|
|
sum_tensor = torch.add(SE, TE)
|
|
return sum_tensor
|
|
|
|
|
|
class SpatialAttentionModel(torch.nn.Module):
|
|
def __init__(self, K, d, adj, dropout=0.3, mask=True):
|
|
super(SpatialAttentionModel, self).__init__()
|
|
D = K * d
|
|
self.fc7 = torch.nn.Linear(2 * D, D)
|
|
self.fc8 = torch.nn.Linear(2 * D, D)
|
|
self.fc9 = torch.nn.Linear(2 * D, D)
|
|
self.fc10 = torch.nn.Linear(D, D)
|
|
self.fc11 = torch.nn.Linear(D, D)
|
|
self.K = K
|
|
self.d = d
|
|
self.adj = adj
|
|
self.mask = mask
|
|
self.dropout = dropout
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
|
|
def forward(self, X, STE):
|
|
X = torch.cat((X, STE), dim=-1)
|
|
query = F.gelu(self.fc7(X))
|
|
key = F.gelu(self.fc8(X))
|
|
value = F.gelu(self.fc9(X))
|
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
|
attention = torch.matmul(query, torch.transpose(key, 2, 3))
|
|
attention /= (self.d ** 0.5)
|
|
if self.mask:
|
|
zero_vec = -9e15 * torch.ones_like(attention)
|
|
attention = torch.where(self.adj > 0, attention, zero_vec)
|
|
attention = self.softmax(attention)
|
|
X = torch.matmul(attention, value)
|
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
|
X = self.fc11(F.gelu(self.fc10(X)))
|
|
return X
|
|
|
|
|
|
class TemporalAttentionModel(torch.nn.Module):
|
|
def __init__(self, K, d, device):
|
|
super(TemporalAttentionModel, self).__init__()
|
|
D = K * d
|
|
self.fc12 = torch.nn.Linear(2 * D, D)
|
|
self.fc13 = torch.nn.Linear(2 * D, D)
|
|
self.fc14 = torch.nn.Linear(2 * D, D)
|
|
self.fc15 = torch.nn.Linear(D, D)
|
|
self.fc16 = torch.nn.Linear(D, D)
|
|
self.K = K
|
|
self.d = d
|
|
self.device = device
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
self.dropout = nn.Dropout(p=0.1)
|
|
|
|
def forward(self, X, STE, Mask=True):
|
|
X = torch.cat((X, STE), dim=-1)
|
|
query = F.gelu(self.fc12(X))
|
|
key = F.gelu(self.fc13(X))
|
|
value = F.gelu(self.fc14(X))
|
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
|
query = torch.transpose(query, 2, 1)
|
|
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
|
value = torch.transpose(value, 2, 1)
|
|
attention = torch.matmul(query, key)
|
|
attention /= (self.d ** 0.5)
|
|
if Mask:
|
|
num_steps = X.shape[1]
|
|
mask = torch.ones(num_steps, num_steps).to(self.device)
|
|
mask = torch.tril(mask)
|
|
zero_vec = torch.tensor(-9e15).to(self.device)
|
|
mask = mask.to(torch.bool)
|
|
attention = torch.where(mask, attention, zero_vec)
|
|
attention = self.softmax(attention)
|
|
X = torch.matmul(attention, value)
|
|
X = torch.transpose(X, 2, 1)
|
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
|
X = self.dropout(self.fc16(F.gelu(self.fc15(X))))
|
|
return X
|
|
|
|
|
|
class GatedFusionModel(torch.nn.Module):
|
|
def __init__(self, K, d):
|
|
super(GatedFusionModel, self).__init__()
|
|
D = K * d
|
|
self.fc17 = torch.nn.Linear(D, D)
|
|
self.fc18 = torch.nn.Linear(D, D)
|
|
self.fc19 = torch.nn.Linear(D, D)
|
|
self.fc20 = torch.nn.Linear(D, D)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, HS, HT):
|
|
XS = self.fc17(HS)
|
|
XT = self.fc18(HT)
|
|
z = self.sigmoid(torch.add(XS, XT))
|
|
H = torch.add((z * HS), ((1 - z) * HT))
|
|
H = self.fc20(F.gelu(self.fc19(H)))
|
|
return H
|
|
|
|
|
|
class STAttModel(torch.nn.Module):
|
|
def __init__(self, K, d, adj, device):
|
|
super(STAttModel, self).__init__()
|
|
D = K * d
|
|
self.fc30 = torch.nn.Linear(7 * D, D)
|
|
self.gcn = gcn(K, d)
|
|
self.gcn1 = randomGAT(K, d, adj[0], device)
|
|
self.gcn2 = randomGAT(K, d, adj[0], device)
|
|
self.gcn3 = randomGAT(K, d, adj[1], device)
|
|
self.gcn4 = randomGAT(K, d, adj[1], device)
|
|
self.temporalAttention = TemporalAttentionModel(K, d, device)
|
|
self.gatedFusion = GatedFusionModel(K, d)
|
|
|
|
def forward(self, X, STE, adp, Mask=True):
|
|
HS1 = self.gcn1(X, STE)
|
|
HS2 = self.gcn2(HS1, STE)
|
|
HS3 = self.gcn3(X, STE)
|
|
HS4 = self.gcn4(HS3, STE)
|
|
HS5 = self.gcn(X, STE, adp)
|
|
HS6 = self.gcn(HS5, STE, adp)
|
|
HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1)
|
|
HS = F.gelu(self.fc30(HS))
|
|
HT = self.temporalAttention(X, STE, Mask)
|
|
H = self.gatedFusion(HS, HT)
|
|
return torch.add(X, H)
|
|
|
|
|
|
class TransformAttentionModel(torch.nn.Module):
|
|
def __init__(self, K, d):
|
|
super(TransformAttentionModel, self).__init__()
|
|
D = K * d
|
|
self.fc21 = torch.nn.Linear(D, D)
|
|
self.fc22 = torch.nn.Linear(D, D)
|
|
self.fc23 = torch.nn.Linear(D, D)
|
|
self.fc24 = torch.nn.Linear(D, D)
|
|
self.fc25 = torch.nn.Linear(D, D)
|
|
self.K = K
|
|
self.d = d
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
|
|
def forward(self, X, STE_P, STE_Q):
|
|
query = F.gelu(self.fc21(STE_Q))
|
|
key = F.gelu(self.fc22(STE_P))
|
|
value = F.gelu(self.fc23(X))
|
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
|
query = torch.transpose(query, 2, 1)
|
|
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
|
value = torch.transpose(value, 2, 1)
|
|
attention = torch.matmul(query, key)
|
|
attention /= (self.d ** 0.5)
|
|
attention = self.softmax(attention)
|
|
X = torch.matmul(attention, value)
|
|
X = torch.transpose(X, 2, 1)
|
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
|
X = self.fc25(F.gelu(self.fc24(X)))
|
|
return X
|
|
|
|
|
|
class RGDAN(nn.Module):
|
|
def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes):
|
|
super(RGDAN, self).__init__()
|
|
D = K * d
|
|
self.fc1 = torch.nn.Linear(1, D)
|
|
self.fc2 = torch.nn.Linear(D, D)
|
|
self.STEmb = STEmbModel(SEDims, TEDims, K * d, device)
|
|
self.STAttBlockEnc = STAttModel(K, d, adj, device)
|
|
self.STAttBlockDec = STAttModel(K, d, adj, device)
|
|
self.transformAttention = TransformAttentionModel(K, d)
|
|
self.P = P
|
|
self.L = L
|
|
self.device = device
|
|
self.fc26 = torch.nn.Linear(D, D)
|
|
self.fc27 = torch.nn.Linear(D, 1)
|
|
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
|
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
|
self.dropout = nn.Dropout(p=0.1)
|
|
|
|
def forward(self, X, SE, TE):
|
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
|
X = self.fc2(F.gelu(self.fc1(X)))
|
|
STE = self.STEmb(SE, TE)
|
|
STE_P = STE[:, : self.P]
|
|
STE_Q = STE[:, self.P:]
|
|
X = self.STAttBlockEnc(X, STE_P, adp, Mask=True)
|
|
X = self.transformAttention(X, STE_P, STE_Q)
|
|
X = self.STAttBlockDec(X, STE_Q, adp, Mask=True)
|
|
X = self.fc27(self.dropout(F.gelu(self.fc26(X))))
|
|
return X.squeeze(3)
|
|
|
|
|
|
class RGDANModel(nn.Module):
|
|
"""Wrapper to integrate RGDAN with TrafficWheel pipeline.
|
|
|
|
Expects dataloader to provide tensors shaped as:
|
|
- X: [B, T_in, N, F] where F>=1 and we use channel 0
|
|
- Y: [B, T_out, N, F]
|
|
We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended).
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super(RGDANModel, self).__init__()
|
|
|
|
self.args = args
|
|
self.device = args.get('device', 'cpu')
|
|
self.num_nodes = args['num_nodes']
|
|
self.input_dim = args['input_dim']
|
|
self.output_dim = args['output_dim']
|
|
self.P = args.get('lag', args.get('history', 12))
|
|
self.L = args.get('horizon', 12)
|
|
|
|
# RGDAN hyper-params with defaults
|
|
self.K = args.get('K', 3)
|
|
self.d = args.get('d', 8)
|
|
self.SEDims = args.get('SEDims', 16)
|
|
self.TEDims = args.get('TEDims', 288 + 7)
|
|
|
|
# adjacency set (two views expected by STAttModel)
|
|
# use distance matrix from get_adj. Build two masks: forward and backward edges
|
|
adj_distance = get_adj({'num_nodes': self.num_nodes})
|
|
adj = []
|
|
if adj_distance is None:
|
|
base = torch.ones(self.num_nodes, self.num_nodes, device=self.device)
|
|
adj = [base, base]
|
|
else:
|
|
base = torch.from_numpy(adj_distance).float().to(self.device)
|
|
adj = [base, base.T]
|
|
|
|
self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True)
|
|
|
|
self.rgdan = RGDAN(
|
|
K=self.K,
|
|
d=self.d,
|
|
SEDims=self.SEDims,
|
|
TEDims=self.TEDims,
|
|
P=self.P,
|
|
L=self.L,
|
|
device=self.device,
|
|
adj=adj,
|
|
num_nodes=self.num_nodes,
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week]
|
|
x0 = x[..., 0:1]
|
|
|
|
steps_per_day = self.args.get('steps_per_day', 288)
|
|
days_per_week = self.args.get('days_per_week', 7)
|
|
B, T_in, N, F_total = x.shape
|
|
T_out = self.L
|
|
|
|
# Extract TE for observed window from appended channels (constant across nodes)
|
|
time_in_day_cont = x[:, :, 0, -2] # [B, T_in]
|
|
day_in_week_cont = x[:, :, 0, -1] # [B, T_in]
|
|
tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long()
|
|
dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long()
|
|
|
|
# Extrapolate TE for horizon
|
|
last_tod = tod_idx[:, -1] # [B]
|
|
last_dow = dow_idx[:, -1] # [B]
|
|
offsets = torch.arange(1, T_out + 1, device=x.device)
|
|
future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0)
|
|
future_tod = (future_tod_linear % steps_per_day).long()
|
|
carry_days = (future_tod_linear // steps_per_day).long()
|
|
future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week
|
|
|
|
TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2]
|
|
TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2]
|
|
TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2]
|
|
|
|
# SE: node static embeddings [N, SEDims]
|
|
SE = self.se_embedding
|
|
|
|
y = self.rgdan(x0, SE, TE)
|
|
# Output: [B, T_out, N]
|
|
if y.dim() == 3:
|
|
y = y.unsqueeze(-1)
|
|
return y
|
|
|
|
|