Compare commits

..

No commits in common. "1b76cc6ce26dea851ceec25ddd96c207bf4cff14" and "5306d244081716258786d2b1ba50ea70fd9cdee3" have entirely different histories.

3 changed files with 199 additions and 320 deletions

View File

@ -1,119 +1,248 @@
import torch, torch.nn as nn, torch.nn.functional as F import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict from collections import OrderedDict
class DGCRM(nn.Module): class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super().__init__() super(DGCRM, self).__init__()
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers assert num_layers >= 1, 'At least one DGCRM layer is required.'
self.cells = nn.ModuleList(
[DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers)] self.node_num = node_num
) self.input_dim = dim_in
self.num_layers = num_layers
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0 else
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings): def forward(self, x, init_state, node_embeddings):
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim """
for i in range(self.num_layers): Forward pass of the DGCRM model.
state, inner = init_state[i].to(x.device), []
for t in range(x.shape[1]):
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner.append(state)
init_state[i] = state
x = torch.stack(inner, dim=1)
return x, init_state
def init_hidden(self, bs): Parameters:
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0) - x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack([
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
], dim=0)
class DDGCRN(nn.Module): class DDGCRN(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super(DDGCRN, self).__init__()
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers'] self.num_node = args['num_nodes']
self.use_day, self.use_week = args['use_day'], args['use_week'] self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_day = args['use_day']
self.use_week = args['use_week']
self.default_graph = args['default_graph']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim'])) self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.drop1, self.drop2 = nn.Dropout(0.1), nn.Dropout(0.1)
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers) self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers) self.num_layers)
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
def forward(self, source): # Predictor
node_embed = self.node_embeddings1 self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, **kwargs):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day: if self.use_day:
node_embed = node_embed * self.T_i_D_emb[(source[..., 1] * 288).long()] t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week: if self.use_week:
node_embed = node_embed * self.D_i_W_emb[source[..., 2].long()] d_i_w_data = source[..., 2]
node_embeddings = [node_embed, self.node_embeddings1] D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1) source = source[..., 0].unsqueeze(-1)
init1 = self.encoder1.init_hidden(source.shape[0])
out, _ = self.encoder1(source, init1, node_embeddings) init_state1 = self.encoder1.init_hidden(source.shape[0])
out = self.drop1(out[:, -1:, :, :]) output, _ = self.encoder1(source, init_state1, node_embeddings)
out1 = self.end_conv1(out) output = self.dropout1(output[:, -1:, :, :])
src1 = self.end_conv2(out) output1 = self.end_conv1(output)
src2 = source[:, -self.horizon:, ...] - src1
init2 = self.encoder2.init_hidden(source.shape[0]) source1 = self.end_conv2(output)
out2, _ = self.encoder2(src2, init2, node_embeddings) source2 = source[:, -self.horizon:, ...] - source1
out2 = self.drop2(out2[:, -1:, :, :])
return out1 + self.end_conv3(out2) init_state2 = self.encoder2.init_hidden(source2.shape[0])
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2
class DDGCRNCell(nn.Module): class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super().__init__() super(DDGCRNCell, self).__init__()
self.node_num, self.hidden_dim = node_num, dim_out self.node_num = node_num
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num) self.hidden_dim = dim_out
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num) self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings): def forward(self, x, state, node_embeddings):
inp = torch.cat((x, state), -1) state = state.to(x.device)
z_r = torch.sigmoid(self.gate(inp, node_embeddings)) input_and_state = torch.cat((x, state), dim=-1)
z, r = torch.split(z_r, self.hidden_dim, -1) z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings)) z, r = torch.split(z_r, self.hidden_dim, dim=-1)
return r * state + (1 - r) * hc candidate = torch.cat((x, z * state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, bs): def init_hidden_state(self, batch_size):
return torch.zeros(bs, self.node_num, self.hidden_dim) return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class DGCN(nn.Module): class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes): def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super().__init__() super(DGCN, self).__init__()
self.cheb_k, self.embed_dim = cheb_k, embed_dim self.cheb_k = cheb_k
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(OrderedDict([ self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, 16)), ('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
('sigmoid1', nn.Sigmoid()), ('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(16, 2)), ('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
('sigmoid2', nn.Sigmoid()), ('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(2, embed_dim)) ('fc3', nn.Linear(self.middle_dim, self.embed_dim))
])) ]))
# 预注册恒定不变的单位矩阵
self.register_buffer('eye', torch.eye(num_nodes))
def forward(self, x, node_embeddings): def forward(self, x, node_embeddings):
supp1 = self.eye.to(node_embeddings[0].device) """
filt = self.fc(x) Forward pass for the DGCN model.
nodevec = torch.tanh(node_embeddings[0] * filt)
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1) Parameters:
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), - x: Input tensor of shape [B, N, C]
torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1) - node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
# Apply fully connected layers
filter = self.fc(x)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
# Compute Laplacian
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool) weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool)
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
return x_gconv
@staticmethod @staticmethod
def get_laplacian(graph, I, normalize=True): def get_laplacian(graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) """
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul( Compute the Laplacian of the graph.
torch.matmul(D_inv, graph + I), D_inv)
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L

View File

@ -1,248 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__()
assert num_layers >= 1, 'At least one DGCRM layer is required.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0 else
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings):
"""
Forward pass of the DGCRM model.
Parameters:
- x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack([
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
], dim=0)
class DDGCRN(nn.Module):
def __init__(self, args):
super(DDGCRN, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_day = args['use_day']
self.use_week = args['use_week']
self.default_graph = args['default_graph']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
# Predictor
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, **kwargs):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day:
t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week:
d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state1 = self.encoder1.init_hidden(source.shape[0])
output, _ = self.encoder1(source, init_state1, node_embeddings)
output = self.dropout1(output[:, -1:, :, :])
output1 = self.end_conv1(output)
source1 = self.end_conv2(output)
source2 = source[:, -self.horizon:, ...] - source1
init_state2 = self.encoder2.init_hidden(source2.shape[0])
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings):
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z * state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__()
self.cheb_k = cheb_k
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
]))
def forward(self, x, node_embeddings):
"""
Forward pass for the DGCN model.
Parameters:
- x: Input tensor of shape [B, N, C]
- node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
# Apply fully connected layers
filter = self.fc(x)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
# Compute Laplacian
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
return x_gconv
@staticmethod
def get_laplacian(graph, I, normalize=True):
"""
Compute the Laplacian of the graph.
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L

View File

@ -13,7 +13,6 @@ from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.PDG2SEQ.PDG2Seq import PDG2Seq
from model.EXP.EXP import EXP
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:
@ -32,5 +31,4 @@ def model_selector(model):
case 'STSGCN': return STSGCN(model) case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model) case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model) case 'PDG2SEQ': return PDG2Seq(model)
case 'EXP': return EXP(model)