TrafficWheel/model/TWDGCN/TWDGCN.py

183 lines
5.6 KiB
Python
Executable File

import torch
import torch.nn as nn
from model.TWDGCN.DGCRU import DDGCRNCell
from model.TWDGCN.ConnectionMatrix import ConnectionMatrix
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__()
assert num_layers >= 1, "At least one DGCRM layer is required."
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.conn = ConnectionMatrix()
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList(
[
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0
else DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
]
)
def forward(self, x, init_state, node_embeddings):
"""
Forward pass of the DGCRM model.
Parameters:
- x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
conn_mtx = self.conn.get(x) # Connectivity matrix
for t in range(seq_length):
state = self.DGCRM_cells[i](
current_inputs[:, t, :, :],
state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]],
conn_mtx[t],
)
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack(
[
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
],
dim=0,
)
class TWDGCN(nn.Module):
def __init__(self, args):
super(TWDGCN, self).__init__()
self.num_node = args["num_nodes"]
self.input_dim = args["input_dim"]
self.hidden_dim = args["rnn_units"]
self.output_dim = args["output_dim"]
self.horizon = args["horizon"]
self.num_layers = args["num_layers"]
self.use_day = args["use_day"]
self.use_week = args["use_week"]
self.default_graph = args["default_graph"]
self.node_embeddings1 = nn.Parameter(
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
)
self.node_embeddings2 = nn.Parameter(
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args["embed_dim"]))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args["embed_dim"]))
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(
self.num_node,
self.input_dim,
self.hidden_dim,
args["cheb_order"],
args["embed_dim"],
self.num_layers,
)
self.encoder2 = DGCRM(
self.num_node,
self.input_dim,
self.hidden_dim,
args["cheb_order"],
args["embed_dim"],
self.num_layers,
)
# Predictor
self.end_conv1 = nn.Conv2d(
1,
self.horizon * self.output_dim,
kernel_size=(1, self.hidden_dim),
bias=True,
)
self.end_conv2 = nn.Conv2d(
1,
self.horizon * self.output_dim,
kernel_size=(1, self.hidden_dim),
bias=True,
)
self.end_conv3 = nn.Conv2d(
1,
self.horizon * self.output_dim,
kernel_size=(1, self.hidden_dim),
bias=True,
)
def forward(self, source):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day:
t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week:
d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state1 = self.encoder1.init_hidden(source.shape[0])
output, _ = self.encoder1(source, init_state1, node_embeddings)
output = self.dropout1(output[:, -1:, :, :])
output1 = self.end_conv1(output)
source1 = self.end_conv2(output)
source2 = source[:, -self.horizon :, ...] - source1
init_state2 = self.encoder2.init_hidden(source2.shape[0])
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2