import torch.nn as nn from model.MTGNN.layer import * class gtnet(nn.Module): def __init__(self, configs): super(gtnet, self).__init__() self.gcn_true = configs['gcn_true'] # 是否使用图卷积网络 self.buildA_true = configs['buildA_true'] # 是否动态构建邻接矩阵 self.num_nodes = configs['num_nodes'] # 节点数量 self.device = configs['device'] # 设备(CPU/GPU) self.dropout = configs['dropout'] # dropout率 self.predefined_A = configs.get('predefined_A', None) # 预定义邻接矩阵 self.static_feat = configs.get('static_feat', None) # 静态特征 self.subgraph_size = configs['subgraph_size'] # 子图大小 self.node_dim = configs['node_dim'] # 节点嵌入维度 self.dilation_exponential = configs['dilation_exponential'] # 膨胀卷积指数 self.conv_channels = configs['conv_channels'] # 卷积通道数 self.residual_channels = configs['residual_channels'] # 残差通道数 self.skip_channels = configs['skip_channels'] # 跳跃连接通道数 self.end_channels = configs['end_channels'] # 输出层通道数 self.seq_length = configs['seq_len'] # 输入序列长度 self.in_dim = configs['in_dim'] # 输入特征维度 self.out_len = configs['out_len'] # 输出序列长度 self.out_dim = configs['out_dim'] # 输出预测维度 self.layers = configs['layers'] # 模型层数 self.propalpha = configs['propalpha'] # 图传播参数alpha self.tanhalpha = configs['tanhalpha'] # tanh激活参数alpha self.layer_norm_affline = configs['layer_norm_affline'] # 层归一化是否使用affine变换 self.gcn_depth = configs['gcn_depth'] # 图卷积深度 self.filter_convs = nn.ModuleList() # 卷积滤波器列表 self.gate_convs = nn.ModuleList() # 门控卷积列表 self.residual_convs = nn.ModuleList() # 残差卷积列表 self.skip_convs = nn.ModuleList() # 跳跃连接卷积列表 self.gconv1 = nn.ModuleList() # 第一层图卷积列表 self.gconv2 = nn.ModuleList() # 第二层图卷积列表 self.norm = nn.ModuleList() # 归一化层列表 self.start_conv = nn.Conv2d(in_channels=self.in_dim, out_channels=self.residual_channels, kernel_size=(1, 1)) self.gc = graph_constructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat) kernel_size = 7 if self.dilation_exponential>1: self.receptive_field = int(1+(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: self.receptive_field = self.layers*(kernel_size-1) + 1 for i in range(1): if self.dilation_exponential>1: rf_size_i = int(1 + i*(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: rf_size_i = i*self.layers*(kernel_size-1)+1 new_dilation = 1 for j in range(1,self.layers+1): if self.dilation_exponential > 1: rf_size_j = int(rf_size_i + (kernel_size-1)*(self.dilation_exponential**j-1)/(self.dilation_exponential-1)) else: rf_size_j = rf_size_i+j*(kernel_size-1) self.filter_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) self.gate_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) self.residual_convs.append(nn.Conv2d(in_channels=self.conv_channels, out_channels=self.residual_channels, kernel_size=(1, 1))) if self.seq_length>self.receptive_field: self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-rf_size_j+1))) else: self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field-rf_size_j+1))) if self.gcn_true: self.gconv1.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) self.gconv2.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) if self.seq_length>self.receptive_field: self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) else: self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) new_dilation *= self.dilation_exponential self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels, out_channels=self.end_channels, kernel_size=(1,1), bias=True) self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels, out_channels=self.out_len * self.out_dim, kernel_size=(1,1), bias=True) if self.seq_length > self.receptive_field: self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.seq_length), bias=True) self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) else: self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field), bias=True) self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, 1), bias=True) self.idx = torch.arange(self.num_nodes).to(self.device) def forward(self, input, idx=None): input = input[..., :-2] # 去掉周期嵌入 input = input.transpose(1, 3) seq_len = input.size(3) assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' if self.seq_length [b,t,c,n] -> [b, t, n, c] x = x.reshape(x.size(0), self.out_len, self.out_dim, self.num_nodes) x = x.permute(0, 1, 3, 2) return x