import torch import torch.nn as nn from model.TWDGCN.DGCRU import DDGCRNCell from model.TWDGCN.ConnectionMatrix import ConnectionMatrix class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): super(DGCRM, self).__init__() assert num_layers >= 1, 'At least one DGCRM layer is required.' self.node_num = node_num self.input_dim = dim_in self.num_layers = num_layers self.conn = ConnectionMatrix() # Initialize DGCRM cells self.DGCRM_cells = nn.ModuleList([ DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim) if i == 0 else DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers) ]) def forward(self, x, init_state, node_embeddings): """ Forward pass of the DGCRM model. Parameters: - x: Input tensor of shape (B, T, N, D) - init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim) - node_embeddings: Node embeddings """ assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim seq_length = x.shape[1] current_inputs = x output_hidden = [] for i in range(self.num_layers): state = init_state[i] inner_states = [] conn_mtx = self.conn.get(x) # Connectivity matrix for t in range(seq_length): state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]], conn_mtx[t]) inner_states.append(state) output_hidden.append(state) current_inputs = torch.stack(inner_states, dim=1) return current_inputs, output_hidden def init_hidden(self, batch_size): """ Initialize hidden states for DGCRM layers. Parameters: - batch_size: Size of the batch Returns: - Initial hidden states tensor """ return torch.stack([ self.DGCRM_cells[i].init_hidden_state(batch_size) for i in range(self.num_layers) ], dim=0) class TWDGCN(nn.Module): def __init__(self, args): super(TWDGCN, self).__init__() self.num_node = args['num_nodes'] self.input_dim = args['input_dim'] self.hidden_dim = args['rnn_units'] self.output_dim = args['output_dim'] self.horizon = args['horizon'] self.num_layers = args['num_layers'] self.use_day = args['use_day'] self.use_week = args['use_week'] self.default_graph = args['default_graph'] self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True) self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim'])) self.dropout1 = nn.Dropout(p=0.1) self.dropout2 = nn.Dropout(p=0.1) self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.num_layers) self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.num_layers) # Predictor self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) def forward(self, source): """ Forward pass of the DDGCRN model. Parameters: - source: Input tensor of shape (B, T_1, N, D) - mode: Control mode for the forward pass Returns: - Output tensor """ node_embedding1 = self.node_embeddings1 if self.use_day: t_i_d_data = source[..., 1] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()] node_embedding1 = node_embedding1 * T_i_D_emb if self.use_week: d_i_w_data = source[..., 2] D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()] node_embedding1 = node_embedding1 * D_i_W_emb node_embeddings = [node_embedding1, self.node_embeddings1] source = source[..., 0].unsqueeze(-1) init_state1 = self.encoder1.init_hidden(source.shape[0]) output, _ = self.encoder1(source, init_state1, node_embeddings) output = self.dropout1(output[:, -1:, :, :]) output1 = self.end_conv1(output) source1 = self.end_conv2(output) source2 = source[:, -self.horizon:, ...] - source1 init_state2 = self.encoder2.init_hidden(source2.shape[0]) output2, _ = self.encoder2(source2, init_state2, node_embeddings) output2 = self.dropout2(output2[:, -1:, :, :]) output2 = self.end_conv3(output2) return output1 + output2