新增消耗分析模式,只需在原有的mode中调整为benchmark即可

This commit is contained in:
czzhangheng 2025-03-27 20:02:16 +08:00
parent 5306d24408
commit 8c839642e1
3 changed files with 314 additions and 195 deletions

View File

@ -1,248 +1,117 @@
import torch import torch, torch.nn as nn, torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict from collections import OrderedDict
class DGCRM(nn.Module): class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__() super().__init__()
assert num_layers >= 1, 'At least one DGCRM layer is required.' self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
self.cells = nn.ModuleList(
self.node_num = node_num [DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in
self.input_dim = dim_in range(num_layers)])
self.num_layers = num_layers
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0 else
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings): def forward(self, x, init_state, node_embeddings):
"""
Forward pass of the DGCRM model.
Parameters:
- x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers): for i in range(self.num_layers):
state = init_state[i] state, inner = init_state[i], []
inner_states = [] state = state.to(x.device)
for t in range(x.shape[1]):
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner.append(state)
init_state[i] = state
x = torch.stack(inner, dim=1)
return x, init_state
for t in range(seq_length): def init_hidden(self, bs):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack([
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
], dim=0)
class DDGCRN(nn.Module): class DDGCRN(nn.Module):
def __init__(self, args): def __init__(self, args):
super(DDGCRN, self).__init__() super().__init__()
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
self.num_node = args['num_nodes'] self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers']
self.input_dim = args['input_dim'] self.use_day, self.use_week = args['use_day'], args['use_week']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_day = args['use_day']
self.use_week = args['use_week']
self.default_graph = args['default_graph']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim'])) self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.drop1, self.drop2 = nn.Dropout(0.1), nn.Dropout(0.1)
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers) self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers) self.num_layers)
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
# Predictor def forward(self, source):
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) node_embed = self.node_embeddings1
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, **kwargs):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day: if self.use_day:
t_i_d_data = source[..., 1] node_embed = node_embed * self.T_i_D_emb[(source[..., 1] * 288).long()]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week: if self.use_week:
d_i_w_data = source[..., 2] node_embed = node_embed * self.D_i_W_emb[source[..., 2].long()]
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()] node_embeddings = [node_embed, self.node_embeddings1]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1) source = source[..., 0].unsqueeze(-1)
init1 = self.encoder1.init_hidden(source.shape[0])
init_state1 = self.encoder1.init_hidden(source.shape[0]) out, _ = self.encoder1(source, init1, node_embeddings)
output, _ = self.encoder1(source, init_state1, node_embeddings) out = self.drop1(out[:, -1:, :, :])
output = self.dropout1(output[:, -1:, :, :]) out1 = self.end_conv1(out)
output1 = self.end_conv1(output) src1 = self.end_conv2(out)
src2 = source[:, -self.horizon:, ...] - src1
source1 = self.end_conv2(output) init2 = self.encoder2.init_hidden(source.shape[0])
source2 = source[:, -self.horizon:, ...] - source1 out2, _ = self.encoder2(src2, init2, node_embeddings)
out2 = self.drop2(out2[:, -1:, :, :])
init_state2 = self.encoder2.init_hidden(source2.shape[0]) return out1 + self.end_conv3(out2)
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2
class DDGCRNCell(nn.Module): class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__() super().__init__()
self.node_num = node_num self.node_num, self.hidden_dim = node_num, dim_out
self.hidden_dim = dim_out self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim)
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings): def forward(self, x, state, node_embeddings):
state = state.to(x.device) inp = torch.cat((x, state), -1)
input_and_state = torch.cat((x, state), dim=-1) z_r = torch.sigmoid(self.gate(inp, node_embeddings))
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) z, r = torch.split(z_r, self.hidden_dim, -1)
z, r = torch.split(z_r, self.hidden_dim, dim=-1) hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
candidate = torch.cat((x, z * state), dim=-1) return r * state + (1 - r) * hc
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, batch_size): def init_hidden_state(self, bs):
return torch.zeros(batch_size, self.node_num, self.hidden_dim) return torch.zeros(bs, self.node_num, self.hidden_dim)
class DGCN(nn.Module): class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim): def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__() super().__init__()
self.cheb_k = cheb_k self.cheb_k, self.embed_dim = cheb_k, embed_dim
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(OrderedDict([ self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)), ('fc1', nn.Linear(dim_in, 16)),
('sigmoid1', nn.Sigmoid()), ('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)), ('fc2', nn.Linear(16, 2)),
('sigmoid2', nn.Sigmoid()), ('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, self.embed_dim)) ('fc3', nn.Linear(2, embed_dim))
])) ]))
def forward(self, x, node_embeddings): def forward(self, x, node_embeddings):
"""
Forward pass for the DGCN model.
Parameters:
- x: Input tensor of shape [B, N, C]
- node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1] node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix supp1 = torch.eye(node_num).to(node_embeddings[0].device)
filt = self.fc(x)
# Apply fully connected layers nodevec = torch.tanh(node_embeddings[0] * filt)
filter = self.fc(x) supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
# Compute Laplacian
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool) weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool)
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
return x_gconv
@staticmethod @staticmethod
def get_laplacian(graph, I, normalize=True): def get_laplacian(graph, I, normalize=True):
""" D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
Compute the Laplacian of the graph. return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
torch.matmul(D_inv, graph + I), D_inv)
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L

248
model/DDGCRN/DDGCRN_old.py Normal file
View File

@ -0,0 +1,248 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__()
assert num_layers >= 1, 'At least one DGCRM layer is required.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0 else
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings):
"""
Forward pass of the DGCRM model.
Parameters:
- x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack([
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
], dim=0)
class DDGCRN(nn.Module):
def __init__(self, args):
super(DDGCRN, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_day = args['use_day']
self.use_week = args['use_week']
self.default_graph = args['default_graph']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
# Predictor
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, **kwargs):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day:
t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week:
d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state1 = self.encoder1.init_hidden(source.shape[0])
output, _ = self.encoder1(source, init_state1, node_embeddings)
output = self.dropout1(output[:, -1:, :, :])
output1 = self.end_conv1(output)
source1 = self.end_conv2(output)
source2 = source[:, -self.horizon:, ...] - source1
init_state2 = self.encoder2.init_hidden(source2.shape[0])
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings):
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z * state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__()
self.cheb_k = cheb_k
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
]))
def forward(self, x, node_embeddings):
"""
Forward pass for the DGCN model.
Parameters:
- x: Input tensor of shape [B, N, C]
- node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
# Apply fully connected layers
filter = self.fc(x)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
# Compute Laplacian
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
return x_gconv
@staticmethod
def get_laplacian(graph, I, normalize=True):
"""
Compute the Laplacian of the graph.
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L

View File

@ -13,6 +13,7 @@ from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.PDG2SEQ.PDG2Seq import PDG2Seq
from model.EXP.EXP import EXP
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:
@ -31,4 +32,5 @@ def model_selector(model):
case 'STSGCN': return STSGCN(model) case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model) case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model) case 'PDG2SEQ': return PDG2Seq(model)
case 'EXP': return EXP(model)