155 lines
8.4 KiB
Python
155 lines
8.4 KiB
Python
import torch.nn as nn
|
||
from model.MTGNN.layer import *
|
||
|
||
|
||
class gtnet(nn.Module):
|
||
def __init__(self, configs):
|
||
super(gtnet, self).__init__()
|
||
self.gcn_true = configs['gcn_true'] # 是否使用图卷积网络
|
||
self.buildA_true = configs['buildA_true'] # 是否动态构建邻接矩阵
|
||
self.num_nodes = configs['num_nodes'] # 节点数量
|
||
self.device = configs['device'] # 设备(CPU/GPU)
|
||
self.dropout = configs['dropout'] # dropout率
|
||
self.predefined_A = configs.get('predefined_A', None) # 预定义邻接矩阵
|
||
self.static_feat = configs.get('static_feat', None) # 静态特征
|
||
self.subgraph_size = configs['subgraph_size'] # 子图大小
|
||
self.node_dim = configs['node_dim'] # 节点嵌入维度
|
||
self.dilation_exponential = configs['dilation_exponential'] # 膨胀卷积指数
|
||
self.conv_channels = configs['conv_channels'] # 卷积通道数
|
||
self.residual_channels = configs['residual_channels'] # 残差通道数
|
||
self.skip_channels = configs['skip_channels'] # 跳跃连接通道数
|
||
self.end_channels = configs['end_channels'] # 输出层通道数
|
||
self.seq_length = configs['seq_len'] # 输入序列长度
|
||
self.in_dim = configs['in_dim'] # 输入特征维度
|
||
self.out_len = configs['out_len'] # 输出序列长度
|
||
self.out_dim = configs['out_dim'] # 输出预测维度
|
||
self.layers = configs['layers'] # 模型层数
|
||
self.propalpha = configs['propalpha'] # 图传播参数alpha
|
||
self.tanhalpha = configs['tanhalpha'] # tanh激活参数alpha
|
||
self.layer_norm_affline = configs['layer_norm_affline'] # 层归一化是否使用affine变换
|
||
self.gcn_depth = configs['gcn_depth'] # 图卷积深度
|
||
self.filter_convs = nn.ModuleList() # 卷积滤波器列表
|
||
self.gate_convs = nn.ModuleList() # 门控卷积列表
|
||
self.residual_convs = nn.ModuleList() # 残差卷积列表
|
||
self.skip_convs = nn.ModuleList() # 跳跃连接卷积列表
|
||
self.gconv1 = nn.ModuleList() # 第一层图卷积列表
|
||
self.gconv2 = nn.ModuleList() # 第二层图卷积列表
|
||
self.norm = nn.ModuleList() # 归一化层列表
|
||
self.start_conv = nn.Conv2d(in_channels=self.in_dim,
|
||
out_channels=self.residual_channels,
|
||
kernel_size=(1, 1))
|
||
self.gc = graph_constructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat)
|
||
|
||
kernel_size = 7
|
||
if self.dilation_exponential>1:
|
||
self.receptive_field = int(1+(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1))
|
||
else:
|
||
self.receptive_field = self.layers*(kernel_size-1) + 1
|
||
|
||
for i in range(1):
|
||
if self.dilation_exponential>1:
|
||
rf_size_i = int(1 + i*(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1))
|
||
else:
|
||
rf_size_i = i*self.layers*(kernel_size-1)+1
|
||
new_dilation = 1
|
||
for j in range(1,self.layers+1):
|
||
if self.dilation_exponential > 1:
|
||
rf_size_j = int(rf_size_i + (kernel_size-1)*(self.dilation_exponential**j-1)/(self.dilation_exponential-1))
|
||
else:
|
||
rf_size_j = rf_size_i+j*(kernel_size-1)
|
||
|
||
self.filter_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
|
||
self.gate_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
|
||
self.residual_convs.append(nn.Conv2d(in_channels=self.conv_channels,
|
||
out_channels=self.residual_channels,
|
||
kernel_size=(1, 1)))
|
||
if self.seq_length>self.receptive_field:
|
||
self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels,
|
||
out_channels=self.skip_channels,
|
||
kernel_size=(1, self.seq_length-rf_size_j+1)))
|
||
else:
|
||
self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels,
|
||
out_channels=self.skip_channels,
|
||
kernel_size=(1, self.receptive_field-rf_size_j+1)))
|
||
|
||
if self.gcn_true:
|
||
self.gconv1.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha))
|
||
self.gconv2.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha))
|
||
|
||
if self.seq_length>self.receptive_field:
|
||
self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=self.layer_norm_affline))
|
||
else:
|
||
self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=self.layer_norm_affline))
|
||
|
||
new_dilation *= self.dilation_exponential
|
||
|
||
self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels,
|
||
out_channels=self.end_channels,
|
||
kernel_size=(1,1),
|
||
bias=True)
|
||
self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels,
|
||
out_channels=self.out_len * self.out_dim,
|
||
kernel_size=(1,1),
|
||
bias=True)
|
||
if self.seq_length > self.receptive_field:
|
||
self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.seq_length), bias=True)
|
||
self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True)
|
||
|
||
|
||
else:
|
||
self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field), bias=True)
|
||
self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, 1), bias=True)
|
||
|
||
self.idx = torch.arange(self.num_nodes).to(self.device)
|
||
|
||
|
||
def forward(self, input, idx=None):
|
||
input = input[..., :-2] # 去掉周期嵌入
|
||
input = input.transpose(1, 3)
|
||
seq_len = input.size(3)
|
||
assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length'
|
||
|
||
if self.seq_length<self.receptive_field:
|
||
input = nn.functional.pad(input,(self.receptive_field-self.seq_length,0,0,0))
|
||
|
||
if self.gcn_true:
|
||
if self.buildA_true:
|
||
if idx is None:
|
||
adp = self.gc(self.idx)
|
||
else:
|
||
adp = self.gc(idx)
|
||
else:
|
||
adp = self.predefined_A
|
||
|
||
x = self.start_conv(input)
|
||
skip = self.skip0(F.dropout(input, self.dropout, training=self.training))
|
||
for i in range(self.layers):
|
||
residual = x
|
||
filter = self.filter_convs[i](x)
|
||
filter = torch.tanh(filter)
|
||
gate = self.gate_convs[i](x)
|
||
gate = torch.sigmoid(gate)
|
||
x = filter * gate
|
||
x = F.dropout(x, self.dropout, training=self.training)
|
||
s = x
|
||
s = self.skip_convs[i](s)
|
||
skip = s + skip
|
||
if self.gcn_true:
|
||
x = self.gconv1[i](x, adp)+self.gconv2[i](x, adp.transpose(1,0))
|
||
else:
|
||
x = self.residual_convs[i](x)
|
||
|
||
x = x + residual[:, :, :, -x.size(3):]
|
||
if idx is None:
|
||
x = self.norm[i](x,self.idx)
|
||
else:
|
||
x = self.norm[i](x,idx)
|
||
|
||
skip = self.skipE(x) + skip
|
||
x = F.relu(skip)
|
||
x = F.relu(self.end_conv_1(x))
|
||
x = self.end_conv_2(x) # [b, t*c, n, 1]
|
||
# [b, t*c, n, 1] -> [b,t,c,n] -> [b, t, n, c]
|
||
x = x.reshape(x.size(0), self.out_len, self.out_dim, self.num_nodes)
|
||
x = x.permute(0, 1, 3, 2)
|
||
return x |