TrafficWheel/model/GWN/GraphWaveNet.py

148 lines
5.8 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch, torch.nn as nn, torch.nn.functional as F
class nconv(nn.Module):
"""
图卷积操作的实现类
使用einsum进行矩阵运算实现图卷积操作
"""
def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
class linear(nn.Module):
"""
线性变换层
使用1x1卷积实现线性变换
"""
def __init__(self, c_in, c_out):
super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, 1)
def forward(self, x):
return self.mlp(x)
class gcn(nn.Module):
"""
图卷积网络层
实现高阶图卷积操作,支持多阶邻接矩阵
"""
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super().__init__()
self.nconv = nconv()
c_in = (order * support_len + 1) * c_in
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
def forward(self, x, support):
out = [x]
for a in support:
x1 = self.nconv(x, a)
out.append(x1)
for _ in range(2, self.order + 1):
x1 = self.nconv(x1, a)
out.append(x1)
return F.dropout(self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training)
class gwnet(nn.Module):
"""
Graph WaveNet模型的主类
结合了图卷积网络和时序卷积网络,用于时空预测任务
"""
def __init__(self, args):
super().__init__()
# 初始化基本参数
self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers']
self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj']
# 初始化各种卷积层和模块
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1)
self.supports = args.get('supports', None)
# 计算感受野
receptive_field = 1
self.supports_len = len(self.supports) if self.supports is not None else 0
# 如果使用自适应邻接矩阵,初始化相关参数
if self.gcn_bool and self.addaptadj:
aptinit = args.get('aptinit', None)
if aptinit is None:
if self.supports is None: self.supports = []
self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10, device=args['device']))
self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes'], device=args['device']))
self.supports_len += 1
else:
if self.supports is None: self.supports = []
m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1)
self.nodevec2 = nn.Parameter(initemb2)
self.supports_len += 1
# 获取模型参数
ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \
args['skip_channels'], args['end_channels'], args['out_dim']
# 构建模型层
for b in range(self.blocks):
add_scope, new_dil = ks - 1, 1
for i in range(self.layers):
# 添加时间卷积层
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
self.residual_convs.append(nn.Conv2d(dil, res, 1))
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
self.bn.append(nn.BatchNorm2d(res))
new_dil *= 2
receptive_field += add_scope
add_scope *= 2
if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len))
# 输出层
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
self.receptive_field = receptive_field
def forward(self, input):
"""
前向传播函数
实现模型的推理过程
"""
# 数据预处理
input = input[..., 0:2].transpose(1, 3)
input = F.pad(input, (1, 0, 0, 0))
in_len = input.size(3)
x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input
# 初始卷积
x, skip, new_supports = self.start_conv(x), 0, None
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
# 主网络层的前向传播
for i in range(self.blocks * self.layers):
residual = x
# 时间卷积操作
f = self.filter_convs[i](residual).tanh()
g = self.gate_convs[i](residual).sigmoid()
x = f * g
s = self.skip_convs[i](x)
skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s
# 图卷积操作
if self.gcn_bool and self.supports is not None:
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
else:
x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3):]
x = self.bn[i](x)
# 输出层处理
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))