import torch, torch.nn as nn, torch.nn.functional as F from collections import OrderedDict class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): super().__init__() self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers self.cells = nn.ModuleList([ DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers) ]) def forward(self, x, init_state, node_embeddings): assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim for i in range(self.num_layers): state, inner = init_state[i].to(x.device), [] for t in range(x.shape[1]): state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) inner.append(state) init_state[i] = state x = torch.stack(inner, dim=1) return x, init_state def init_hidden(self, bs): return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0) class EXPB(nn.Module): def __init__(self, args): super().__init__() self.patch_size = args.get('patch_size', 1) self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units'] self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers'] self.use_day, self.use_week = args['use_day'], args['use_week'] self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim'])) self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim'])) self.drop = nn.Dropout(0.1) self.encoder = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'], self.num_layers) self.base_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim)) self.res_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim + 1)) def forward(self, source): # source: (B, T, N, D_total),第0通道为主观测,第1、2通道为时间编码 B, T, N, D_total = source.shape p = self.patch_size num_patches = T // p source = source[:, :num_patches * p, :, :].view(B, num_patches, p, N, D_total) # 对主观测通道取均值,并转置为 (B, num_patches, N, 1) inp = source[..., 0].mean(dim=2, keepdim=True).permute(0, 1, 3, 2) # 每个 patch 最后时刻的时间编码 time_day = source[:, :, -1, :, 1] # (B, num_patches, N) time_week = source[:, :, -1, :, 2] # (B, num_patches, N) patched_source = torch.cat([inp, time_day.unsqueeze(-1), time_week.unsqueeze(-1)], dim=-1) node_embed = self.node_embeddings1 if self.use_day: node_embed = node_embed * self.T_i_D_emb[(patched_source[..., 1] * 288).long()] if self.use_week: node_embed = node_embed * self.D_i_W_emb[patched_source[..., 2].long()] node_embeddings = [node_embed, self.node_embeddings1] init = self.encoder.init_hidden(B) enc_out, _ = self.encoder(inp, init, node_embeddings) rep = self.drop(enc_out[:, -1:, :, :]) base = self.base_conv(rep) res_in = torch.cat([rep, inp[:, -1:, :, :]], dim=-1) res = self.res_conv(res_in) out = base + res out = out.squeeze(-1).view(B, self.horizon, self.output_dim, N).permute(0, 1, 3, 2) return out class DDGCRNCell(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): super().__init__() self.node_num, self.hidden_dim = node_num, dim_out self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num) self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num) self.ln = nn.LayerNorm(dim_out) def forward(self, x, state, node_embeddings): inp = torch.cat((x, state), -1) z_r = torch.sigmoid(self.gate(inp, node_embeddings)) z, r = torch.split(z_r, self.hidden_dim, -1) hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings)) out = r * state + (1 - r) * hc return self.ln(out) def init_hidden_state(self, bs): return torch.zeros(bs, self.node_num, self.hidden_dim) class DGCN(nn.Module): def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes): super().__init__() self.cheb_k, self.embed_dim = cheb_k, embed_dim self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) self.bias = nn.Parameter(torch.FloatTensor(dim_out)) self.fc = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(dim_in, 16)), ('sigmoid1', nn.Sigmoid()), ('fc2', nn.Linear(16, 2)), ('sigmoid2', nn.Sigmoid()), ('fc3', nn.Linear(2, embed_dim)) ])) self.register_buffer('eye', torch.eye(num_nodes)) def forward(self, x, node_embeddings): supp1 = self.eye.to(node_embeddings[0].device) filt = self.fc(x) nodevec = torch.tanh(node_embeddings[0] * filt) supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1) x_g = torch.stack([ torch.einsum("nm,bmc->bnc", supp1, x), torch.einsum("bnm,bmc->bnc", supp2, x) ], dim=1) weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool) return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias @staticmethod def get_laplacian(graph, I, normalize=True): D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul( torch.matmul(D_inv, graph + I), D_inv)