Compare commits
No commits in common. "5f8c31af2e5a58999b949af1f2ba6394afad3da2" and "bc9a2667c232d7768299dfe7686d2075a3ae77e2" have entirely different histories.
5f8c31af2e
...
bc9a2667c2
|
|
@ -1,48 +0,0 @@
|
||||||
data:
|
|
||||||
num_nodes: 358
|
|
||||||
lag: 12
|
|
||||||
horizon: 12
|
|
||||||
val_ratio: 0.2
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: False
|
|
||||||
normalizer: std
|
|
||||||
column_wise: False
|
|
||||||
default_graph: True
|
|
||||||
add_time_in_day: True
|
|
||||||
add_day_in_week: True
|
|
||||||
steps_per_day: 288
|
|
||||||
days_per_week: 7
|
|
||||||
|
|
||||||
model:
|
|
||||||
input_dim: 3
|
|
||||||
output_dim: 1
|
|
||||||
history: 12
|
|
||||||
horizon: 12
|
|
||||||
granularity: 288
|
|
||||||
dropout: 0.1
|
|
||||||
channels: 32
|
|
||||||
|
|
||||||
|
|
||||||
train:
|
|
||||||
loss_func: mae
|
|
||||||
seed: 10
|
|
||||||
batch_size: 64
|
|
||||||
epochs: 300
|
|
||||||
lr_init: 0.003
|
|
||||||
weight_decay: 0
|
|
||||||
lr_decay: False
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: "5,20,40,70"
|
|
||||||
early_stop: True
|
|
||||||
early_stop_patience: 15
|
|
||||||
grad_norm: False
|
|
||||||
max_grad_norm: 5
|
|
||||||
real_value: True
|
|
||||||
|
|
||||||
test:
|
|
||||||
mae_thresh: null
|
|
||||||
mape_thresh: 0.0
|
|
||||||
|
|
||||||
log:
|
|
||||||
log_step: 200
|
|
||||||
plot: False
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
data:
|
|
||||||
num_nodes: 307
|
|
||||||
lag: 12
|
|
||||||
horizon: 12
|
|
||||||
val_ratio: 0.2
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: False
|
|
||||||
normalizer: std
|
|
||||||
column_wise: False
|
|
||||||
default_graph: True
|
|
||||||
add_time_in_day: True
|
|
||||||
add_day_in_week: True
|
|
||||||
steps_per_day: 288
|
|
||||||
days_per_week: 7
|
|
||||||
|
|
||||||
model:
|
|
||||||
input_dim: 3
|
|
||||||
output_dim: 1
|
|
||||||
history: 12
|
|
||||||
horizon: 12
|
|
||||||
granularity: 288
|
|
||||||
dropout: 0.1
|
|
||||||
channels: 64
|
|
||||||
|
|
||||||
|
|
||||||
train:
|
|
||||||
loss_func: mae
|
|
||||||
seed: 10
|
|
||||||
batch_size: 64
|
|
||||||
epochs: 300
|
|
||||||
lr_init: 0.003
|
|
||||||
weight_decay: 0
|
|
||||||
lr_decay: False
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: "5,20,40,70"
|
|
||||||
early_stop: True
|
|
||||||
early_stop_patience: 15
|
|
||||||
grad_norm: False
|
|
||||||
max_grad_norm: 5
|
|
||||||
real_value: True
|
|
||||||
|
|
||||||
test:
|
|
||||||
mae_thresh: null
|
|
||||||
mape_thresh: 0.0
|
|
||||||
|
|
||||||
log:
|
|
||||||
log_step: 200
|
|
||||||
plot: False
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
data:
|
|
||||||
num_nodes: 883
|
|
||||||
lag: 12
|
|
||||||
horizon: 12
|
|
||||||
val_ratio: 0.2
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: False
|
|
||||||
normalizer: std
|
|
||||||
column_wise: False
|
|
||||||
default_graph: True
|
|
||||||
add_time_in_day: True
|
|
||||||
add_day_in_week: True
|
|
||||||
steps_per_day: 288
|
|
||||||
days_per_week: 7
|
|
||||||
|
|
||||||
model:
|
|
||||||
input_dim: 3
|
|
||||||
output_dim: 1
|
|
||||||
history: 12
|
|
||||||
horizon: 12
|
|
||||||
granularity: 288
|
|
||||||
dropout: 0.1
|
|
||||||
channels: 128
|
|
||||||
|
|
||||||
|
|
||||||
train:
|
|
||||||
loss_func: mae
|
|
||||||
seed: 10
|
|
||||||
batch_size: 16
|
|
||||||
epochs: 300
|
|
||||||
lr_init: 0.003
|
|
||||||
weight_decay: 0
|
|
||||||
lr_decay: False
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: "5,20,40,70"
|
|
||||||
early_stop: True
|
|
||||||
early_stop_patience: 15
|
|
||||||
grad_norm: False
|
|
||||||
max_grad_norm: 5
|
|
||||||
real_value: True
|
|
||||||
|
|
||||||
test:
|
|
||||||
mae_thresh: null
|
|
||||||
mape_thresh: 0.0
|
|
||||||
|
|
||||||
log:
|
|
||||||
log_step: 200
|
|
||||||
plot: False
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
data:
|
|
||||||
num_nodes: 170
|
|
||||||
lag: 12
|
|
||||||
horizon: 12
|
|
||||||
val_ratio: 0.2
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: False
|
|
||||||
normalizer: std
|
|
||||||
column_wise: False
|
|
||||||
default_graph: True
|
|
||||||
add_time_in_day: True
|
|
||||||
add_day_in_week: True
|
|
||||||
steps_per_day: 288
|
|
||||||
days_per_week: 7
|
|
||||||
|
|
||||||
model:
|
|
||||||
input_dim: 3
|
|
||||||
output_dim: 1
|
|
||||||
history: 12
|
|
||||||
horizon: 12
|
|
||||||
granularity: 288
|
|
||||||
dropout: 0.1
|
|
||||||
channels: 96
|
|
||||||
|
|
||||||
|
|
||||||
train:
|
|
||||||
loss_func: mae
|
|
||||||
seed: 10
|
|
||||||
batch_size: 64
|
|
||||||
epochs: 300
|
|
||||||
lr_init: 0.003
|
|
||||||
weight_decay: 0
|
|
||||||
lr_decay: False
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: "5,20,40,70"
|
|
||||||
early_stop: True
|
|
||||||
early_stop_patience: 15
|
|
||||||
grad_norm: False
|
|
||||||
max_grad_norm: 5
|
|
||||||
real_value: True
|
|
||||||
|
|
||||||
test:
|
|
||||||
mae_thresh: null
|
|
||||||
mape_thresh: 0.0
|
|
||||||
|
|
||||||
log:
|
|
||||||
log_step: 200
|
|
||||||
plot: False
|
|
||||||
|
|
@ -1,368 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
|
||||||
def __init__(self, features, dropout=0.1):
|
|
||||||
super(GLU, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(features, features, (1, 1))
|
|
||||||
self.conv2 = nn.Conv2d(features, features, (1, 1))
|
|
||||||
self.conv3 = nn.Conv2d(features, features, (1, 1))
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.conv1(x)
|
|
||||||
x2 = self.conv2(x)
|
|
||||||
out = x1 * torch.sigmoid(x2)
|
|
||||||
out = self.dropout(out)
|
|
||||||
out = self.conv3(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# class TemporalEmbedding(nn.Module):
|
|
||||||
# def __init__(self, time, features):
|
|
||||||
# super(TemporalEmbedding, self).__init__()
|
|
||||||
#
|
|
||||||
# self.time = time
|
|
||||||
# # self.time_day = nn.Parameter(torch.empty(time, features))
|
|
||||||
# # nn.init.xavier_uniform_(self.time_day)
|
|
||||||
# #
|
|
||||||
# # self.time_week = nn.Parameter(torch.empty(7, features))
|
|
||||||
# # nn.init.xavier_uniform_(self.time_week)
|
|
||||||
# self.time_day = nn.Embedding(time, features)
|
|
||||||
# self.time_week = nn.Embedding(7, features)
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# day_emb = x[..., 1]
|
|
||||||
# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
|
||||||
# # time_day = time_day.transpose(1, 2).contiguous()
|
|
||||||
#
|
|
||||||
# week_emb = x[..., 2]
|
|
||||||
# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
|
|
||||||
# # time_week = time_week.transpose(1, 2).contiguous()
|
|
||||||
#
|
|
||||||
# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N)
|
|
||||||
# d_idx = week_emb[:, -1, :, ].long() # (B, N)
|
|
||||||
# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
|
|
||||||
# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
|
|
||||||
#
|
|
||||||
# tem_emb = t_idx + d_idx
|
|
||||||
#
|
|
||||||
# # tem_emb = tem_emb.permute(0, 3, 1, 2)
|
|
||||||
#
|
|
||||||
# return tem_emb
|
|
||||||
class TemporalEmbedding(nn.Module):
|
|
||||||
def __init__(self, time, features):
|
|
||||||
super(TemporalEmbedding, self).__init__()
|
|
||||||
|
|
||||||
self.time = time
|
|
||||||
self.time_day = nn.Parameter(torch.empty(time, features))
|
|
||||||
nn.init.xavier_uniform_(self.time_day)
|
|
||||||
|
|
||||||
self.time_week = nn.Parameter(torch.empty(7, features))
|
|
||||||
nn.init.xavier_uniform_(self.time_week)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
day_emb = x[..., 1]
|
|
||||||
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
|
||||||
time_day = time_day.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
week_emb = x[..., 2]
|
|
||||||
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
|
|
||||||
time_week = time_week.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
tem_emb = time_day + time_week
|
|
||||||
|
|
||||||
tem_emb = tem_emb.permute(0,3,1,2)
|
|
||||||
|
|
||||||
return tem_emb
|
|
||||||
|
|
||||||
class Diffusion_GCN(nn.Module):
|
|
||||||
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.diffusion_step = diffusion_step
|
|
||||||
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x, adj):
|
|
||||||
out = []
|
|
||||||
for i in range(0, self.diffusion_step):
|
|
||||||
if adj.dim() == 3:
|
|
||||||
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
|
|
||||||
out.append(x)
|
|
||||||
elif adj.dim() == 2:
|
|
||||||
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
|
|
||||||
out.append(x)
|
|
||||||
x = torch.cat(out, dim=1)
|
|
||||||
x = self.conv(x)
|
|
||||||
output = self.dropout(x)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Graph_Generator(nn.Module):
|
|
||||||
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
|
||||||
super().__init__()
|
|
||||||
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
|
|
||||||
nn.init.xavier_uniform_(self.memory)
|
|
||||||
self.fc = nn.Linear(2, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
adj_dyn_1 = torch.softmax(
|
|
||||||
F.relu(
|
|
||||||
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
|
||||||
/ math.sqrt(x.shape[1])
|
|
||||||
),
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
adj_dyn_2 = torch.softmax(
|
|
||||||
F.relu(
|
|
||||||
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
|
||||||
/ math.sqrt(x.shape[1])
|
|
||||||
),
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
# adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2
|
|
||||||
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
|
|
||||||
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
|
||||||
|
|
||||||
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
|
|
||||||
mask = torch.zeros_like(adj_f)
|
|
||||||
mask.scatter_(-1, topk_indices, 1)
|
|
||||||
adj_f = adj_f * mask
|
|
||||||
|
|
||||||
return adj_f
|
|
||||||
|
|
||||||
|
|
||||||
class DGCN(nn.Module):
|
|
||||||
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(channels, channels, (1, 1))
|
|
||||||
self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout)
|
|
||||||
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
|
||||||
self.emb = emb
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
skip = x
|
|
||||||
x = self.conv(x)
|
|
||||||
adj_dyn = self.generator(x)
|
|
||||||
x = self.gcn(x, adj_dyn)
|
|
||||||
x = x * self.emb + skip
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Splitting(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Splitting, self).__init__()
|
|
||||||
|
|
||||||
def even(self, x):
|
|
||||||
return x[:, :, :, ::2]
|
|
||||||
|
|
||||||
def odd(self, x):
|
|
||||||
return x[:, :, :, 1::2]
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return (self.even(x), self.odd(x))
|
|
||||||
|
|
||||||
|
|
||||||
class IDGCN(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device,
|
|
||||||
channels=64,
|
|
||||||
diffusion_step=1,
|
|
||||||
splitting=True,
|
|
||||||
num_nodes=170,
|
|
||||||
dropout=0.2, emb=None
|
|
||||||
):
|
|
||||||
super(IDGCN, self).__init__()
|
|
||||||
|
|
||||||
device = device
|
|
||||||
self.dropout = dropout
|
|
||||||
self.num_nodes = num_nodes
|
|
||||||
self.splitting = splitting
|
|
||||||
self.split = Splitting()
|
|
||||||
|
|
||||||
Conv1 = []
|
|
||||||
Conv2 = []
|
|
||||||
Conv3 = []
|
|
||||||
Conv4 = []
|
|
||||||
pad_l = 3
|
|
||||||
pad_r = 3
|
|
||||||
|
|
||||||
k1 = 5
|
|
||||||
k2 = 3
|
|
||||||
Conv1 += [
|
|
||||||
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
|
||||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
|
||||||
nn.Dropout(self.dropout),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
|
||||||
nn.Tanh(),
|
|
||||||
]
|
|
||||||
Conv2 += [
|
|
||||||
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
|
||||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
|
||||||
nn.Dropout(self.dropout),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
|
||||||
nn.Tanh(),
|
|
||||||
]
|
|
||||||
Conv4 += [
|
|
||||||
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
|
||||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
|
||||||
nn.Dropout(self.dropout),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
|
||||||
nn.Tanh(),
|
|
||||||
]
|
|
||||||
Conv3 += [
|
|
||||||
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
|
||||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
|
||||||
nn.Dropout(self.dropout),
|
|
||||||
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
|
||||||
nn.Tanh(),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.conv1 = nn.Sequential(*Conv1)
|
|
||||||
self.conv2 = nn.Sequential(*Conv2)
|
|
||||||
self.conv3 = nn.Sequential(*Conv3)
|
|
||||||
self.conv4 = nn.Sequential(*Conv4)
|
|
||||||
|
|
||||||
self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.splitting:
|
|
||||||
(x_even, x_odd) = self.split(x)
|
|
||||||
else:
|
|
||||||
(x_even, x_odd) = x
|
|
||||||
|
|
||||||
x1 = self.conv1(x_even)
|
|
||||||
x1 = self.dgcn(x1)
|
|
||||||
d = x_odd.mul(torch.tanh(x1))
|
|
||||||
|
|
||||||
x2 = self.conv2(x_odd)
|
|
||||||
x2 = self.dgcn(x2)
|
|
||||||
c = x_even.mul(torch.tanh(x2))
|
|
||||||
|
|
||||||
x3 = self.conv3(c)
|
|
||||||
x3 = self.dgcn(x3)
|
|
||||||
x_odd_update = d + x3
|
|
||||||
|
|
||||||
x4 = self.conv4(d)
|
|
||||||
x4 = self.dgcn(x4)
|
|
||||||
x_even_update = c + x4
|
|
||||||
|
|
||||||
return (x_even_update, x_odd_update)
|
|
||||||
|
|
||||||
|
|
||||||
class IDGCN_Tree(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6))
|
|
||||||
self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3))
|
|
||||||
self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3))
|
|
||||||
|
|
||||||
self.IDGCN1 = IDGCN(
|
|
||||||
device=device,
|
|
||||||
splitting=True,
|
|
||||||
channels=channels,
|
|
||||||
diffusion_step=diffusion_step,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
dropout=dropout, emb=self.memory1
|
|
||||||
)
|
|
||||||
self.IDGCN2 = IDGCN(
|
|
||||||
device=device,
|
|
||||||
splitting=True,
|
|
||||||
channels=channels,
|
|
||||||
diffusion_step=diffusion_step,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
dropout=dropout, emb=self.memory2
|
|
||||||
)
|
|
||||||
self.IDGCN3 = IDGCN(
|
|
||||||
device=device,
|
|
||||||
splitting=True,
|
|
||||||
channels=channels,
|
|
||||||
diffusion_step=diffusion_step,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
dropout=dropout, emb=self.memory2
|
|
||||||
)
|
|
||||||
|
|
||||||
def concat(self, even, odd):
|
|
||||||
even = even.permute(3, 1, 2, 0)
|
|
||||||
odd = odd.permute(3, 1, 2, 0)
|
|
||||||
len = even.shape[0]
|
|
||||||
_ = []
|
|
||||||
for i in range(len):
|
|
||||||
_.append(even[i].unsqueeze(0))
|
|
||||||
_.append(odd[i].unsqueeze(0))
|
|
||||||
return torch.cat(_, 0).permute(3, 1, 2, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_even_update1, x_odd_update1 = self.IDGCN1(x)
|
|
||||||
x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1)
|
|
||||||
x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1)
|
|
||||||
concat1 = self.concat(x_even_update2, x_odd_update2)
|
|
||||||
concat2 = self.concat(x_even_update3, x_odd_update3)
|
|
||||||
concat0 = self.concat(concat1, concat2)
|
|
||||||
output = concat0 + x
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class STIDGCN(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
"""
|
|
||||||
device, input_dim, num_nodes, channels, granularity, dropout=0.1
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
device = args['device']
|
|
||||||
input_dim = args['input_dim']
|
|
||||||
self.num_nodes = args['num_nodes']
|
|
||||||
self.output_len = 12
|
|
||||||
channels = args['channels']
|
|
||||||
granularity = args['granularity']
|
|
||||||
dropout = args['dropout']
|
|
||||||
diffusion_step = 1
|
|
||||||
|
|
||||||
self.Temb = TemporalEmbedding(granularity, channels)
|
|
||||||
|
|
||||||
self.start_conv = nn.Conv2d(
|
|
||||||
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tree = IDGCN_Tree(
|
|
||||||
device=device,
|
|
||||||
channels=channels * 2,
|
|
||||||
diffusion_step=diffusion_step,
|
|
||||||
num_nodes=self.num_nodes,
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.glu = GLU(channels * 2, dropout)
|
|
||||||
|
|
||||||
self.regression_layer = nn.Conv2d(
|
|
||||||
channels * 2, self.output_len, kernel_size=(1, self.output_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
def param_num(self):
|
|
||||||
return sum([param.nelement() for param in self.parameters()])
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
input = input.transpose(1, 3)
|
|
||||||
x = input
|
|
||||||
# Encoder
|
|
||||||
# Data Embedding
|
|
||||||
time_emb = self.Temb(input.permute(0, 3, 2, 1))
|
|
||||||
x = torch.cat([self.start_conv(x)] + [time_emb], dim=1)
|
|
||||||
# IDGCN_Tree
|
|
||||||
x = self.tree(x)
|
|
||||||
# Decoder
|
|
||||||
gcn = self.glu(x) + x
|
|
||||||
prediction = self.regression_layer(F.relu(gcn))
|
|
||||||
return prediction
|
|
||||||
|
|
@ -14,7 +14,6 @@ from model.STSGCN.STSGCN import STSGCN
|
||||||
from model.STGODE.STGODE import ODEGCN
|
from model.STGODE.STGODE import ODEGCN
|
||||||
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
||||||
from model.STMLP.STMLP import STMLP
|
from model.STMLP.STMLP import STMLP
|
||||||
from model.STIDGCN.STIDGCN import STIDGCN
|
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -34,5 +33,4 @@ def model_selector(model):
|
||||||
case 'STGODE': return ODEGCN(model)
|
case 'STGODE': return ODEGCN(model)
|
||||||
case 'PDG2SEQ': return PDG2Seq(model)
|
case 'PDG2SEQ': return PDG2Seq(model)
|
||||||
case 'STMLP': return STMLP(model)
|
case 'STMLP': return STMLP(model)
|
||||||
case 'STIDGCN': return STIDGCN(model)
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue