TrafficWheel/model/DDGCRN/DDGCRN.py

249 lines
9.0 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__()
assert num_layers >= 1, 'At least one DGCRM layer is required.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
# Initialize DGCRM cells
self.DGCRM_cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)
if i == 0 else
DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings):
"""
Forward pass of the DGCRM model.
Parameters:
- x: Input tensor of shape (B, T, N, D)
- init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim)
- node_embeddings: Node embeddings
"""
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
"""
Initialize hidden states for DGCRM layers.
Parameters:
- batch_size: Size of the batch
Returns:
- Initial hidden states tensor
"""
return torch.stack([
self.DGCRM_cells[i].init_hidden_state(batch_size)
for i in range(self.num_layers)
], dim=0)
class DDGCRN(nn.Module):
def __init__(self, args):
super(DDGCRN, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_day = args['use_day']
self.use_week = args['use_week']
self.default_graph = args['default_graph']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.dropout1 = nn.Dropout(p=0.1)
self.dropout2 = nn.Dropout(p=0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
# Predictor
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source):
"""
Forward pass of the DDGCRN model.
Parameters:
- source: Input tensor of shape (B, T_1, N, D)
- mode: Control mode for the forward pass
Returns:
- Output tensor
"""
node_embedding1 = self.node_embeddings1
if self.use_day:
t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()]
node_embedding1 = node_embedding1 * T_i_D_emb
if self.use_week:
d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()]
node_embedding1 = node_embedding1 * D_i_W_emb
node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state1 = self.encoder1.init_hidden(source.shape[0])
output, _ = self.encoder1(source, init_state1, node_embeddings)
output = self.dropout1(output[:, -1:, :, :])
output1 = self.end_conv1(output)
source1 = self.end_conv2(output)
source2 = source[:, -self.horizon:, ...] - source1
init_state2 = self.encoder2.init_hidden(source2.shape[0])
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
output2 = self.dropout2(output2[:, -1:, :, :])
output2 = self.end_conv3(output2)
return output1 + output2
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings):
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z * state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(DGCN, self).__init__()
self.cheb_k = cheb_k
self.embed_dim = embed_dim
# Initialize parameters
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
# Hyperparameters
self.hyperGNN_dim = 16
self.middle_dim = 2
# Fully connected layers
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, self.embed_dim))
]))
def forward(self, x, node_embeddings):
"""
Forward pass for the DGCN model.
Parameters:
- x: Input tensor of shape [B, N, C]
- node_embeddings: Node embeddings tensor of shape [N, D]
- connMtx: Connectivity matrix
Returns:
- x_gconv: Output tensor of shape [B, N, dim_out]
"""
node_num = node_embeddings[0].shape[1]
supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix
# Apply fully connected layers
filter = self.fc(x)
nodevec = torch.tanh(torch.mul(node_embeddings[0], filter)) # Element-wise multiplication
# Compute Laplacian
supports2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1)
# Graph convolution
x_g1 = torch.einsum("nm,bmc->bnc", supports1, x)
x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x)
x_g = torch.stack([x_g1, x_g2], dim=1)
# Apply graph convolution weights and biases
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # Graph convolution operation
return x_gconv
@staticmethod
def get_laplacian(graph, I, normalize=True):
"""
Compute the Laplacian of the graph.
Parameters:
- graph: Adjacency matrix of the graph, [N, N]
- I: Identity matrix
- normalize: Whether to use the normalized Laplacian
Returns:
- L: Graph Laplacian
"""
if normalize:
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
else:
graph = graph + I
D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt)
return L