新增消耗分析模式,只需在原有的mode中调整为benchmark即可
This commit is contained in:
parent
5306d24408
commit
8c839642e1
|
|
@ -1,248 +1,117 @@
|
||||||
import torch
|
import torch, torch.nn as nn, torch.nn.functional as F
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class DGCRM(nn.Module):
|
class DGCRM(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||||
super(DGCRM, self).__init__()
|
super().__init__()
|
||||||
assert num_layers >= 1, 'At least one DGCRM layer is required.'
|
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
|
||||||
|
self.cells = nn.ModuleList(
|
||||||
self.node_num = node_num
|
[DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in
|
||||||
self.input_dim = dim_in
|
range(num_layers)])
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
# Initialize DGCRM cells
|
|
||||||
self.DGCRM_cells = nn.ModuleList([
|
|
||||||
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
|
|
||||||
if i == 0 else
|
|
||||||
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
|
|
||||||
for i in range(num_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(self, x, init_state, node_embeddings):
|
def forward(self, x, init_state, node_embeddings):
|
||||||
"""
|
|
||||||
Forward pass of the DGCRM model.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- x: Input tensor of shape (B, T, N, D)
|
|
||||||
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
|
|
||||||
- node_embeddings: Node embeddings
|
|
||||||
"""
|
|
||||||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||||
|
|
||||||
seq_length = x.shape[1]
|
|
||||||
current_inputs = x
|
|
||||||
output_hidden = []
|
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
state = init_state[i]
|
state, inner = init_state[i], []
|
||||||
inner_states = []
|
state = state.to(x.device)
|
||||||
|
for t in range(x.shape[1]):
|
||||||
|
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||||
|
inner.append(state)
|
||||||
|
init_state[i] = state
|
||||||
|
x = torch.stack(inner, dim=1)
|
||||||
|
return x, init_state
|
||||||
|
|
||||||
for t in range(seq_length):
|
def init_hidden(self, bs):
|
||||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
|
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
|
||||||
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
|
||||||
inner_states.append(state)
|
|
||||||
|
|
||||||
output_hidden.append(state)
|
|
||||||
current_inputs = torch.stack(inner_states, dim=1)
|
|
||||||
|
|
||||||
return current_inputs, output_hidden
|
|
||||||
|
|
||||||
def init_hidden(self, batch_size):
|
|
||||||
"""
|
|
||||||
Initialize hidden states for DGCRM layers.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- batch_size: Size of the batch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Initial hidden states tensor
|
|
||||||
"""
|
|
||||||
return torch.stack([
|
|
||||||
self.DGCRM_cells[i].init_hidden_state(batch_size)
|
|
||||||
for i in range(self.num_layers)
|
|
||||||
], dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class DDGCRN(nn.Module):
|
class DDGCRN(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(DDGCRN, self).__init__()
|
super().__init__()
|
||||||
|
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
|
||||||
self.num_node = args['num_nodes']
|
self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers']
|
||||||
self.input_dim = args['input_dim']
|
self.use_day, self.use_week = args['use_day'], args['use_week']
|
||||||
self.hidden_dim = args['rnn_units']
|
|
||||||
self.output_dim = args['output_dim']
|
|
||||||
self.horizon = args['horizon']
|
|
||||||
self.num_layers = args['num_layers']
|
|
||||||
self.use_day = args['use_day']
|
|
||||||
self.use_week = args['use_week']
|
|
||||||
self.default_graph = args['default_graph']
|
|
||||||
|
|
||||||
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
||||||
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
||||||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
|
||||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
|
||||||
|
self.drop1, self.drop2 = nn.Dropout(0.1), nn.Dropout(0.1)
|
||||||
self.dropout1 = nn.Dropout(p=0.1)
|
|
||||||
self.dropout2 = nn.Dropout(p=0.1)
|
|
||||||
|
|
||||||
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
||||||
self.num_layers)
|
self.num_layers)
|
||||||
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
||||||
self.num_layers)
|
self.num_layers)
|
||||||
|
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
|
||||||
|
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
|
||||||
|
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
|
||||||
|
|
||||||
# Predictor
|
def forward(self, source):
|
||||||
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
node_embed = self.node_embeddings1
|
||||||
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
||||||
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
||||||
|
|
||||||
def forward(self, source, **kwargs):
|
|
||||||
"""
|
|
||||||
Forward pass of the DDGCRN model.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- source: Input tensor of shape (B, T_1, N, D)
|
|
||||||
- mode: Control mode for the forward pass
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Output tensor
|
|
||||||
"""
|
|
||||||
node_embedding1 = self.node_embeddings1
|
|
||||||
|
|
||||||
if self.use_day:
|
if self.use_day:
|
||||||
t_i_d_data = source[..., 1]
|
node_embed = node_embed * self.T_i_D_emb[(source[..., 1] * 288).long()]
|
||||||
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
|
|
||||||
node_embedding1 = node_embedding1 * T_i_D_emb
|
|
||||||
|
|
||||||
if self.use_week:
|
if self.use_week:
|
||||||
d_i_w_data = source[..., 2]
|
node_embed = node_embed * self.D_i_W_emb[source[..., 2].long()]
|
||||||
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
|
node_embeddings = [node_embed, self.node_embeddings1]
|
||||||
node_embedding1 = node_embedding1 * D_i_W_emb
|
|
||||||
|
|
||||||
node_embeddings = [node_embedding1, self.node_embeddings1]
|
|
||||||
source = source[..., 0].unsqueeze(-1)
|
source = source[..., 0].unsqueeze(-1)
|
||||||
|
init1 = self.encoder1.init_hidden(source.shape[0])
|
||||||
init_state1 = self.encoder1.init_hidden(source.shape[0])
|
out, _ = self.encoder1(source, init1, node_embeddings)
|
||||||
output, _ = self.encoder1(source, init_state1, node_embeddings)
|
out = self.drop1(out[:, -1:, :, :])
|
||||||
output = self.dropout1(output[:, -1:, :, :])
|
out1 = self.end_conv1(out)
|
||||||
output1 = self.end_conv1(output)
|
src1 = self.end_conv2(out)
|
||||||
|
src2 = source[:, -self.horizon:, ...] - src1
|
||||||
source1 = self.end_conv2(output)
|
init2 = self.encoder2.init_hidden(source.shape[0])
|
||||||
source2 = source[:, -self.horizon:, ...] - source1
|
out2, _ = self.encoder2(src2, init2, node_embeddings)
|
||||||
|
out2 = self.drop2(out2[:, -1:, :, :])
|
||||||
init_state2 = self.encoder2.init_hidden(source2.shape[0])
|
return out1 + self.end_conv3(out2)
|
||||||
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
|
|
||||||
output2 = self.dropout2(output2[:, -1:, :, :])
|
|
||||||
output2 = self.end_conv3(output2)
|
|
||||||
|
|
||||||
return output1 + output2
|
|
||||||
|
|
||||||
|
|
||||||
class DDGCRNCell(nn.Module):
|
class DDGCRNCell(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
super(DDGCRNCell, self).__init__()
|
super().__init__()
|
||||||
self.node_num = node_num
|
self.node_num, self.hidden_dim = node_num, dim_out
|
||||||
self.hidden_dim = dim_out
|
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim)
|
||||||
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim)
|
||||||
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x, state, node_embeddings):
|
def forward(self, x, state, node_embeddings):
|
||||||
state = state.to(x.device)
|
inp = torch.cat((x, state), -1)
|
||||||
input_and_state = torch.cat((x, state), dim=-1)
|
z_r = torch.sigmoid(self.gate(inp, node_embeddings))
|
||||||
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
|
z, r = torch.split(z_r, self.hidden_dim, -1)
|
||||||
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
|
||||||
candidate = torch.cat((x, z * state), dim=-1)
|
return r * state + (1 - r) * hc
|
||||||
hc = torch.tanh(self.update(candidate, node_embeddings))
|
|
||||||
h = r * state + (1 - r) * hc
|
|
||||||
return h
|
|
||||||
|
|
||||||
def init_hidden_state(self, batch_size):
|
def init_hidden_state(self, bs):
|
||||||
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
return torch.zeros(bs, self.node_num, self.hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
class DGCN(nn.Module):
|
class DGCN(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
super(DGCN, self).__init__()
|
super().__init__()
|
||||||
self.cheb_k = cheb_k
|
self.cheb_k, self.embed_dim = cheb_k, embed_dim
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
# Initialize parameters
|
|
||||||
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
||||||
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||||||
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||||||
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
self.hyperGNN_dim = 16
|
|
||||||
self.middle_dim = 2
|
|
||||||
|
|
||||||
# Fully connected layers
|
|
||||||
self.fc = nn.Sequential(OrderedDict([
|
self.fc = nn.Sequential(OrderedDict([
|
||||||
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
|
('fc1', nn.Linear(dim_in, 16)),
|
||||||
('sigmoid1', nn.Sigmoid()),
|
('sigmoid1', nn.Sigmoid()),
|
||||||
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
|
('fc2', nn.Linear(16, 2)),
|
||||||
('sigmoid2', nn.Sigmoid()),
|
('sigmoid2', nn.Sigmoid()),
|
||||||
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
|
('fc3', nn.Linear(2, embed_dim))
|
||||||
]))
|
]))
|
||||||
|
|
||||||
def forward(self, x, node_embeddings):
|
def forward(self, x, node_embeddings):
|
||||||
"""
|
|
||||||
Forward pass for the DGCN model.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- x: Input tensor of shape [B, N, C]
|
|
||||||
- node_embeddings: Node embeddings tensor of shape [N, D]
|
|
||||||
- connMtx: Connectivity matrix
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- x_gconv: Output tensor of shape [B, N, dim_out]
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_num = node_embeddings[0].shape[1]
|
node_num = node_embeddings[0].shape[1]
|
||||||
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
|
supp1 = torch.eye(node_num).to(node_embeddings[0].device)
|
||||||
|
filt = self.fc(x)
|
||||||
# Apply fully connected layers
|
nodevec = torch.tanh(node_embeddings[0] * filt)
|
||||||
filter = self.fc(x)
|
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
|
||||||
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
|
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
|
||||||
|
|
||||||
# Compute Laplacian
|
|
||||||
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
|
|
||||||
|
|
||||||
# Graph convolution
|
|
||||||
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
|
|
||||||
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
|
|
||||||
x_g = torch.stack([x_g1, x_g2], dim=1)
|
|
||||||
|
|
||||||
# Apply graph convolution weights and biases
|
|
||||||
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
||||||
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
||||||
|
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
|
||||||
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
|
|
||||||
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
|
|
||||||
|
|
||||||
return x_gconv
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_laplacian(graph, I, normalize=True):
|
def get_laplacian(graph, I, normalize=True):
|
||||||
"""
|
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
||||||
Compute the Laplacian of the graph.
|
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
|
||||||
|
torch.matmul(D_inv, graph + I), D_inv)
|
||||||
Parameters:
|
|
||||||
- graph: Adjacency matrix of the graph, [N, N]
|
|
||||||
- I: Identity matrix
|
|
||||||
- normalize: Whether to use the normalized Laplacian
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- L: Graph Laplacian
|
|
||||||
"""
|
|
||||||
if normalize:
|
|
||||||
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
|
||||||
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
|
||||||
else:
|
|
||||||
graph = graph + I
|
|
||||||
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
|
||||||
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
|
||||||
return L
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,248 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
class DGCRM(nn.Module):
|
||||||
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||||
|
super(DGCRM, self).__init__()
|
||||||
|
assert num_layers >= 1, 'At least one DGCRM layer is required.'
|
||||||
|
|
||||||
|
self.node_num = node_num
|
||||||
|
self.input_dim = dim_in
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# Initialize DGCRM cells
|
||||||
|
self.DGCRM_cells = nn.ModuleList([
|
||||||
|
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
|
||||||
|
if i == 0 else
|
||||||
|
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
|
||||||
|
for i in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, init_state, node_embeddings):
|
||||||
|
"""
|
||||||
|
Forward pass of the DGCRM model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x: Input tensor of shape (B, T, N, D)
|
||||||
|
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
|
||||||
|
- node_embeddings: Node embeddings
|
||||||
|
"""
|
||||||
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||||
|
|
||||||
|
seq_length = x.shape[1]
|
||||||
|
current_inputs = x
|
||||||
|
output_hidden = []
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
state = init_state[i]
|
||||||
|
inner_states = []
|
||||||
|
|
||||||
|
for t in range(seq_length):
|
||||||
|
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
|
||||||
|
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||||
|
inner_states.append(state)
|
||||||
|
|
||||||
|
output_hidden.append(state)
|
||||||
|
current_inputs = torch.stack(inner_states, dim=1)
|
||||||
|
|
||||||
|
return current_inputs, output_hidden
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size):
|
||||||
|
"""
|
||||||
|
Initialize hidden states for DGCRM layers.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- batch_size: Size of the batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Initial hidden states tensor
|
||||||
|
"""
|
||||||
|
return torch.stack([
|
||||||
|
self.DGCRM_cells[i].init_hidden_state(batch_size)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DDGCRN(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(DDGCRN, self).__init__()
|
||||||
|
|
||||||
|
self.num_node = args['num_nodes']
|
||||||
|
self.input_dim = args['input_dim']
|
||||||
|
self.hidden_dim = args['rnn_units']
|
||||||
|
self.output_dim = args['output_dim']
|
||||||
|
self.horizon = args['horizon']
|
||||||
|
self.num_layers = args['num_layers']
|
||||||
|
self.use_day = args['use_day']
|
||||||
|
self.use_week = args['use_week']
|
||||||
|
self.default_graph = args['default_graph']
|
||||||
|
|
||||||
|
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
||||||
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
|
||||||
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
|
||||||
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(p=0.1)
|
||||||
|
self.dropout2 = nn.Dropout(p=0.1)
|
||||||
|
|
||||||
|
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
||||||
|
self.num_layers)
|
||||||
|
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
|
||||||
|
self.num_layers)
|
||||||
|
|
||||||
|
# Predictor
|
||||||
|
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
|
||||||
|
def forward(self, source, **kwargs):
|
||||||
|
"""
|
||||||
|
Forward pass of the DDGCRN model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- source: Input tensor of shape (B, T_1, N, D)
|
||||||
|
- mode: Control mode for the forward pass
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Output tensor
|
||||||
|
"""
|
||||||
|
node_embedding1 = self.node_embeddings1
|
||||||
|
|
||||||
|
if self.use_day:
|
||||||
|
t_i_d_data = source[..., 1]
|
||||||
|
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
|
||||||
|
node_embedding1 = node_embedding1 * T_i_D_emb
|
||||||
|
|
||||||
|
if self.use_week:
|
||||||
|
d_i_w_data = source[..., 2]
|
||||||
|
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
|
||||||
|
node_embedding1 = node_embedding1 * D_i_W_emb
|
||||||
|
|
||||||
|
node_embeddings = [node_embedding1, self.node_embeddings1]
|
||||||
|
source = source[..., 0].unsqueeze(-1)
|
||||||
|
|
||||||
|
init_state1 = self.encoder1.init_hidden(source.shape[0])
|
||||||
|
output, _ = self.encoder1(source, init_state1, node_embeddings)
|
||||||
|
output = self.dropout1(output[:, -1:, :, :])
|
||||||
|
output1 = self.end_conv1(output)
|
||||||
|
|
||||||
|
source1 = self.end_conv2(output)
|
||||||
|
source2 = source[:, -self.horizon:, ...] - source1
|
||||||
|
|
||||||
|
init_state2 = self.encoder2.init_hidden(source2.shape[0])
|
||||||
|
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
|
||||||
|
output2 = self.dropout2(output2[:, -1:, :, :])
|
||||||
|
output2 = self.end_conv3(output2)
|
||||||
|
|
||||||
|
return output1 + output2
|
||||||
|
|
||||||
|
|
||||||
|
class DDGCRNCell(nn.Module):
|
||||||
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
|
super(DDGCRNCell, self).__init__()
|
||||||
|
self.node_num = node_num
|
||||||
|
self.hidden_dim = dim_out
|
||||||
|
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
||||||
|
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, x, state, node_embeddings):
|
||||||
|
state = state.to(x.device)
|
||||||
|
input_and_state = torch.cat((x, state), dim=-1)
|
||||||
|
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
|
||||||
|
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
||||||
|
candidate = torch.cat((x, z * state), dim=-1)
|
||||||
|
hc = torch.tanh(self.update(candidate, node_embeddings))
|
||||||
|
h = r * state + (1 - r) * hc
|
||||||
|
return h
|
||||||
|
|
||||||
|
def init_hidden_state(self, batch_size):
|
||||||
|
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class DGCN(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
|
super(DGCN, self).__init__()
|
||||||
|
self.cheb_k = cheb_k
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
|
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
|
||||||
|
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||||||
|
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||||||
|
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.hyperGNN_dim = 16
|
||||||
|
self.middle_dim = 2
|
||||||
|
|
||||||
|
# Fully connected layers
|
||||||
|
self.fc = nn.Sequential(OrderedDict([
|
||||||
|
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
|
||||||
|
('sigmoid1', nn.Sigmoid()),
|
||||||
|
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
|
||||||
|
('sigmoid2', nn.Sigmoid()),
|
||||||
|
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x, node_embeddings):
|
||||||
|
"""
|
||||||
|
Forward pass for the DGCN model.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- x: Input tensor of shape [B, N, C]
|
||||||
|
- node_embeddings: Node embeddings tensor of shape [N, D]
|
||||||
|
- connMtx: Connectivity matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- x_gconv: Output tensor of shape [B, N, dim_out]
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_num = node_embeddings[0].shape[1]
|
||||||
|
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
|
||||||
|
|
||||||
|
# Apply fully connected layers
|
||||||
|
filter = self.fc(x)
|
||||||
|
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
|
||||||
|
|
||||||
|
# Compute Laplacian
|
||||||
|
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
|
||||||
|
|
||||||
|
# Graph convolution
|
||||||
|
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
|
||||||
|
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
|
||||||
|
x_g = torch.stack([x_g1, x_g2], dim=1)
|
||||||
|
|
||||||
|
# Apply graph convolution weights and biases
|
||||||
|
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
|
||||||
|
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
||||||
|
|
||||||
|
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
|
||||||
|
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
|
||||||
|
|
||||||
|
return x_gconv
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_laplacian(graph, I, normalize=True):
|
||||||
|
"""
|
||||||
|
Compute the Laplacian of the graph.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- graph: Adjacency matrix of the graph, [N, N]
|
||||||
|
- I: Identity matrix
|
||||||
|
- normalize: Whether to use the normalized Laplacian
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- L: Graph Laplacian
|
||||||
|
"""
|
||||||
|
if normalize:
|
||||||
|
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
||||||
|
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
||||||
|
else:
|
||||||
|
graph = graph + I
|
||||||
|
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
|
||||||
|
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
@ -13,6 +13,7 @@ from model.STFGNN.STFGNN import STFGNN
|
||||||
from model.STSGCN.STSGCN import STSGCN
|
from model.STSGCN.STSGCN import STSGCN
|
||||||
from model.STGODE.STGODE import ODEGCN
|
from model.STGODE.STGODE import ODEGCN
|
||||||
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
||||||
|
from model.EXP.EXP import EXP
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -31,4 +32,5 @@ def model_selector(model):
|
||||||
case 'STSGCN': return STSGCN(model)
|
case 'STSGCN': return STSGCN(model)
|
||||||
case 'STGODE': return ODEGCN(model)
|
case 'STGODE': return ODEGCN(model)
|
||||||
case 'PDG2SEQ': return PDG2Seq(model)
|
case 'PDG2SEQ': return PDG2Seq(model)
|
||||||
|
case 'EXP': return EXP(model)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue