import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import sys class nconv(nn.Module): def __init__(self): super(nconv, self).__init__() def forward(self, x, A): x = torch.einsum("ncvl,vw->ncwl", (x, A)) return x.contiguous() class linear(nn.Module): def __init__(self, c_in, c_out): super(linear, self).__init__() self.mlp = torch.nn.Conv2d( c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True ) def forward(self, x): return self.mlp(x) class gcn(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super(gcn, self).__init__() self.nconv = nconv() c_in = (order * support_len + 1) * c_in self.mlp = linear(c_in, c_out) self.dropout = dropout self.order = order def forward(self, x, support): out = [x] for a in support: x1 = self.nconv(x, a) out.append(x1) for k in range(2, self.order + 1): x2 = self.nconv(x1, a) out.append(x2) x1 = x2 h = torch.cat(out, dim=1) h = self.mlp(h) h = F.dropout(h, self.dropout, training=self.training) return h class gwnet(nn.Module): def __init__(self, args): super(gwnet, self).__init__() self.dropout = args["dropout"] self.blocks = args["blocks"] self.layers = args["layers"] self.gcn_bool = args["gcn_bool"] self.addaptadj = args["addaptadj"] self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.bn = nn.ModuleList() self.gconv = nn.ModuleList() self.start_conv = nn.Conv2d( in_channels=args["in_dim"], out_channels=args["residual_channels"], kernel_size=(1, 1), ) self.supports = args.get("supports", None) receptive_field = 1 self.supports_len = 0 if self.supports is not None: self.supports_len += len(self.supports) if self.gcn_bool and self.addaptadj: aptinit = args.get("aptinit", None) if aptinit is None: if self.supports is None: self.supports = [] self.nodevec1 = nn.Parameter( torch.randn(args["num_nodes"], 10).to(args["device"]), requires_grad=True, ).to(args["device"]) self.nodevec2 = nn.Parameter( torch.randn(10, args["num_nodes"]).to(args["device"]), requires_grad=True, ).to(args["device"]) self.supports_len += 1 else: if self.supports is None: self.supports = [] m, p, n = torch.svd(aptinit) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to( args["device"] ) self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to( args["device"] ) self.supports_len += 1 kernel_size = args["kernel_size"] residual_channels = args["residual_channels"] dilation_channels = args["dilation_channels"] kernel_size = args["kernel_size"] skip_channels = args["skip_channels"] end_channels = args["end_channels"] out_dim = args["out_dim"] dropout = args["dropout"] for b in range(self.blocks): additional_scope = kernel_size - 1 new_dilation = 1 for i in range(self.layers): # dilated convolutions self.filter_convs.append( nn.Conv2d( in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation, ) ) self.gate_convs.append( nn.Conv2d( in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation, ) ) # 1x1 convolution for residual connection self.residual_convs.append( nn.Conv2d( in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1), ) ) # 1x1 convolution for skip connection self.skip_convs.append( nn.Conv2d( in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1), ) ) self.bn.append(nn.BatchNorm2d(residual_channels)) new_dilation *= 2 receptive_field += additional_scope additional_scope *= 2 if self.gcn_bool: self.gconv.append( gcn( dilation_channels, residual_channels, dropout, support_len=self.supports_len, ) ) self.end_conv_1 = nn.Conv2d( in_channels=skip_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True, ) self.end_conv_2 = nn.Conv2d( in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True, ) self.receptive_field = receptive_field def forward(self, input): input = input[..., 0:2] input = input.transpose(1, 3) input = nn.functional.pad(input, (1, 0, 0, 0)) in_len = input.size(3) if in_len < self.receptive_field: x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) else: x = input x = self.start_conv(x) skip = 0 # calculate the current adaptive adj matrix once per iteration new_supports = None if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] # WaveNet layers for i in range(self.blocks * self.layers): # |----------------------------------------| *residual* # | | # | |-- conv -- tanh --| | # -> dilate -|----| * ----|-- 1x1 -- + --> *input* # |-- conv -- sigm --| | # 1x1 # | # ---------------------------------------> + -------------> *skip* # (dilation, init_dilation) = self.dilations[i] # residual = dilation_func(x, dilation, init_dilation, i) residual = x # dilated convolution filter = self.filter_convs[i](residual) filter = torch.tanh(filter) gate = self.gate_convs[i](residual) gate = torch.sigmoid(gate) x = filter * gate # parametrized skip connection s = x s = self.skip_convs[i](s) try: skip = skip[:, :, :, -s.size(3) :] except: skip = 0 skip = s + skip if self.gcn_bool and self.supports is not None: if self.addaptadj: x = self.gconv[i](x, new_supports) else: x = self.gconv[i](x, self.supports) else: x = self.residual_convs[i](x) y = residual[:, :, :, -x.size(3) :] # 从倒数第x.size(3)个到末尾 x = x + y x = self.bn[i](x) x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) return x