183 lines
5.6 KiB
Python
Executable File
183 lines
5.6 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from model.TWDGCN.DGCRU import DDGCRNCell
|
|
from model.TWDGCN.ConnectionMatrix import ConnectionMatrix
|
|
|
|
|
|
class DGCRM(nn.Module):
|
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
|
super(DGCRM, self).__init__()
|
|
assert num_layers >= 1, "At least one DGCRM layer is required."
|
|
|
|
self.node_num = node_num
|
|
self.input_dim = dim_in
|
|
self.num_layers = num_layers
|
|
self.conn = ConnectionMatrix()
|
|
|
|
# Initialize DGCRM cells
|
|
self.DGCRM_cells = nn.ModuleList(
|
|
[
|
|
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
|
|
if i == 0
|
|
else DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, x, init_state, node_embeddings):
|
|
"""
|
|
Forward pass of the DGCRM model.
|
|
|
|
Parameters:
|
|
- x: Input tensor of shape (B, T, N, D)
|
|
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
|
|
- node_embeddings: Node embeddings
|
|
"""
|
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
|
|
|
seq_length = x.shape[1]
|
|
current_inputs = x
|
|
output_hidden = []
|
|
|
|
for i in range(self.num_layers):
|
|
state = init_state[i]
|
|
inner_states = []
|
|
conn_mtx = self.conn.get(x) # Connectivity matrix
|
|
|
|
for t in range(seq_length):
|
|
state = self.DGCRM_cells[i](
|
|
current_inputs[:, t, :, :],
|
|
state,
|
|
[node_embeddings[0][:, t, :, :], node_embeddings[1]],
|
|
conn_mtx[t],
|
|
)
|
|
inner_states.append(state)
|
|
|
|
output_hidden.append(state)
|
|
current_inputs = torch.stack(inner_states, dim=1)
|
|
|
|
return current_inputs, output_hidden
|
|
|
|
def init_hidden(self, batch_size):
|
|
"""
|
|
Initialize hidden states for DGCRM layers.
|
|
|
|
Parameters:
|
|
- batch_size: Size of the batch
|
|
|
|
Returns:
|
|
- Initial hidden states tensor
|
|
"""
|
|
return torch.stack(
|
|
[
|
|
self.DGCRM_cells[i].init_hidden_state(batch_size)
|
|
for i in range(self.num_layers)
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
|
|
class TWDGCN(nn.Module):
|
|
def __init__(self, args):
|
|
super(TWDGCN, self).__init__()
|
|
|
|
self.num_node = args["num_nodes"]
|
|
self.input_dim = args["input_dim"]
|
|
self.hidden_dim = args["rnn_units"]
|
|
self.output_dim = args["output_dim"]
|
|
self.horizon = args["horizon"]
|
|
self.num_layers = args["num_layers"]
|
|
self.use_day = args["use_day"]
|
|
self.use_week = args["use_week"]
|
|
self.default_graph = args["default_graph"]
|
|
|
|
self.node_embeddings1 = nn.Parameter(
|
|
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
|
|
)
|
|
self.node_embeddings2 = nn.Parameter(
|
|
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
|
|
)
|
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args["embed_dim"]))
|
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args["embed_dim"]))
|
|
|
|
self.dropout1 = nn.Dropout(p=0.1)
|
|
self.dropout2 = nn.Dropout(p=0.1)
|
|
|
|
self.encoder1 = DGCRM(
|
|
self.num_node,
|
|
self.input_dim,
|
|
self.hidden_dim,
|
|
args["cheb_order"],
|
|
args["embed_dim"],
|
|
self.num_layers,
|
|
)
|
|
self.encoder2 = DGCRM(
|
|
self.num_node,
|
|
self.input_dim,
|
|
self.hidden_dim,
|
|
args["cheb_order"],
|
|
args["embed_dim"],
|
|
self.num_layers,
|
|
)
|
|
|
|
# Predictor
|
|
self.end_conv1 = nn.Conv2d(
|
|
1,
|
|
self.horizon * self.output_dim,
|
|
kernel_size=(1, self.hidden_dim),
|
|
bias=True,
|
|
)
|
|
self.end_conv2 = nn.Conv2d(
|
|
1,
|
|
self.horizon * self.output_dim,
|
|
kernel_size=(1, self.hidden_dim),
|
|
bias=True,
|
|
)
|
|
self.end_conv3 = nn.Conv2d(
|
|
1,
|
|
self.horizon * self.output_dim,
|
|
kernel_size=(1, self.hidden_dim),
|
|
bias=True,
|
|
)
|
|
|
|
def forward(self, source):
|
|
"""
|
|
Forward pass of the DDGCRN model.
|
|
|
|
Parameters:
|
|
- source: Input tensor of shape (B, T_1, N, D)
|
|
- mode: Control mode for the forward pass
|
|
|
|
Returns:
|
|
- Output tensor
|
|
"""
|
|
node_embedding1 = self.node_embeddings1
|
|
|
|
if self.use_day:
|
|
t_i_d_data = source[..., 1]
|
|
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
|
|
node_embedding1 = node_embedding1 * T_i_D_emb
|
|
|
|
if self.use_week:
|
|
d_i_w_data = source[..., 2]
|
|
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
|
|
node_embedding1 = node_embedding1 * D_i_W_emb
|
|
|
|
node_embeddings = [node_embedding1, self.node_embeddings1]
|
|
source = source[..., 0].unsqueeze(-1)
|
|
|
|
init_state1 = self.encoder1.init_hidden(source.shape[0])
|
|
output, _ = self.encoder1(source, init_state1, node_embeddings)
|
|
output = self.dropout1(output[:, -1:, :, :])
|
|
output1 = self.end_conv1(output)
|
|
|
|
source1 = self.end_conv2(output)
|
|
source2 = source[:, -self.horizon :, ...] - source1
|
|
|
|
init_state2 = self.encoder2.init_hidden(source2.shape[0])
|
|
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
|
|
output2 = self.dropout2(output2[:, -1:, :, :])
|
|
output2 = self.end_conv3(output2)
|
|
|
|
return output1 + output2
|