添加STIDGCN
This commit is contained in:
parent
229b6320b9
commit
97eb39073a
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 358
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
granularity: 288
|
||||||
|
dropout: 0.1
|
||||||
|
channels: 32
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 307
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
granularity: 288
|
||||||
|
dropout: 0.1
|
||||||
|
channels: 64
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 883
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
granularity: 288
|
||||||
|
dropout: 0.1
|
||||||
|
channels: 128
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 16
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
num_nodes: 170
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 288
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
granularity: 288
|
||||||
|
dropout: 0.1
|
||||||
|
channels: 96
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss_func: mae
|
||||||
|
seed: 10
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 300
|
||||||
|
lr_init: 0.003
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
test:
|
||||||
|
mae_thresh: null
|
||||||
|
mape_thresh: 0.0
|
||||||
|
|
||||||
|
log:
|
||||||
|
log_step: 200
|
||||||
|
plot: False
|
||||||
|
|
@ -0,0 +1,368 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(self, features, dropout=0.1):
|
||||||
|
super(GLU, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.conv2 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.conv3 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
x2 = self.conv2(x)
|
||||||
|
out = x1 * torch.sigmoid(x2)
|
||||||
|
out = self.dropout(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# class TemporalEmbedding(nn.Module):
|
||||||
|
# def __init__(self, time, features):
|
||||||
|
# super(TemporalEmbedding, self).__init__()
|
||||||
|
#
|
||||||
|
# self.time = time
|
||||||
|
# # self.time_day = nn.Parameter(torch.empty(time, features))
|
||||||
|
# # nn.init.xavier_uniform_(self.time_day)
|
||||||
|
# #
|
||||||
|
# # self.time_week = nn.Parameter(torch.empty(7, features))
|
||||||
|
# # nn.init.xavier_uniform_(self.time_week)
|
||||||
|
# self.time_day = nn.Embedding(time, features)
|
||||||
|
# self.time_week = nn.Embedding(7, features)
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# day_emb = x[..., 1]
|
||||||
|
# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
||||||
|
# # time_day = time_day.transpose(1, 2).contiguous()
|
||||||
|
#
|
||||||
|
# week_emb = x[..., 2]
|
||||||
|
# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
|
||||||
|
# # time_week = time_week.transpose(1, 2).contiguous()
|
||||||
|
#
|
||||||
|
# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N)
|
||||||
|
# d_idx = week_emb[:, -1, :, ].long() # (B, N)
|
||||||
|
# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
|
||||||
|
# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
|
||||||
|
#
|
||||||
|
# tem_emb = t_idx + d_idx
|
||||||
|
#
|
||||||
|
# # tem_emb = tem_emb.permute(0, 3, 1, 2)
|
||||||
|
#
|
||||||
|
# return tem_emb
|
||||||
|
class TemporalEmbedding(nn.Module):
|
||||||
|
def __init__(self, time, features):
|
||||||
|
super(TemporalEmbedding, self).__init__()
|
||||||
|
|
||||||
|
self.time = time
|
||||||
|
self.time_day = nn.Parameter(torch.empty(time, features))
|
||||||
|
nn.init.xavier_uniform_(self.time_day)
|
||||||
|
|
||||||
|
self.time_week = nn.Parameter(torch.empty(7, features))
|
||||||
|
nn.init.xavier_uniform_(self.time_week)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
day_emb = x[..., 1]
|
||||||
|
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
||||||
|
time_day = time_day.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
week_emb = x[..., 2]
|
||||||
|
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
|
||||||
|
time_week = time_week.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
tem_emb = time_day + time_week
|
||||||
|
|
||||||
|
tem_emb = tem_emb.permute(0,3,1,2)
|
||||||
|
|
||||||
|
return tem_emb
|
||||||
|
|
||||||
|
class Diffusion_GCN(nn.Module):
|
||||||
|
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.diffusion_step = diffusion_step
|
||||||
|
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x, adj):
|
||||||
|
out = []
|
||||||
|
for i in range(0, self.diffusion_step):
|
||||||
|
if adj.dim() == 3:
|
||||||
|
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
|
||||||
|
out.append(x)
|
||||||
|
elif adj.dim() == 2:
|
||||||
|
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
|
||||||
|
out.append(x)
|
||||||
|
x = torch.cat(out, dim=1)
|
||||||
|
x = self.conv(x)
|
||||||
|
output = self.dropout(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Graph_Generator(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
|
||||||
|
nn.init.xavier_uniform_(self.memory)
|
||||||
|
self.fc = nn.Linear(2, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
adj_dyn_1 = torch.softmax(
|
||||||
|
F.relu(
|
||||||
|
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
||||||
|
/ math.sqrt(x.shape[1])
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
adj_dyn_2 = torch.softmax(
|
||||||
|
F.relu(
|
||||||
|
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
||||||
|
/ math.sqrt(x.shape[1])
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
# adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2
|
||||||
|
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
|
||||||
|
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
||||||
|
|
||||||
|
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
|
||||||
|
mask = torch.zeros_like(adj_f)
|
||||||
|
mask.scatter_(-1, topk_indices, 1)
|
||||||
|
adj_f = adj_f * mask
|
||||||
|
|
||||||
|
return adj_f
|
||||||
|
|
||||||
|
|
||||||
|
class DGCN(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(channels, channels, (1, 1))
|
||||||
|
self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout)
|
||||||
|
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
||||||
|
self.emb = emb
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
skip = x
|
||||||
|
x = self.conv(x)
|
||||||
|
adj_dyn = self.generator(x)
|
||||||
|
x = self.gcn(x, adj_dyn)
|
||||||
|
x = x * self.emb + skip
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Splitting(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Splitting, self).__init__()
|
||||||
|
|
||||||
|
def even(self, x):
|
||||||
|
return x[:, :, :, ::2]
|
||||||
|
|
||||||
|
def odd(self, x):
|
||||||
|
return x[:, :, :, 1::2]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (self.even(x), self.odd(x))
|
||||||
|
|
||||||
|
|
||||||
|
class IDGCN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
channels=64,
|
||||||
|
diffusion_step=1,
|
||||||
|
splitting=True,
|
||||||
|
num_nodes=170,
|
||||||
|
dropout=0.2, emb=None
|
||||||
|
):
|
||||||
|
super(IDGCN, self).__init__()
|
||||||
|
|
||||||
|
device = device
|
||||||
|
self.dropout = dropout
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.splitting = splitting
|
||||||
|
self.split = Splitting()
|
||||||
|
|
||||||
|
Conv1 = []
|
||||||
|
Conv2 = []
|
||||||
|
Conv3 = []
|
||||||
|
Conv4 = []
|
||||||
|
pad_l = 3
|
||||||
|
pad_r = 3
|
||||||
|
|
||||||
|
k1 = 5
|
||||||
|
k2 = 3
|
||||||
|
Conv1 += [
|
||||||
|
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
||||||
|
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
Conv2 += [
|
||||||
|
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
||||||
|
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
Conv4 += [
|
||||||
|
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
||||||
|
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
Conv3 += [
|
||||||
|
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
|
||||||
|
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||||
|
nn.Dropout(self.dropout),
|
||||||
|
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.conv1 = nn.Sequential(*Conv1)
|
||||||
|
self.conv2 = nn.Sequential(*Conv2)
|
||||||
|
self.conv3 = nn.Sequential(*Conv3)
|
||||||
|
self.conv4 = nn.Sequential(*Conv4)
|
||||||
|
|
||||||
|
self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.splitting:
|
||||||
|
(x_even, x_odd) = self.split(x)
|
||||||
|
else:
|
||||||
|
(x_even, x_odd) = x
|
||||||
|
|
||||||
|
x1 = self.conv1(x_even)
|
||||||
|
x1 = self.dgcn(x1)
|
||||||
|
d = x_odd.mul(torch.tanh(x1))
|
||||||
|
|
||||||
|
x2 = self.conv2(x_odd)
|
||||||
|
x2 = self.dgcn(x2)
|
||||||
|
c = x_even.mul(torch.tanh(x2))
|
||||||
|
|
||||||
|
x3 = self.conv3(c)
|
||||||
|
x3 = self.dgcn(x3)
|
||||||
|
x_odd_update = d + x3
|
||||||
|
|
||||||
|
x4 = self.conv4(d)
|
||||||
|
x4 = self.dgcn(x4)
|
||||||
|
x_even_update = c + x4
|
||||||
|
|
||||||
|
return (x_even_update, x_odd_update)
|
||||||
|
|
||||||
|
|
||||||
|
class IDGCN_Tree(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6))
|
||||||
|
self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3))
|
||||||
|
self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3))
|
||||||
|
|
||||||
|
self.IDGCN1 = IDGCN(
|
||||||
|
device=device,
|
||||||
|
splitting=True,
|
||||||
|
channels=channels,
|
||||||
|
diffusion_step=diffusion_step,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
dropout=dropout, emb=self.memory1
|
||||||
|
)
|
||||||
|
self.IDGCN2 = IDGCN(
|
||||||
|
device=device,
|
||||||
|
splitting=True,
|
||||||
|
channels=channels,
|
||||||
|
diffusion_step=diffusion_step,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
dropout=dropout, emb=self.memory2
|
||||||
|
)
|
||||||
|
self.IDGCN3 = IDGCN(
|
||||||
|
device=device,
|
||||||
|
splitting=True,
|
||||||
|
channels=channels,
|
||||||
|
diffusion_step=diffusion_step,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
dropout=dropout, emb=self.memory2
|
||||||
|
)
|
||||||
|
|
||||||
|
def concat(self, even, odd):
|
||||||
|
even = even.permute(3, 1, 2, 0)
|
||||||
|
odd = odd.permute(3, 1, 2, 0)
|
||||||
|
len = even.shape[0]
|
||||||
|
_ = []
|
||||||
|
for i in range(len):
|
||||||
|
_.append(even[i].unsqueeze(0))
|
||||||
|
_.append(odd[i].unsqueeze(0))
|
||||||
|
return torch.cat(_, 0).permute(3, 1, 2, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_even_update1, x_odd_update1 = self.IDGCN1(x)
|
||||||
|
x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1)
|
||||||
|
x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1)
|
||||||
|
concat1 = self.concat(x_even_update2, x_odd_update2)
|
||||||
|
concat2 = self.concat(x_even_update3, x_odd_update3)
|
||||||
|
concat0 = self.concat(concat1, concat2)
|
||||||
|
output = concat0 + x
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class STIDGCN(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
"""
|
||||||
|
device, input_dim, num_nodes, channels, granularity, dropout=0.1
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
device = args['device']
|
||||||
|
input_dim = args['input_dim']
|
||||||
|
self.num_nodes = args['num_nodes']
|
||||||
|
self.output_len = 12
|
||||||
|
channels = args['channels']
|
||||||
|
granularity = args['granularity']
|
||||||
|
dropout = args['dropout']
|
||||||
|
diffusion_step = 1
|
||||||
|
|
||||||
|
self.Temb = TemporalEmbedding(granularity, channels)
|
||||||
|
|
||||||
|
self.start_conv = nn.Conv2d(
|
||||||
|
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tree = IDGCN_Tree(
|
||||||
|
device=device,
|
||||||
|
channels=channels * 2,
|
||||||
|
diffusion_step=diffusion_step,
|
||||||
|
num_nodes=self.num_nodes,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.glu = GLU(channels * 2, dropout)
|
||||||
|
|
||||||
|
self.regression_layer = nn.Conv2d(
|
||||||
|
channels * 2, self.output_len, kernel_size=(1, self.output_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
def param_num(self):
|
||||||
|
return sum([param.nelement() for param in self.parameters()])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
input = input.transpose(1, 3)
|
||||||
|
x = input
|
||||||
|
# Encoder
|
||||||
|
# Data Embedding
|
||||||
|
time_emb = self.Temb(input.permute(0, 3, 2, 1))
|
||||||
|
x = torch.cat([self.start_conv(x)] + [time_emb], dim=1)
|
||||||
|
# IDGCN_Tree
|
||||||
|
x = self.tree(x)
|
||||||
|
# Decoder
|
||||||
|
gcn = self.glu(x) + x
|
||||||
|
prediction = self.regression_layer(F.relu(gcn))
|
||||||
|
return prediction
|
||||||
|
|
@ -13,8 +13,8 @@ from model.STFGNN.STFGNN import STFGNN
|
||||||
from model.STSGCN.STSGCN import STSGCN
|
from model.STSGCN.STSGCN import STSGCN
|
||||||
from model.STGODE.STGODE import ODEGCN
|
from model.STGODE.STGODE import ODEGCN
|
||||||
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
||||||
from model.EXP.EXP import EXP
|
from model.STIDGCN.STIDGCN import STIDGCN
|
||||||
from model.EXPB.EXP_b import EXPB
|
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -33,6 +33,7 @@ def model_selector(model):
|
||||||
case 'STSGCN': return STSGCN(model)
|
case 'STSGCN': return STSGCN(model)
|
||||||
case 'STGODE': return ODEGCN(model)
|
case 'STGODE': return ODEGCN(model)
|
||||||
case 'PDG2SEQ': return PDG2Seq(model)
|
case 'PDG2SEQ': return PDG2Seq(model)
|
||||||
case 'EXP': return EXP(model)
|
case 'STIDGCN': return STIDGCN(model)
|
||||||
case 'EXPB': return EXPB(model)
|
# case 'EXP': return EXP(model)
|
||||||
|
# case 'EXPB': return EXPB(model)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue