import torch, torch.nn as nn, torch.nn.functional as F class nconv(nn.Module): def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() class linear(nn.Module): def __init__(self, c_in, c_out): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, 1) def forward(self, x): return self.mlp(x) class gcn(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super().__init__() self.nconv = nconv() c_in = (order * support_len + 1) * c_in self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order def forward(self, x, support): out = [x] for a in support: x1 = self.nconv(x, a) out.append(x1) for _ in range(2, self.order + 1): x1 = self.nconv(x1, a) out.append(x1) return F.dropout(self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training) class gwnet(nn.Module): def __init__(self, args): super().__init__() self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers'] self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj'] self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList() self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1) self.supports = args.get('supports', None) receptive_field = 1 self.supports_len = len(self.supports) if self.supports is not None else 0 if self.gcn_bool and self.addaptadj: aptinit = args.get('aptinit', None) if aptinit is None: if self.supports is None: self.supports = [] self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10, device=args['device'])) self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes'], device=args['device'])) self.supports_len += 1 else: if self.supports is None: self.supports = [] m, p, n = torch.svd(aptinit) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) self.nodevec1 = nn.Parameter(initemb1) self.nodevec2 = nn.Parameter(initemb2) self.supports_len += 1 ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \ args['skip_channels'], args['end_channels'], args['out_dim'] for b in range(self.blocks): add_scope, new_dil = ks - 1, 1 for i in range(self.layers): self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.residual_convs.append(nn.Conv2d(dil, res, 1)) self.skip_convs.append(nn.Conv2d(dil, skip, 1)) self.bn.append(nn.BatchNorm2d(res)) new_dil *= 2 receptive_field += add_scope add_scope *= 2 if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len)) self.end_conv_1 = nn.Conv2d(skip, endc, 1) self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.receptive_field = receptive_field def forward(self, input): input = input[..., 0:2].transpose(1, 3) input = F.pad(input, (1, 0, 0, 0)) in_len = input.size(3) x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input x, skip, new_supports = self.start_conv(x), 0, None if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] for i in range(self.blocks * self.layers): residual = x f = self.filter_convs[i](residual).tanh() g = self.gate_convs[i](residual).sigmoid() x = f * g s = self.skip_convs[i](x) skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s if self.gcn_bool and self.supports is not None: x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) else: x = self.residual_convs[i](x) x = x + residual[:, :, :, -x.size(3):] x = self.bn[i](x) return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))