add RGDAN
This commit is contained in:
parent
fa6eb90d65
commit
9e22712d77
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
num_nodes: 358
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
K: 3
|
||||
d: 8
|
||||
SEDims: 16
|
||||
TEDims: 295
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
num_nodes: 307
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
K: 3
|
||||
d: 8
|
||||
SEDims: 16
|
||||
TEDims: 295 # 7 + 288
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
num_nodes: 883
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
K: 3
|
||||
d: 8
|
||||
SEDims: 16
|
||||
TEDims: 295
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 8 # larger graph may need smaller batch
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
num_nodes: 170
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
K: 3
|
||||
d: 8
|
||||
SEDims: 16
|
||||
TEDims: 295
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ data:
|
|||
days_per_week: 7
|
||||
sample: 1
|
||||
input_dim: 3
|
||||
batch_size: 8
|
||||
batch_size: 64
|
||||
|
||||
model:
|
||||
type: 'STEP'
|
||||
|
|
@ -68,7 +68,7 @@ model:
|
|||
train:
|
||||
loss_func: mae
|
||||
seed: 10
|
||||
batch_size: 8
|
||||
batch_size: 64
|
||||
epochs: 100
|
||||
lr_init: 0.002
|
||||
weight_decay: 1.0e-5
|
||||
|
|
|
|||
|
|
@ -141,10 +141,10 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
decoder_hidden_state)
|
||||
decoder_input = decoder_output
|
||||
outputs.append(decoder_output)
|
||||
if self.training and self.use_curriculum_learning:
|
||||
c = np.random.uniform(0, 1)
|
||||
if c < self._compute_sampling_threshold(batches_seen):
|
||||
decoder_input = labels[t]
|
||||
# if self.training and self.use_curriculum_learning:
|
||||
# c = np.random.uniform(0, 1)
|
||||
# if c < self._compute_sampling_threshold(batches_seen):
|
||||
# decoder_input = labels[t]
|
||||
outputs = torch.stack(outputs)
|
||||
return outputs
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,348 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data.get_adj import get_adj
|
||||
|
||||
|
||||
class gcn(torch.nn.Module):
|
||||
def __init__(self, k, d):
|
||||
super(gcn, self).__init__()
|
||||
D = k * d
|
||||
self.fc = torch.nn.Linear(2 * D, D)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
def forward(self, X, STE, A):
|
||||
X = torch.cat((X, STE), dim=-1)
|
||||
H = F.gelu(self.fc(X))
|
||||
H = torch.einsum('ncvl,vw->ncwl', (H, A))
|
||||
return self.dropout(H.contiguous())
|
||||
|
||||
|
||||
class randomGAT(torch.nn.Module):
|
||||
def __init__(self, k, d, adj, device):
|
||||
super(randomGAT, self).__init__()
|
||||
D = k * d
|
||||
self.d = d
|
||||
self.K = k
|
||||
num_nodes = adj.shape[0]
|
||||
self.device = device
|
||||
self.fc = torch.nn.Linear(2 * D, D)
|
||||
self.adj = adj
|
||||
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
||||
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
||||
|
||||
def forward(self, X, STE):
|
||||
X = torch.cat((X, STE), dim=-1)
|
||||
H = F.gelu(self.fc(X))
|
||||
H = torch.cat(torch.split(H, self.d, dim=-1), dim=0)
|
||||
adp = torch.mm(self.nodevec1, self.nodevec2)
|
||||
zero_vec = torch.tensor(-9e15).to(self.device)
|
||||
adp = torch.where(self.adj > 0, adp, zero_vec)
|
||||
adj = F.softmax(adp, dim=-1)
|
||||
H = torch.einsum('vw,ncwl->ncvl', (adj, H))
|
||||
H = torch.cat(torch.split(H, H.shape[0] // self.K, dim=0), dim=-1)
|
||||
return F.gelu(H.contiguous())
|
||||
|
||||
|
||||
class STEmbModel(torch.nn.Module):
|
||||
def __init__(self, SEDims, TEDims, OutDims, device):
|
||||
super(STEmbModel, self).__init__()
|
||||
self.TEDims = TEDims
|
||||
self.fc3 = torch.nn.Linear(SEDims, OutDims)
|
||||
self.fc4 = torch.nn.Linear(OutDims, OutDims)
|
||||
self.fc5 = torch.nn.Linear(TEDims, OutDims)
|
||||
self.fc6 = torch.nn.Linear(OutDims, OutDims)
|
||||
self.device = device
|
||||
|
||||
def forward(self, SE, TE):
|
||||
SE = SE.unsqueeze(0).unsqueeze(0)
|
||||
SE = self.fc4(F.gelu(self.fc3(SE)))
|
||||
dayofweek = F.one_hot(TE[..., 0], num_classes=7)
|
||||
timeofday = F.one_hot(TE[..., 1], num_classes=self.TEDims - 7)
|
||||
TE = torch.cat((dayofweek, timeofday), dim=-1)
|
||||
TE = TE.unsqueeze(2).type(torch.FloatTensor).to(self.device)
|
||||
TE = self.fc6(F.gelu(self.fc5(TE)))
|
||||
sum_tensor = torch.add(SE, TE)
|
||||
return sum_tensor
|
||||
|
||||
|
||||
class SpatialAttentionModel(torch.nn.Module):
|
||||
def __init__(self, K, d, adj, dropout=0.3, mask=True):
|
||||
super(SpatialAttentionModel, self).__init__()
|
||||
D = K * d
|
||||
self.fc7 = torch.nn.Linear(2 * D, D)
|
||||
self.fc8 = torch.nn.Linear(2 * D, D)
|
||||
self.fc9 = torch.nn.Linear(2 * D, D)
|
||||
self.fc10 = torch.nn.Linear(D, D)
|
||||
self.fc11 = torch.nn.Linear(D, D)
|
||||
self.K = K
|
||||
self.d = d
|
||||
self.adj = adj
|
||||
self.mask = mask
|
||||
self.dropout = dropout
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, X, STE):
|
||||
X = torch.cat((X, STE), dim=-1)
|
||||
query = F.gelu(self.fc7(X))
|
||||
key = F.gelu(self.fc8(X))
|
||||
value = F.gelu(self.fc9(X))
|
||||
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||
attention = torch.matmul(query, torch.transpose(key, 2, 3))
|
||||
attention /= (self.d ** 0.5)
|
||||
if self.mask:
|
||||
zero_vec = -9e15 * torch.ones_like(attention)
|
||||
attention = torch.where(self.adj > 0, attention, zero_vec)
|
||||
attention = self.softmax(attention)
|
||||
X = torch.matmul(attention, value)
|
||||
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||
X = self.fc11(F.gelu(self.fc10(X)))
|
||||
return X
|
||||
|
||||
|
||||
class TemporalAttentionModel(torch.nn.Module):
|
||||
def __init__(self, K, d, device):
|
||||
super(TemporalAttentionModel, self).__init__()
|
||||
D = K * d
|
||||
self.fc12 = torch.nn.Linear(2 * D, D)
|
||||
self.fc13 = torch.nn.Linear(2 * D, D)
|
||||
self.fc14 = torch.nn.Linear(2 * D, D)
|
||||
self.fc15 = torch.nn.Linear(D, D)
|
||||
self.fc16 = torch.nn.Linear(D, D)
|
||||
self.K = K
|
||||
self.d = d
|
||||
self.device = device
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
def forward(self, X, STE, Mask=True):
|
||||
X = torch.cat((X, STE), dim=-1)
|
||||
query = F.gelu(self.fc12(X))
|
||||
key = F.gelu(self.fc13(X))
|
||||
value = F.gelu(self.fc14(X))
|
||||
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||
query = torch.transpose(query, 2, 1)
|
||||
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
||||
value = torch.transpose(value, 2, 1)
|
||||
attention = torch.matmul(query, key)
|
||||
attention /= (self.d ** 0.5)
|
||||
if Mask:
|
||||
num_steps = X.shape[1]
|
||||
mask = torch.ones(num_steps, num_steps).to(self.device)
|
||||
mask = torch.tril(mask)
|
||||
zero_vec = torch.tensor(-9e15).to(self.device)
|
||||
mask = mask.to(torch.bool)
|
||||
attention = torch.where(mask, attention, zero_vec)
|
||||
attention = self.softmax(attention)
|
||||
X = torch.matmul(attention, value)
|
||||
X = torch.transpose(X, 2, 1)
|
||||
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||
X = self.dropout(self.fc16(F.gelu(self.fc15(X))))
|
||||
return X
|
||||
|
||||
|
||||
class GatedFusionModel(torch.nn.Module):
|
||||
def __init__(self, K, d):
|
||||
super(GatedFusionModel, self).__init__()
|
||||
D = K * d
|
||||
self.fc17 = torch.nn.Linear(D, D)
|
||||
self.fc18 = torch.nn.Linear(D, D)
|
||||
self.fc19 = torch.nn.Linear(D, D)
|
||||
self.fc20 = torch.nn.Linear(D, D)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, HS, HT):
|
||||
XS = self.fc17(HS)
|
||||
XT = self.fc18(HT)
|
||||
z = self.sigmoid(torch.add(XS, XT))
|
||||
H = torch.add((z * HS), ((1 - z) * HT))
|
||||
H = self.fc20(F.gelu(self.fc19(H)))
|
||||
return H
|
||||
|
||||
|
||||
class STAttModel(torch.nn.Module):
|
||||
def __init__(self, K, d, adj, device):
|
||||
super(STAttModel, self).__init__()
|
||||
D = K * d
|
||||
self.fc30 = torch.nn.Linear(7 * D, D)
|
||||
self.gcn = gcn(K, d)
|
||||
self.gcn1 = randomGAT(K, d, adj[0], device)
|
||||
self.gcn2 = randomGAT(K, d, adj[0], device)
|
||||
self.gcn3 = randomGAT(K, d, adj[1], device)
|
||||
self.gcn4 = randomGAT(K, d, adj[1], device)
|
||||
self.temporalAttention = TemporalAttentionModel(K, d, device)
|
||||
self.gatedFusion = GatedFusionModel(K, d)
|
||||
|
||||
def forward(self, X, STE, adp, Mask=True):
|
||||
HS1 = self.gcn1(X, STE)
|
||||
HS2 = self.gcn2(HS1, STE)
|
||||
HS3 = self.gcn3(X, STE)
|
||||
HS4 = self.gcn4(HS3, STE)
|
||||
HS5 = self.gcn(X, STE, adp)
|
||||
HS6 = self.gcn(HS5, STE, adp)
|
||||
HS = torch.cat((X, HS1, HS2, HS3, HS4, HS5, HS6), dim=-1)
|
||||
HS = F.gelu(self.fc30(HS))
|
||||
HT = self.temporalAttention(X, STE, Mask)
|
||||
H = self.gatedFusion(HS, HT)
|
||||
return torch.add(X, H)
|
||||
|
||||
|
||||
class TransformAttentionModel(torch.nn.Module):
|
||||
def __init__(self, K, d):
|
||||
super(TransformAttentionModel, self).__init__()
|
||||
D = K * d
|
||||
self.fc21 = torch.nn.Linear(D, D)
|
||||
self.fc22 = torch.nn.Linear(D, D)
|
||||
self.fc23 = torch.nn.Linear(D, D)
|
||||
self.fc24 = torch.nn.Linear(D, D)
|
||||
self.fc25 = torch.nn.Linear(D, D)
|
||||
self.K = K
|
||||
self.d = d
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, X, STE_P, STE_Q):
|
||||
query = F.gelu(self.fc21(STE_Q))
|
||||
key = F.gelu(self.fc22(STE_P))
|
||||
value = F.gelu(self.fc23(X))
|
||||
query = torch.cat(torch.split(query, self.d, dim=-1), dim=0)
|
||||
key = torch.cat(torch.split(key, self.d, dim=-1), dim=0)
|
||||
value = torch.cat(torch.split(value, self.d, dim=-1), dim=0)
|
||||
query = torch.transpose(query, 2, 1)
|
||||
key = torch.transpose(torch.transpose(key, 1, 2), 2, 3)
|
||||
value = torch.transpose(value, 2, 1)
|
||||
attention = torch.matmul(query, key)
|
||||
attention /= (self.d ** 0.5)
|
||||
attention = self.softmax(attention)
|
||||
X = torch.matmul(attention, value)
|
||||
X = torch.transpose(X, 2, 1)
|
||||
X = torch.cat(torch.split(X, X.shape[0] // self.K, dim=0), dim=-1)
|
||||
X = self.fc25(F.gelu(self.fc24(X)))
|
||||
return X
|
||||
|
||||
|
||||
class RGDAN(nn.Module):
|
||||
def __init__(self, K, d, SEDims, TEDims, P, L, device, adj, num_nodes):
|
||||
super(RGDAN, self).__init__()
|
||||
D = K * d
|
||||
self.fc1 = torch.nn.Linear(1, D)
|
||||
self.fc2 = torch.nn.Linear(D, D)
|
||||
self.STEmb = STEmbModel(SEDims, TEDims, K * d, device)
|
||||
self.STAttBlockEnc = STAttModel(K, d, adj, device)
|
||||
self.STAttBlockDec = STAttModel(K, d, adj, device)
|
||||
self.transformAttention = TransformAttentionModel(K, d)
|
||||
self.P = P
|
||||
self.L = L
|
||||
self.device = device
|
||||
self.fc26 = torch.nn.Linear(D, D)
|
||||
self.fc27 = torch.nn.Linear(D, 1)
|
||||
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device)
|
||||
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
def forward(self, X, SE, TE):
|
||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||
X = self.fc2(F.gelu(self.fc1(X)))
|
||||
STE = self.STEmb(SE, TE)
|
||||
STE_P = STE[:, : self.P]
|
||||
STE_Q = STE[:, self.P:]
|
||||
X = self.STAttBlockEnc(X, STE_P, adp, Mask=True)
|
||||
X = self.transformAttention(X, STE_P, STE_Q)
|
||||
X = self.STAttBlockDec(X, STE_Q, adp, Mask=True)
|
||||
X = self.fc27(self.dropout(F.gelu(self.fc26(X))))
|
||||
return X.squeeze(3)
|
||||
|
||||
|
||||
class RGDANModel(nn.Module):
|
||||
"""Wrapper to integrate RGDAN with TrafficWheel pipeline.
|
||||
|
||||
Expects dataloader to provide tensors shaped as:
|
||||
- X: [B, T_in, N, F] where F>=1 and we use channel 0
|
||||
- Y: [B, T_out, N, F]
|
||||
We synthesize TE internally via steps_per_day/days_per_week and use learnable SE as zeros (or could be extended).
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(RGDANModel, self).__init__()
|
||||
|
||||
self.args = args
|
||||
self.device = args.get('device', 'cpu')
|
||||
self.num_nodes = args['num_nodes']
|
||||
self.input_dim = args['input_dim']
|
||||
self.output_dim = args['output_dim']
|
||||
self.P = args.get('lag', args.get('history', 12))
|
||||
self.L = args.get('horizon', 12)
|
||||
|
||||
# RGDAN hyper-params with defaults
|
||||
self.K = args.get('K', 3)
|
||||
self.d = args.get('d', 8)
|
||||
self.SEDims = args.get('SEDims', 16)
|
||||
self.TEDims = args.get('TEDims', 288 + 7)
|
||||
|
||||
# adjacency set (two views expected by STAttModel)
|
||||
# use distance matrix from get_adj. Build two masks: forward and backward edges
|
||||
adj_distance = get_adj({'num_nodes': self.num_nodes})
|
||||
adj = []
|
||||
if adj_distance is None:
|
||||
base = torch.ones(self.num_nodes, self.num_nodes, device=self.device)
|
||||
adj = [base, base]
|
||||
else:
|
||||
base = torch.from_numpy(adj_distance).float().to(self.device)
|
||||
adj = [base, base.T]
|
||||
|
||||
self.se_embedding = nn.Parameter(torch.zeros(self.num_nodes, self.SEDims), requires_grad=True)
|
||||
|
||||
self.rgdan = RGDAN(
|
||||
K=self.K,
|
||||
d=self.d,
|
||||
SEDims=self.SEDims,
|
||||
TEDims=self.TEDims,
|
||||
P=self.P,
|
||||
L=self.L,
|
||||
device=self.device,
|
||||
adj=adj,
|
||||
num_nodes=self.num_nodes,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, T_in, N, F_total]; channels = [orig_features..., time_in_day, day_in_week]
|
||||
x0 = x[..., 0:1]
|
||||
|
||||
steps_per_day = self.args.get('steps_per_day', 288)
|
||||
days_per_week = self.args.get('days_per_week', 7)
|
||||
B, T_in, N, F_total = x.shape
|
||||
T_out = self.L
|
||||
|
||||
# Extract TE for observed window from appended channels (constant across nodes)
|
||||
time_in_day_cont = x[:, :, 0, -2] # [B, T_in]
|
||||
day_in_week_cont = x[:, :, 0, -1] # [B, T_in]
|
||||
tod_idx = torch.round(time_in_day_cont * steps_per_day - 1e-6).clamp(0, steps_per_day - 1).long()
|
||||
dow_idx = torch.round(day_in_week_cont).clamp(0, days_per_week - 1).long()
|
||||
|
||||
# Extrapolate TE for horizon
|
||||
last_tod = tod_idx[:, -1] # [B]
|
||||
last_dow = dow_idx[:, -1] # [B]
|
||||
offsets = torch.arange(1, T_out + 1, device=x.device)
|
||||
future_tod_linear = last_tod.unsqueeze(1) + offsets.unsqueeze(0)
|
||||
future_tod = (future_tod_linear % steps_per_day).long()
|
||||
carry_days = (future_tod_linear // steps_per_day).long()
|
||||
future_dow = (last_dow.unsqueeze(1) + carry_days) % days_per_week
|
||||
|
||||
TE_P = torch.stack([dow_idx, tod_idx], dim=-1) # [B, T_in, 2]
|
||||
TE_Q = torch.stack([future_dow, future_tod], dim=-1) # [B, T_out, 2]
|
||||
TE = torch.cat([TE_P, TE_Q], dim=1) # [B, T_in+T_out, 2]
|
||||
|
||||
# SE: node static embeddings [N, SEDims]
|
||||
SE = self.se_embedding
|
||||
|
||||
y = self.rgdan(x0, SE, TE)
|
||||
# Output: [B, T_out, N]
|
||||
if y.dim() == 3:
|
||||
y = y.unsqueeze(-1)
|
||||
return y
|
||||
|
||||
|
||||
|
|
@ -22,6 +22,8 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
|||
from model.ST_SSL.ST_SSL import STSSLModel
|
||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||
from model.STAWnet.STAWnet import STAWnet
|
||||
from model.STEP.STEP import STEP
|
||||
from model.RGDAN.RGDAN import RGDANModel
|
||||
|
||||
def model_selector(model):
|
||||
match model['type']:
|
||||
|
|
@ -49,4 +51,6 @@ def model_selector(model):
|
|||
case 'ST_SSL': return STSSLModel(model)
|
||||
case 'STGNRDE': return make_nrde_model(model)
|
||||
case 'STAWnet': return STAWnet(model)
|
||||
case 'STEP': return STEP(model)
|
||||
case 'RGDAN': return RGDANModel(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 566e2738da2d83f055718d8edb609ad8dc325204
|
||||
|
|
@ -13,7 +13,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
|
|||
lr_scheduler, kwargs[0], None)
|
||||
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler, kwargs[0], None)
|
||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case 'DCRNN': return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler)
|
||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler)
|
||||
|
|
|
|||
Loading…
Reference in New Issue