TrafficWheel/model/MTGNN/MTGNN.py

155 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
from model.MTGNN.layer import *
class gtnet(nn.Module):
def __init__(self, configs):
super(gtnet, self).__init__()
self.gcn_true = configs['gcn_true'] # 是否使用图卷积网络
self.buildA_true = configs['buildA_true'] # 是否动态构建邻接矩阵
self.num_nodes = configs['num_nodes'] # 节点数量
self.device = configs['device'] # 设备CPU/GPU
self.dropout = configs['dropout'] # dropout率
self.predefined_A = configs.get('predefined_A', None) # 预定义邻接矩阵
self.static_feat = configs.get('static_feat', None) # 静态特征
self.subgraph_size = configs['subgraph_size'] # 子图大小
self.node_dim = configs['node_dim'] # 节点嵌入维度
self.dilation_exponential = configs['dilation_exponential'] # 膨胀卷积指数
self.conv_channels = configs['conv_channels'] # 卷积通道数
self.residual_channels = configs['residual_channels'] # 残差通道数
self.skip_channels = configs['skip_channels'] # 跳跃连接通道数
self.end_channels = configs['end_channels'] # 输出层通道数
self.seq_length = configs['seq_len'] # 输入序列长度
self.in_dim = configs['in_dim'] # 输入特征维度
self.out_len = configs['out_len'] # 输出序列长度
self.out_dim = configs['out_dim'] # 输出预测维度
self.layers = configs['layers'] # 模型层数
self.propalpha = configs['propalpha'] # 图传播参数alpha
self.tanhalpha = configs['tanhalpha'] # tanh激活参数alpha
self.layer_norm_affline = configs['layer_norm_affline'] # 层归一化是否使用affine变换
self.gcn_depth = configs['gcn_depth'] # 图卷积深度
self.filter_convs = nn.ModuleList() # 卷积滤波器列表
self.gate_convs = nn.ModuleList() # 门控卷积列表
self.residual_convs = nn.ModuleList() # 残差卷积列表
self.skip_convs = nn.ModuleList() # 跳跃连接卷积列表
self.gconv1 = nn.ModuleList() # 第一层图卷积列表
self.gconv2 = nn.ModuleList() # 第二层图卷积列表
self.norm = nn.ModuleList() # 归一化层列表
self.start_conv = nn.Conv2d(in_channels=self.in_dim,
out_channels=self.residual_channels,
kernel_size=(1, 1))
self.gc = graph_constructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat)
kernel_size = 7
if self.dilation_exponential>1:
self.receptive_field = int(1+(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1))
else:
self.receptive_field = self.layers*(kernel_size-1) + 1
for i in range(1):
if self.dilation_exponential>1:
rf_size_i = int(1 + i*(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1))
else:
rf_size_i = i*self.layers*(kernel_size-1)+1
new_dilation = 1
for j in range(1,self.layers+1):
if self.dilation_exponential > 1:
rf_size_j = int(rf_size_i + (kernel_size-1)*(self.dilation_exponential**j-1)/(self.dilation_exponential-1))
else:
rf_size_j = rf_size_i+j*(kernel_size-1)
self.filter_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.gate_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.residual_convs.append(nn.Conv2d(in_channels=self.conv_channels,
out_channels=self.residual_channels,
kernel_size=(1, 1)))
if self.seq_length>self.receptive_field:
self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels,
out_channels=self.skip_channels,
kernel_size=(1, self.seq_length-rf_size_j+1)))
else:
self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels,
out_channels=self.skip_channels,
kernel_size=(1, self.receptive_field-rf_size_j+1)))
if self.gcn_true:
self.gconv1.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha))
self.gconv2.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha))
if self.seq_length>self.receptive_field:
self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=self.layer_norm_affline))
else:
self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=self.layer_norm_affline))
new_dilation *= self.dilation_exponential
self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels,
out_channels=self.end_channels,
kernel_size=(1,1),
bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels,
out_channels=self.out_len * self.out_dim,
kernel_size=(1,1),
bias=True)
if self.seq_length > self.receptive_field:
self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.seq_length), bias=True)
self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True)
else:
self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field), bias=True)
self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, 1), bias=True)
self.idx = torch.arange(self.num_nodes).to(self.device)
def forward(self, input, idx=None):
input = input[..., :-2] # 去掉周期嵌入
input = input.transpose(1, 3)
seq_len = input.size(3)
assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length'
if self.seq_length<self.receptive_field:
input = nn.functional.pad(input,(self.receptive_field-self.seq_length,0,0,0))
if self.gcn_true:
if self.buildA_true:
if idx is None:
adp = self.gc(self.idx)
else:
adp = self.gc(idx)
else:
adp = self.predefined_A
x = self.start_conv(input)
skip = self.skip0(F.dropout(input, self.dropout, training=self.training))
for i in range(self.layers):
residual = x
filter = self.filter_convs[i](x)
filter = torch.tanh(filter)
gate = self.gate_convs[i](x)
gate = torch.sigmoid(gate)
x = filter * gate
x = F.dropout(x, self.dropout, training=self.training)
s = x
s = self.skip_convs[i](s)
skip = s + skip
if self.gcn_true:
x = self.gconv1[i](x, adp)+self.gconv2[i](x, adp.transpose(1,0))
else:
x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3):]
if idx is None:
x = self.norm[i](x,self.idx)
else:
x = self.norm[i](x,idx)
skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x) # [b, t*c, n, 1]
# [b, t*c, n, 1] -> [b,t,c,n] -> [b, t, n, c]
x = x.reshape(x.size(0), self.out_len, self.out_dim, self.num_nodes)
x = x.permute(0, 1, 3, 2)
return x