import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): super(DGCRM, self).__init__() assert num_layers >= 1, "At least one DGCRM layer is required." self.node_num = node_num self.input_dim = dim_in self.num_layers = num_layers # Initialize DGCRM cells self.DGCRM_cells = nn.ModuleList( [ DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim) if i == 0 else DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers) ] ) def forward(self, x, init_state, node_embeddings): """ Forward pass of the DGCRM model. Parameters: - x: Input tensor of shape (B, T, N, D) - init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim) - node_embeddings: Node embeddings """ assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim seq_length = x.shape[1] current_inputs = x output_hidden = [] for i in range(self.num_layers): state = init_state[i] inner_states = [] for t in range(seq_length): state = self.DGCRM_cells[i]( current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]], ) inner_states.append(state) output_hidden.append(state) current_inputs = torch.stack(inner_states, dim=1) return current_inputs, output_hidden def init_hidden(self, batch_size): """ Initialize hidden states for DGCRM layers. Parameters: - batch_size: Size of the batch Returns: - Initial hidden states tensor """ return torch.stack( [ self.DGCRM_cells[i].init_hidden_state(batch_size) for i in range(self.num_layers) ], dim=0, ) class DDGCRN(nn.Module): def __init__(self, args): super(DDGCRN, self).__init__() self.num_node = args["num_nodes"] self.input_dim = args["input_dim"] self.hidden_dim = args["rnn_units"] self.output_dim = args["output_dim"] self.horizon = args["horizon"] self.num_layers = args["num_layers"] self.use_day = args["use_day"] self.use_week = args["use_week"] self.default_graph = args["default_graph"] self.node_embeddings1 = nn.Parameter( torch.randn(self.num_node, args["embed_dim"]), requires_grad=True ) self.node_embeddings2 = nn.Parameter( torch.randn(self.num_node, args["embed_dim"]), requires_grad=True ) self.T_i_D_emb = nn.Parameter(torch.empty(288, args["embed_dim"])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args["embed_dim"])) self.dropout1 = nn.Dropout(p=0.1) self.dropout2 = nn.Dropout(p=0.1) self.encoder1 = DGCRM( self.num_node, self.input_dim, self.hidden_dim, args["cheb_order"], args["embed_dim"], self.num_layers, ) self.encoder2 = DGCRM( self.num_node, self.input_dim, self.hidden_dim, args["cheb_order"], args["embed_dim"], self.num_layers, ) # Predictor self.end_conv1 = nn.Conv2d( 1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True, ) self.end_conv2 = nn.Conv2d( 1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True, ) self.end_conv3 = nn.Conv2d( 1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True, ) def forward(self, source, **kwargs): """ Forward pass of the DDGCRN model. Parameters: - source: Input tensor of shape (B, T_1, N, D) - mode: Control mode for the forward pass Returns: - Output tensor """ node_embedding1 = self.node_embeddings1 if self.use_day: t_i_d_data = source[..., 1] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()] node_embedding1 = node_embedding1 * T_i_D_emb if self.use_week: d_i_w_data = source[..., 2] D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()] node_embedding1 = node_embedding1 * D_i_W_emb node_embeddings = [node_embedding1, self.node_embeddings1] source = source[..., 0].unsqueeze(-1) init_state1 = self.encoder1.init_hidden(source.shape[0]) output, _ = self.encoder1(source, init_state1, node_embeddings) output = self.dropout1(output[:, -1:, :, :]) output1 = self.end_conv1(output) source1 = self.end_conv2(output) source2 = source[:, -self.horizon :, ...] - source1 init_state2 = self.encoder2.init_hidden(source2.shape[0]) output2, _ = self.encoder2(source2, init_state2, node_embeddings) output2 = self.dropout2(output2[:, -1:, :, :]) output2 = self.end_conv3(output2) return output1 + output2 class DDGCRNCell(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): super(DDGCRNCell, self).__init__() self.node_num = node_num self.hidden_dim = dim_out self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) def forward(self, x, state, node_embeddings): state = state.to(x.device) input_and_state = torch.cat((x, state), dim=-1) z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) z, r = torch.split(z_r, self.hidden_dim, dim=-1) candidate = torch.cat((x, z * state), dim=-1) hc = torch.tanh(self.update(candidate, node_embeddings)) h = r * state + (1 - r) * hc return h def init_hidden_state(self, batch_size): return torch.zeros(batch_size, self.node_num, self.hidden_dim) class DGCN(nn.Module): def __init__(self, dim_in, dim_out, cheb_k, embed_dim): super(DGCN, self).__init__() self.cheb_k = cheb_k self.embed_dim = embed_dim # Initialize parameters self.weights_pool = nn.Parameter( torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out) ) self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out)) # Hyperparameters self.hyperGNN_dim = 16 self.middle_dim = 2 # Fully connected layers self.fc = nn.Sequential( OrderedDict( [ ("fc1", nn.Linear(dim_in, self.hyperGNN_dim)), ("sigmoid1", nn.Sigmoid()), ("fc2", nn.Linear(self.hyperGNN_dim, self.middle_dim)), ("sigmoid2", nn.Sigmoid()), ("fc3", nn.Linear(self.middle_dim, self.embed_dim)), ] ) ) def forward(self, x, node_embeddings): """ Forward pass for the DGCN model. Parameters: - x: Input tensor of shape [B, N, C] - node_embeddings: Node embeddings tensor of shape [N, D] - connMtx: Connectivity matrix Returns: - x_gconv: Output tensor of shape [B, N, dim_out] """ node_num = node_embeddings[0].shape[1] supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix # Apply fully connected layers filter = self.fc(x) nodevec = torch.tanh( torch.mul(node_embeddings[0], filter) ) # Element-wise multiplication # Compute Laplacian supports2 = self.get_laplacian( F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1 ) # Graph convolution x_g1 = torch.einsum("nm,bmc->bnc", supports1, x) x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x) x_g = torch.stack([x_g1, x_g2], dim=1) # Apply graph convolution weights and biases weights = torch.einsum("nd,dkio->nkio", node_embeddings[1], self.weights_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool) x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions x_gconv = ( torch.einsum("bnki,nkio->bno", x_g, weights) + bias ) # Graph convolution operation return x_gconv @staticmethod def get_laplacian(graph, I, normalize=True): """ Compute the Laplacian of the graph. Parameters: - graph: Adjacency matrix of the graph, [N, N] - I: Identity matrix - normalize: Whether to use the normalized Laplacian Returns: - L: Graph Laplacian """ if normalize: D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) else: graph = graph + I D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) return L