from __future__ import division import torch import torch.nn as nn from torch.nn import init import numbers import torch.nn.functional as F class nconv(nn.Module): def __init__(self): super(nconv,self).__init__() def forward(self,x, A): x = torch.einsum('ncwl,vw->ncvl',(x,A)) return x.contiguous() class dy_nconv(nn.Module): def __init__(self): super(dy_nconv,self).__init__() def forward(self,x, A): x = torch.einsum('ncvl,nvwl->ncwl',(x,A)) return x.contiguous() class linear(nn.Module): def __init__(self,c_in,c_out,bias=True): super(linear,self).__init__() self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias) def forward(self,x): return self.mlp(x) class prop(nn.Module): def __init__(self,c_in,c_out,gdep,dropout,alpha): super(prop, self).__init__() self.nconv = nconv() self.mlp = linear(c_in,c_out) self.gdep = gdep self.dropout = dropout self.alpha = alpha def forward(self,x,adj): adj = adj + torch.eye(adj.size(0)).to(x.device) d = adj.sum(1) h = x dv = d a = adj / dv.view(-1, 1) for i in range(self.gdep): h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) ho = self.mlp(h) return ho class mixprop(nn.Module): def __init__(self,c_in,c_out,gdep,dropout,alpha): super(mixprop, self).__init__() self.nconv = nconv() self.mlp = linear((gdep+1)*c_in,c_out) self.gdep = gdep self.dropout = dropout self.alpha = alpha def forward(self,x,adj): adj = adj + torch.eye(adj.size(0)).to(x.device) d = adj.sum(1) h = x out = [h] a = adj / d.view(-1, 1) for i in range(self.gdep): h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) out.append(h) ho = torch.cat(out,dim=1) ho = self.mlp(ho) return ho class dy_mixprop(nn.Module): def __init__(self,c_in,c_out,gdep,dropout,alpha): super(dy_mixprop, self).__init__() self.nconv = dy_nconv() self.mlp1 = linear((gdep+1)*c_in,c_out) self.mlp2 = linear((gdep+1)*c_in,c_out) self.gdep = gdep self.dropout = dropout self.alpha = alpha self.lin1 = linear(c_in,c_in) self.lin2 = linear(c_in,c_in) def forward(self,x): #adj = adj + torch.eye(adj.size(0)).to(x.device) #d = adj.sum(1) x1 = torch.tanh(self.lin1(x)) x2 = torch.tanh(self.lin2(x)) adj = self.nconv(x1.transpose(2,1),x2) adj0 = torch.softmax(adj, dim=2) adj1 = torch.softmax(adj.transpose(2,1), dim=2) h = x out = [h] for i in range(self.gdep): h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0) out.append(h) ho = torch.cat(out,dim=1) ho1 = self.mlp1(ho) h = x out = [h] for i in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) out.append(h) ho = torch.cat(out, dim=1) ho2 = self.mlp2(ho) return ho1+ho2 class dilated_1D(nn.Module): def __init__(self, cin, cout, dilation_factor=2): super(dilated_1D, self).__init__() self.tconv = nn.ModuleList() self.kernel_set = [2,3,6,7] self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor)) def forward(self,input): x = self.tconv(input) return x class dilated_inception(nn.Module): def __init__(self, cin, cout, dilation_factor=2): super(dilated_inception, self).__init__() self.tconv = nn.ModuleList() self.kernel_set = [2,3,6,7] cout = int(cout/len(self.kernel_set)) for kern in self.kernel_set: self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor))) def forward(self,input): x = [] for i in range(len(self.kernel_set)): x.append(self.tconv[i](input)) for i in range(len(self.kernel_set)): x[i] = x[i][...,-x[-1].size(3):] x = torch.cat(x,dim=1) return x class graph_constructor(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super(graph_constructor, self).__init__() self.nnodes = nnodes if static_feat is not None: xd = static_feat.shape[1] self.lin1 = nn.Linear(xd, dim) self.lin2 = nn.Linear(xd, dim) else: self.emb1 = nn.Embedding(nnodes, dim) self.emb2 = nn.Embedding(nnodes, dim) self.lin1 = nn.Linear(dim,dim) self.lin2 = nn.Linear(dim,dim) self.device = device self.k = k self.dim = dim self.alpha = alpha self.static_feat = static_feat def forward(self, idx): if self.static_feat is None: nodevec1 = self.emb1(idx) nodevec2 = self.emb2(idx) else: nodevec1 = self.static_feat[idx,:] nodevec2 = nodevec1 nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) adj = F.relu(torch.tanh(self.alpha*a)) mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) mask.fill_(float('0')) s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) mask.scatter_(1,t1,s1.fill_(1)) adj = adj*mask return adj def fullA(self, idx): if self.static_feat is None: nodevec1 = self.emb1(idx) nodevec2 = self.emb2(idx) else: nodevec1 = self.static_feat[idx,:] nodevec2 = nodevec1 nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) adj = F.relu(torch.tanh(self.alpha*a)) return adj class graph_global(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super(graph_global, self).__init__() self.nnodes = nnodes self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device) def forward(self, idx): return F.relu(self.A) class graph_undirected(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super(graph_undirected, self).__init__() self.nnodes = nnodes if static_feat is not None: xd = static_feat.shape[1] self.lin1 = nn.Linear(xd, dim) else: self.emb1 = nn.Embedding(nnodes, dim) self.lin1 = nn.Linear(dim,dim) self.device = device self.k = k self.dim = dim self.alpha = alpha self.static_feat = static_feat def forward(self, idx): if self.static_feat is None: nodevec1 = self.emb1(idx) nodevec2 = self.emb1(idx) else: nodevec1 = self.static_feat[idx,:] nodevec2 = nodevec1 nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2)) a = torch.mm(nodevec1, nodevec2.transpose(1,0)) adj = F.relu(torch.tanh(self.alpha*a)) mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) mask.fill_(float('0')) s1,t1 = adj.topk(self.k,1) mask.scatter_(1,t1,s1.fill_(1)) adj = adj*mask return adj class graph_directed(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super(graph_directed, self).__init__() self.nnodes = nnodes if static_feat is not None: xd = static_feat.shape[1] self.lin1 = nn.Linear(xd, dim) self.lin2 = nn.Linear(xd, dim) else: self.emb1 = nn.Embedding(nnodes, dim) self.emb2 = nn.Embedding(nnodes, dim) self.lin1 = nn.Linear(dim,dim) self.lin2 = nn.Linear(dim,dim) self.device = device self.k = k self.dim = dim self.alpha = alpha self.static_feat = static_feat def forward(self, idx): if self.static_feat is None: nodevec1 = self.emb1(idx) nodevec2 = self.emb2(idx) else: nodevec1 = self.static_feat[idx,:] nodevec2 = nodevec1 nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) a = torch.mm(nodevec1, nodevec2.transpose(1,0)) adj = F.relu(torch.tanh(self.alpha*a)) mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) mask.fill_(float('0')) s1,t1 = adj.topk(self.k,1) mask.scatter_(1,t1,s1.fill_(1)) adj = adj*mask return adj class LayerNorm(nn.Module): __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): if self.elementwise_affine: init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input, idx): if self.elementwise_affine: return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) else: return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) def extra_repr(self): return '{normalized_shape}, eps={eps}, ' \ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)