add RGDAN
This commit is contained in:
parent
fa6eb90d65
commit
9e22712d77
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 358
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
K: 3
|
||||||
|
d: 8
|
||||||
|
SEDims: 16
|
||||||
|
TEDims: 295
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 307
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
K: 3
|
||||||
|
d: 8
|
||||||
|
SEDims: 16
|
||||||
|
TEDims: 295 # 7 + 288
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 883
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
K: 3
|
||||||
|
d: 8
|
||||||
|
SEDims: 16
|
||||||
|
TEDims: 295
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 8 # larger graph may need smaller batch
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 170
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
K: 3
|
||||||
|
d: 8
|
||||||
|
SEDims: 16
|
||||||
|
TEDims: 295
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ data:
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
sample: 1
|
sample: 1
|
||||||
input_dim: 3
|
input_dim: 3
|
||||||
batch_size: 8
|
batch_size: 64
|
||||||
|
|
||||||
model:
|
model:
|
||||||
type: 'STEP'
|
type: 'STEP'
|
||||||
|
|
@ -68,7 +68,7 @@ model:
|
||||||
train:
|
train:
|
||||||
loss_func: mae
|
loss_func: mae
|
||||||
seed: 10
|
seed: 10
|
||||||
batch_size: 8
|
batch_size: 64
|
||||||
epochs: 100
|
epochs: 100
|
||||||
lr_init: 0.002
|
lr_init: 0.002
|
||||||
weight_decay: 1.0e-5
|
weight_decay: 1.0e-5
|
||||||
|
|
|
||||||
|
|
@ -141,10 +141,10 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
decoder_hidden_state)
|
decoder_hidden_state)
|
||||||
decoder_input = decoder_output
|
decoder_input = decoder_output
|
||||||
outputs.append(decoder_output)
|
outputs.append(decoder_output)
|
||||||
if self.training and self.use_curriculum_learning:
|
# if self.training and self.use_curriculum_learning:
|
||||||
c = np.random.uniform(0, 1)
|
# c = np.random.uniform(0, 1)
|
||||||
if c < self._compute_sampling_threshold(batches_seen):
|
# if c < self._compute_sampling_threshold(batches_seen):
|
||||||
decoder_input = labels[t]
|
# decoder_input = labels[t]
|
||||||
outputs = torch.stack(outputs)
|
outputs = torch.stack(outputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,348 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from data.get_adj import get_adj
|
||||||
|
|
||||||
|
|
||||||
|
class gcn(torch.nn.Module):
|
||||||
|
def __init__(self, k, d):
|
||||||
|
super(gcn, self).__init__()
|
||||||
|
D = k * d
|
||||||
|
self.fc = torch.nn.Linear(2 * D, D)
|
||||||
|
self.dropout = nn.Dropout(p=0.1)
|
||||||
|
|
||||||
|
def forward(self, X, STE, A):
|
||||||
|
X = torch.cat((X, STE), dim=-1)
|
||||||
|
H = F.gelu(self.fc(X))
|
||||||
|
H = torch.einsum('ncvl,vw->ncwl', (H, A))
|
||||||
|
return self.dropout(H.contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
class randomGAT(torch.nn.Module):
|
||||||
|
def __init__(self, k, d, adj, device):
|
||||||
|
super(randomGAT, self).__init__()
|
||||||
|
D = k * d
|
||||||
|
self.d = d
|
||||||
|
self.K = k
|
||||||
|
num_nodes = adj.shape[0]
|
||||||
|
self.device = device
|
||||||
|
self.fc = torch.nn.Linear(2 * D, D)
|
||||||
|
self.adj = adj
|
||||||
|
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
||||||
|
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
||||||
|
|
||||||
|
def forward(self, X, STE):
|
||||||
|
X = torch.cat((X, STE), dim=-1)
|
||||||
|
H = F.gelu(self.fc(X))
|
||||||
|
H = torch.cat(torch.split(H, self.d, dim=-1), dim=0)
|
||||||
|
adp = torch.mm(self.nodevec1, self.nodevec2)
|
||||||
|
zero_vec = torch.tensor(-9e15).to(self.device)
|
||||||
|
adp = torch.where(self.adj > 0, adp, zero_vec)
|
||||||
|
adj = F.softmax(adp, dim=-1)
|
||||||
|
H = torch.einsum('vw,ncwl->ncvl', (adj, H))
|
||||||
|
H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1)
|
||||||
|
return F.gelu(H.contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
class STEmbModel(torch.nn.Module):
|
||||||
|
def __init__(self, SEDims, TEDims, OutDims, device):
|
||||||
|
super(STEmbModel, self).__init__()
|
||||||
|
self.TEDims = TEDims
|
||||||
|
self.fc3 = torch.nn.Linear(SEDims, OutDims)
|
||||||
|
self.fc4 = torch.nn.Linear(OutDims, OutDims)
|
||||||
|
self.fc5 = torch.nn.Linear(TEDims, OutDims)
|
||||||
|
self.fc6 = torch.nn.Linear(OutDims, OutDims)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(self, SE, TE):
|
||||||
|
SE = SE.unsqueeze(0).unsqueeze(0)
|
||||||
|
SE = self.fc4(F.gelu(self.fc3(SE)))
|
||||||
|
dayofweek = F.one_hot(TE[..., 0], num_classes=7)
|
||||||
|
timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7)
|
||||||
|
TE = torch.cat((dayofweek, timeofday), dim=-1)
|
||||||
|
TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device)
|
||||||
|
TE = self.fc6(F.gelu(self.fc5(TE)))
|
||||||
|
sum_tensor = torch.add(SE, TE)
|
||||||
|
return sum_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttentionModel(torch.nn.Module):
|
||||||
|
def __init__(self, K, d, adj, dropout=0.3, mask=True):
|
||||||
|
super(SpatialAttentionModel, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc7 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc8 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc9 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc10 = torch.nn.Linear(D, D)
|
||||||
|
self.fc11 = torch.nn.Linear(D, D)
|
||||||
|
self.K = K
|
||||||
|
self.d = d
|
||||||
|
self.adj = adj
|
||||||
|
self.mask = mask
|
||||||
|
self.dropout = dropout
|
||||||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, X, STE):
|
||||||
|
X = torch.cat((X, STE), dim=-1)
|
||||||
|
query = F.gelu(self.fc7(X))
|
||||||
|
key = F.gelu(self.fc8(X))
|
||||||
|
value = F.gelu(self.fc9(X))
|
||||||
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||||
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||||
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||||
|
attention = torch.matmul(query, torch.transpose(key, 2, 3))
|
||||||
|
attention /= (self.d ** 0.5)
|
||||||
|
if self.mask:
|
||||||
|
zero_vec = -9e15 * torch.ones_like(attention)
|
||||||
|
attention = torch.where(self.adj > 0, attention, zero_vec)
|
||||||
|
attention = self.softmax(attention)
|
||||||
|
X = torch.matmul(attention, value)
|
||||||
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||||
|
X = self.fc11(F.gelu(self.fc10(X)))
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalAttentionModel(torch.nn.Module):
|
||||||
|
def __init__(self, K, d, device):
|
||||||
|
super(TemporalAttentionModel, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc12 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc13 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc14 = torch.nn.Linear(2 * D, D)
|
||||||
|
self.fc15 = torch.nn.Linear(D, D)
|
||||||
|
self.fc16 = torch.nn.Linear(D, D)
|
||||||
|
self.K = K
|
||||||
|
self.d = d
|
||||||
|
self.device = device
|
||||||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
|
self.dropout = nn.Dropout(p=0.1)
|
||||||
|
|
||||||
|
def forward(self, X, STE, Mask=True):
|
||||||
|
X = torch.cat((X, STE), dim=-1)
|
||||||
|
query = F.gelu(self.fc12(X))
|
||||||
|
key = F.gelu(self.fc13(X))
|
||||||
|
value = F.gelu(self.fc14(X))
|
||||||
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||||
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||||
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||||
|
query = torch.transpose(query, 2, 1)
|
||||||
|
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
||||||
|
value = torch.transpose(value, 2, 1)
|
||||||
|
attention = torch.matmul(query, key)
|
||||||
|
attention /= (self.d ** 0.5)
|
||||||
|
if Mask:
|
||||||
|
num_steps = X.shape[1]
|
||||||
|
mask = torch.ones(num_steps, num_steps).to(self.device)
|
||||||
|
mask = torch.tril(mask)
|
||||||
|
zero_vec = torch.tensor(-9e15).to(self.device)
|
||||||
|
mask = mask.to(torch.bool)
|
||||||
|
attention = torch.where(mask, attention, zero_vec)
|
||||||
|
attention = self.softmax(attention)
|
||||||
|
X = torch.matmul(attention, value)
|
||||||
|
X = torch.transpose(X, 2, 1)
|
||||||
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||||
|
X = self.dropout(self.fc16(F.gelu(self.fc15(X))))
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class GatedFusionModel(torch.nn.Module):
|
||||||
|
def __init__(self, K, d):
|
||||||
|
super(GatedFusionModel, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc17 = torch.nn.Linear(D, D)
|
||||||
|
self.fc18 = torch.nn.Linear(D, D)
|
||||||
|
self.fc19 = torch.nn.Linear(D, D)
|
||||||
|
self.fc20 = torch.nn.Linear(D, D)
|
||||||
|
self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, HS, HT):
|
||||||
|
XS = self.fc17(HS)
|
||||||
|
XT = self.fc18(HT)
|
||||||
|
z = self.sigmoid(torch.add(XS, XT))
|
||||||
|
H = torch.add((z * HS), ((1 - z) * HT))
|
||||||
|
H = self.fc20(F.gelu(self.fc19(H)))
|
||||||
|
return H
|
||||||
|
|
||||||
|
|
||||||
|
class STAttModel(torch.nn.Module):
|
||||||
|
def __init__(self, K, d, adj, device):
|
||||||
|
super(STAttModel, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc30 = torch.nn.Linear(7 * D, D)
|
||||||
|
self.gcn = gcn(K, d)
|
||||||
|
self.gcn1 = randomGAT(K, d, adj[0], device)
|
||||||
|
self.gcn2 = randomGAT(K, d, adj[0], device)
|
||||||
|
self.gcn3 = randomGAT(K, d, adj[1], device)
|
||||||
|
self.gcn4 = randomGAT(K, d, adj[1], device)
|
||||||
|
self.temporalAttention = TemporalAttentionModel(K, d, device)
|
||||||
|
self.gatedFusion = GatedFusionModel(K, d)
|
||||||
|
|
||||||
|
def forward(self, X, STE, adp, Mask=True):
|
||||||
|
HS1 = self.gcn1(X, STE)
|
||||||
|
HS2 = self.gcn2(HS1, STE)
|
||||||
|
HS3 = self.gcn3(X, STE)
|
||||||
|
HS4 = self.gcn4(HS3, STE)
|
||||||
|
HS5 = self.gcn(X, STE, adp)
|
||||||
|
HS6 = self.gcn(HS5, STE, adp)
|
||||||
|
HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1)
|
||||||
|
HS = F.gelu(self.fc30(HS))
|
||||||
|
HT = self.temporalAttention(X, STE, Mask)
|
||||||
|
H = self.gatedFusion(HS, HT)
|
||||||
|
return torch.add(X, H)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformAttentionModel(torch.nn.Module):
|
||||||
|
def __init__(self, K, d):
|
||||||
|
super(TransformAttentionModel, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc21 = torch.nn.Linear(D, D)
|
||||||
|
self.fc22 = torch.nn.Linear(D, D)
|
||||||
|
self.fc23 = torch.nn.Linear(D, D)
|
||||||
|
self.fc24 = torch.nn.Linear(D, D)
|
||||||
|
self.fc25 = torch.nn.Linear(D, D)
|
||||||
|
self.K = K
|
||||||
|
self.d = d
|
||||||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, X, STE_P, STE_Q):
|
||||||
|
query = F.gelu(self.fc21(STE_Q))
|
||||||
|
key = F.gelu(self.fc22(STE_P))
|
||||||
|
value = F.gelu(self.fc23(X))
|
||||||
|
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||||
|
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||||
|
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||||
|
query = torch.transpose(query, 2, 1)
|
||||||
|
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
||||||
|
value = torch.transpose(value, 2, 1)
|
||||||
|
attention = torch.matmul(query, key)
|
||||||
|
attention /= (self.d ** 0.5)
|
||||||
|
attention = self.softmax(attention)
|
||||||
|
X = torch.matmul(attention, value)
|
||||||
|
X = torch.transpose(X, 2, 1)
|
||||||
|
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||||
|
X = self.fc25(F.gelu(self.fc24(X)))
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class RGDAN(nn.Module):
|
||||||
|
def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes):
|
||||||
|
super(RGDAN, self).__init__()
|
||||||
|
D = K * d
|
||||||
|
self.fc1 = torch.nn.Linear(1, D)
|
||||||
|
self.fc2 = torch.nn.Linear(D, D)
|
||||||
|
self.STEmb = STEmbModel(SEDims, TEDims, K * d, device)
|
||||||
|
self.STAttBlockEnc = STAttModel(K, d, adj, device)
|
||||||
|
self.STAttBlockDec = STAttModel(K, d, adj, device)
|
||||||
|
self.transformAttention = TransformAttentionModel(K, d)
|
||||||
|
self.P = P
|
||||||
|
self.L = L
|
||||||
|
self.device = device
|
||||||
|
self.fc26 = torch.nn.Linear(D, D)
|
||||||
|
self.fc27 = torch.nn.Linear(D, 1)
|
||||||
|
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
||||||
|
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
||||||
|
self.dropout = nn.Dropout(p=0.1)
|
||||||
|
|
||||||
|
def forward(self, X, SE, TE):
|
||||||
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||||
|
X = self.fc2(F.gelu(self.fc1(X)))
|
||||||
|
STE = self.STEmb(SE, TE)
|
||||||
|
STE_P = STE[:, : self.P]
|
||||||
|
STE_Q = STE[:, self.P:]
|
||||||
|
X = self.STAttBlockEnc(X, STE_P, adp, Mask=True)
|
||||||
|
X = self.transformAttention(X, STE_P, STE_Q)
|
||||||
|
X = self.STAttBlockDec(X, STE_Q, adp, Mask=True)
|
||||||
|
X = self.fc27(self.dropout(F.gelu(self.fc26(X))))
|
||||||
|
return X.squeeze(3)
|
||||||
|
|
||||||
|
|
||||||
|
class RGDANModel(nn.Module):
|
||||||
|
"""Wrapper to integrate RGDAN with TrafficWheel pipeline.
|
||||||
|
|
||||||
|
Expects dataloader to provide tensors shaped as:
|
||||||
|
- X: [B, T_in, N, F] where F>=1 and we use channel 0
|
||||||
|
- Y: [B, T_out, N, F]
|
||||||
|
We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
super(RGDANModel, self).__init__()
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
self.device = args.get('device', 'cpu')
|
||||||
|
self.num_nodes = args['num_nodes']
|
||||||
|
self.input_dim = args['input_dim']
|
||||||
|
self.output_dim = args['output_dim']
|
||||||
|
self.P = args.get('lag', args.get('history', 12))
|
||||||
|
self.L = args.get('horizon', 12)
|
||||||
|
|
||||||
|
# RGDAN hyper-params with defaults
|
||||||
|
self.K = args.get('K', 3)
|
||||||
|
self.d = args.get('d', 8)
|
||||||
|
self.SEDims = args.get('SEDims', 16)
|
||||||
|
self.TEDims = args.get('TEDims', 288 + 7)
|
||||||
|
|
||||||
|
# adjacency set (two views expected by STAttModel)
|
||||||
|
# use distance matrix from get_adj. Build two masks: forward and backward edges
|
||||||
|
adj_distance = get_adj({'num_nodes': self.num_nodes})
|
||||||
|
adj = []
|
||||||
|
if adj_distance is None:
|
||||||
|
base = torch.ones(self.num_nodes, self.num_nodes, device=self.device)
|
||||||
|
adj = [base, base]
|
||||||
|
else:
|
||||||
|
base = torch.from_numpy(adj_distance).float().to(self.device)
|
||||||
|
adj = [base, base.T]
|
||||||
|
|
||||||
|
self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True)
|
||||||
|
|
||||||
|
self.rgdan = RGDAN(
|
||||||
|
K=self.K,
|
||||||
|
d=self.d,
|
||||||
|
SEDims=self.SEDims,
|
||||||
|
TEDims=self.TEDims,
|
||||||
|
P=self.P,
|
||||||
|
L=self.L,
|
||||||
|
device=self.device,
|
||||||
|
adj=adj,
|
||||||
|
num_nodes=self.num_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week]
|
||||||
|
x0 = x[..., 0:1]
|
||||||
|
|
||||||
|
steps_per_day = self.args.get('steps_per_day', 288)
|
||||||
|
days_per_week = self.args.get('days_per_week', 7)
|
||||||
|
B, T_in, N, F_total = x.shape
|
||||||
|
T_out = self.L
|
||||||
|
|
||||||
|
# Extract TE for observed window from appended channels (constant across nodes)
|
||||||
|
time_in_day_cont = x[:, :, 0, -2] # [B, T_in]
|
||||||
|
day_in_week_cont = x[:, :, 0, -1] # [B, T_in]
|
||||||
|
tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long()
|
||||||
|
dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long()
|
||||||
|
|
||||||
|
# Extrapolate TE for horizon
|
||||||
|
last_tod = tod_idx[:, -1] # [B]
|
||||||
|
last_dow = dow_idx[:, -1] # [B]
|
||||||
|
offsets = torch.arange(1, T_out + 1, device=x.device)
|
||||||
|
future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0)
|
||||||
|
future_tod = (future_tod_linear % steps_per_day).long()
|
||||||
|
carry_days = (future_tod_linear // steps_per_day).long()
|
||||||
|
future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week
|
||||||
|
|
||||||
|
TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2]
|
||||||
|
TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2]
|
||||||
|
TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2]
|
||||||
|
|
||||||
|
# SE: node static embeddings [N, SEDims]
|
||||||
|
SE = self.se_embedding
|
||||||
|
|
||||||
|
y = self.rgdan(x0, SE, TE)
|
||||||
|
# Output: [B, T_out, N]
|
||||||
|
if y.dim() == 3:
|
||||||
|
y = y.unsqueeze(-1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,6 +22,8 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
||||||
from model.ST_SSL.ST_SSL import STSSLModel
|
from model.ST_SSL.ST_SSL import STSSLModel
|
||||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||||
from model.STAWnet.STAWnet import STAWnet
|
from model.STAWnet.STAWnet import STAWnet
|
||||||
|
from model.STEP.STEP import STEP
|
||||||
|
from model.RGDAN.RGDAN import RGDANModel
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -49,4 +51,6 @@ def model_selector(model):
|
||||||
case 'ST_SSL': return STSSLModel(model)
|
case 'ST_SSL': return STSSLModel(model)
|
||||||
case 'STGNRDE': return make_nrde_model(model)
|
case 'STGNRDE': return make_nrde_model(model)
|
||||||
case 'STAWnet': return STAWnet(model)
|
case 'STAWnet': return STAWnet(model)
|
||||||
|
case 'STEP': return STEP(model)
|
||||||
|
case 'RGDAN': return RGDANModel(model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 566e2738da2d83f055718d8edb609ad8dc325204
|
||||||
|
|
@ -13,7 +13,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
|
||||||
lr_scheduler, kwargs[0], None)
|
lr_scheduler, kwargs[0], None)
|
||||||
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
lr_scheduler, kwargs[0], None)
|
lr_scheduler, kwargs[0], None)
|
||||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case 'DCRNN': return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
lr_scheduler)
|
lr_scheduler)
|
||||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
lr_scheduler)
|
lr_scheduler)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue