73 lines
3.2 KiB
Python
Executable File
73 lines
3.2 KiB
Python
Executable File
import torch
|
||
import torch.nn as nn
|
||
from model.PDG2SEQ.PDG2Seq_DGCN import PDG2Seq_GCN
|
||
from collections import OrderedDict
|
||
import torch.nn.functional as F
|
||
class FC(nn.Module):
|
||
def __init__(self, dim_in, dim_out):
|
||
super(FC, self).__init__()
|
||
self.hyperGNN_dim = 16
|
||
self.middle_dim = 2
|
||
self.mlp=nn.Sequential( #疑问,这里为什么要用三层linear来做,为什么激活函数是sigmoid
|
||
OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
|
||
#('sigmoid1', nn.ReLU()),
|
||
('sigmoid1', nn.Sigmoid()),
|
||
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
|
||
#('sigmoid1', nn.ReLU()),
|
||
('sigmoid2', nn.Sigmoid()),
|
||
('fc3', nn.Linear(self.middle_dim, dim_out))]))
|
||
|
||
def forward(self, x):
|
||
|
||
ho = self.mlp(x)
|
||
|
||
return ho
|
||
|
||
class PDG2SeqCell(nn.Module): #这个模块只进行GRU内部的更新,所以需要修改的是AGCN里面的东西
|
||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim):
|
||
super(PDG2SeqCell, self).__init__()
|
||
self.node_num = node_num
|
||
self.hidden_dim = dim_out
|
||
self.gate = PDG2Seq_GCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim, time_dim)
|
||
self.update = PDG2Seq_GCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim, time_dim)
|
||
self.fc1 = FC(dim_in + self.hidden_dim, time_dim)
|
||
self.fc2 = FC(dim_in + self.hidden_dim, time_dim)
|
||
|
||
def forward(self, x, state, node_embeddings):
|
||
#x: B, num_nodes, input_dim
|
||
#state: B, num_nodes, hidden_dim
|
||
state = state.to(x.device)
|
||
input_and_state = torch.cat((x, state), dim=-1)
|
||
filter1 = self.fc1(input_and_state)
|
||
filter2 = self.fc2(input_and_state)
|
||
|
||
nodevec1 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[0], filter1)) #[B,N,dim_in]
|
||
nodevec2 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[1], filter2)) # [B,N,dim_in]
|
||
|
||
|
||
adj = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul(
|
||
nodevec2, nodevec1.transpose(2, 1))
|
||
|
||
adj1 = PDG2SeqCell.preprocessing(F.relu(adj))
|
||
adj2 = PDG2SeqCell.preprocessing(F.relu(-adj.transpose(-2, -1)))
|
||
|
||
|
||
adj = [adj1, adj2]
|
||
|
||
|
||
z_r = torch.sigmoid(self.gate(input_and_state, adj, node_embeddings[2]))
|
||
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
||
candidate = torch.cat((x, z*state), dim=-1)
|
||
hc = torch.tanh(self.update(candidate, adj, node_embeddings[2]))
|
||
h = r*state + (1-r)*hc
|
||
return h
|
||
|
||
def init_hidden_state(self, batch_size):
|
||
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
||
|
||
def preprocessing(adj): #处理动态矩阵可能不含有对角线元素的问题
|
||
num_nodes= adj.shape[-1]
|
||
adj = adj + torch.eye(num_nodes).to(adj.device)
|
||
x= torch.unsqueeze(adj.sum(-1), -1)
|
||
adj = adj / x # D = torch.diag_embed(torch.sum(adj, dim=-1) ** (-1)) adj =torch.matmul(D, adj)
|
||
return adj |