import torch import torch.nn as nn from model.PDG2SEQ.PDG2Seq_DGCN import PDG2Seq_GCN from collections import OrderedDict import torch.nn.functional as F class FC(nn.Module): def __init__(self, dim_in, dim_out): super(FC, self).__init__() self.hyperGNN_dim = 16 self.middle_dim = 2 self.mlp = nn.Sequential( # 疑问,这里为什么要用三层linear来做,为什么激活函数是sigmoid OrderedDict( [ ("fc1", nn.Linear(dim_in, self.hyperGNN_dim)), # ('sigmoid1', nn.ReLU()), ("sigmoid1", nn.Sigmoid()), ("fc2", nn.Linear(self.hyperGNN_dim, self.middle_dim)), # ('sigmoid1', nn.ReLU()), ("sigmoid2", nn.Sigmoid()), ("fc3", nn.Linear(self.middle_dim, dim_out)), ] ) ) def forward(self, x): ho = self.mlp(x) return ho class PDG2SeqCell( nn.Module ): # 这个模块只进行GRU内部的更新,所以需要修改的是AGCN里面的东西 def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim): super(PDG2SeqCell, self).__init__() self.node_num = node_num self.hidden_dim = dim_out self.gate = PDG2Seq_GCN( dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim, time_dim ) self.update = PDG2Seq_GCN( dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim, time_dim ) self.fc1 = FC(dim_in + self.hidden_dim, time_dim) self.fc2 = FC(dim_in + self.hidden_dim, time_dim) def forward(self, x, state, node_embeddings): # x: B, num_nodes, input_dim # state: B, num_nodes, hidden_dim state = state.to(x.device) input_and_state = torch.cat((x, state), dim=-1) filter1 = self.fc1(input_and_state) filter2 = self.fc2(input_and_state) nodevec1 = torch.tanh( torch.einsum("bd,bnd->bnd", node_embeddings[0], filter1) ) # [B,N,dim_in] nodevec2 = torch.tanh( torch.einsum("bd,bnd->bnd", node_embeddings[1], filter2) ) # [B,N,dim_in] adj = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul( nodevec2, nodevec1.transpose(2, 1) ) adj1 = PDG2SeqCell.preprocessing(F.relu(adj)) adj2 = PDG2SeqCell.preprocessing(F.relu(-adj.transpose(-2, -1))) adj = [adj1, adj2] z_r = torch.sigmoid(self.gate(input_and_state, adj, node_embeddings[2])) z, r = torch.split(z_r, self.hidden_dim, dim=-1) candidate = torch.cat((x, z * state), dim=-1) hc = torch.tanh(self.update(candidate, adj, node_embeddings[2])) h = r * state + (1 - r) * hc return h def init_hidden_state(self, batch_size): return torch.zeros(batch_size, self.node_num, self.hidden_dim) def preprocessing(adj): # 处理动态矩阵可能不含有对角线元素的问题 num_nodes = adj.shape[-1] adj = adj + torch.eye(num_nodes).to(adj.device) x = torch.unsqueeze(adj.sum(-1), -1) adj = ( adj / x ) # D = torch.diag_embed(torch.sum(adj, dim=-1) ** (-1)) adj =torch.matmul(D, adj) return adj