import torch import torch.nn as nn from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter import torch.nn.functional as F def nconv(x, A): """Multiply x by adjacency matrix along source node axis""" return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() class GraphConvNet(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super().__init__() c_in = (order * support_len + 1) * c_in self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True) self.dropout = dropout self.order = order def forward(self, x, support: list): out = [x] for a in support: x1 = nconv(x, a) out.append(x1) for k in range(2, self.order + 1): x2 = nconv(x1, a) out.append(x2) x1 = x2 h = torch.cat(out, dim=1) h = self.final_conv(h) h = F.dropout(h, self.dropout, training=self.training) return h class gwnet(nn.Module): """ Graph WaveNet模型的主类 结合了图卷积网络和时序卷积网络,用于时空预测任务 """ def __init__(self, args): super().__init__() # 初始化基本参数 self.dropout = args["dropout"] self.blocks = args["blocks"] self.layers = args["layers"] self.do_graph_conv = args.get("do_graph_conv", True) self.cat_feat_gc = args.get("cat_feat_gc", False) self.addaptadj = args.get("addaptadj", True) supports = None aptinit = args.get("aptinit", None) in_dim = args.get("in_dim") out_dim = args.get("out_dim") residual_channels = args.get("residual_channels") dilation_channels = args.get("dilation_channels") skip_channels = args.get("skip_channels") end_channels = args.get("end_channels") kernel_size = args.get("kernel_size") apt_size = args.get("apt_size", 10) if self.cat_feat_gc: self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors out_channels=residual_channels, kernel_size=(1, 1)) self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1, out_channels=residual_channels, kernel_size=(1, 1)) else: self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=residual_channels, kernel_size=(1, 1)) self.fixed_supports = supports or [] receptive_field = 1 self.supports_len = len(self.fixed_supports) if self.do_graph_conv and self.addaptadj: if aptinit is None: nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"]) else: nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit) self.supports_len += 1 self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs] depth = list(range(self.blocks * self.layers)) # 1x1 convolution for residual and skip connections (slightly different see docstring) self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth]) self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth]) self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth]) self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len) for _ in depth]) self.filter_convs = ModuleList() self.gate_convs = ModuleList() for b in range(self.blocks): additional_scope = kernel_size - 1 D = 1 # dilation for i in range(self.layers): # dilated convolutions self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) D *= 2 receptive_field += additional_scope additional_scope *= 2 self.receptive_field = receptive_field self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True) self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True) def forward(self, input): x = input[..., 0:1].transpose(1, 3) # Input shape is (bs, features, n_nodes, n_timesteps) in_len = x.size(3) if in_len < self.receptive_field: x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) if self.cat_feat_gc: f1, f2 = x[:, [0]], x[:, 1:] x1 = self.start_conv(f1) x2 = F.leaky_relu(self.cat_feature_conv(f2)) x = x1 + x2 else: x = self.start_conv(x) skip = 0 adjacency_matrices = self.fixed_supports # calculate the current adaptive adj matrix once per iteration if self.addaptadj: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) adjacency_matrices = self.fixed_supports + [adp] # WaveNet layers for i in range(self.blocks * self.layers): residual = x # dilated convolution filter = torch.tanh(self.filter_convs[i](residual)) gate = torch.sigmoid(self.gate_convs[i](residual)) x = filter * gate # parametrized skip connection s = self.skip_convs[i](x) # what are we skipping?? try: # if i > 0 this works skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool? except: skip = 0 skip = s + skip if i == (self.blocks * self.layers - 1): # last X getting ignored anyway break if self.do_graph_conv: graph_out = self.graph_convs[i](x, adjacency_matrices) x = x + graph_out if self.cat_feat_gc else graph_out else: x = self.residual_convs[i](x) x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool? x = self.bn[i](x) x = F.relu(skip) # ignore last X? x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures) # x = x.transpose(1, 3) return x