148 lines
5.8 KiB
Python
Executable File
148 lines
5.8 KiB
Python
Executable File
import torch, torch.nn as nn, torch.nn.functional as F
|
||
|
||
|
||
class nconv(nn.Module):
|
||
"""
|
||
图卷积操作的实现类
|
||
使用einsum进行矩阵运算,实现图卷积操作
|
||
"""
|
||
def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
|
||
|
||
|
||
class linear(nn.Module):
|
||
"""
|
||
线性变换层
|
||
使用1x1卷积实现线性变换
|
||
"""
|
||
def __init__(self, c_in, c_out):
|
||
super().__init__()
|
||
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
||
|
||
def forward(self, x):
|
||
return self.mlp(x)
|
||
|
||
|
||
class gcn(nn.Module):
|
||
"""
|
||
图卷积网络层
|
||
实现高阶图卷积操作,支持多阶邻接矩阵
|
||
"""
|
||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||
super().__init__()
|
||
self.nconv = nconv()
|
||
c_in = (order * support_len + 1) * c_in
|
||
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
||
|
||
def forward(self, x, support):
|
||
out = [x]
|
||
for a in support:
|
||
x1 = self.nconv(x, a)
|
||
out.append(x1)
|
||
for _ in range(2, self.order + 1):
|
||
x1 = self.nconv(x1, a)
|
||
out.append(x1)
|
||
return F.dropout(self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training)
|
||
|
||
|
||
class gwnet(nn.Module):
|
||
"""
|
||
Graph WaveNet模型的主类
|
||
结合了图卷积网络和时序卷积网络,用于时空预测任务
|
||
"""
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
# 初始化基本参数
|
||
self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers']
|
||
self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj']
|
||
|
||
# 初始化各种卷积层和模块
|
||
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
||
self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
|
||
self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1)
|
||
self.supports = args.get('supports', None)
|
||
|
||
# 计算感受野
|
||
receptive_field = 1
|
||
self.supports_len = len(self.supports) if self.supports is not None else 0
|
||
|
||
# 如果使用自适应邻接矩阵,初始化相关参数
|
||
if self.gcn_bool and self.addaptadj:
|
||
aptinit = args.get('aptinit', None)
|
||
if aptinit is None:
|
||
if self.supports is None: self.supports = []
|
||
self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10, device=args['device']))
|
||
self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes'], device=args['device']))
|
||
self.supports_len += 1
|
||
else:
|
||
if self.supports is None: self.supports = []
|
||
m, p, n = torch.svd(aptinit)
|
||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||
self.nodevec1 = nn.Parameter(initemb1)
|
||
self.nodevec2 = nn.Parameter(initemb2)
|
||
self.supports_len += 1
|
||
|
||
# 获取模型参数
|
||
ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \
|
||
args['skip_channels'], args['end_channels'], args['out_dim']
|
||
|
||
# 构建模型层
|
||
for b in range(self.blocks):
|
||
add_scope, new_dil = ks - 1, 1
|
||
for i in range(self.layers):
|
||
# 添加时间卷积层
|
||
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
||
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
||
self.bn.append(nn.BatchNorm2d(res))
|
||
new_dil *= 2
|
||
receptive_field += add_scope
|
||
add_scope *= 2
|
||
if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len))
|
||
|
||
# 输出层
|
||
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
||
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
||
self.receptive_field = receptive_field
|
||
|
||
def forward(self, input):
|
||
"""
|
||
前向传播函数
|
||
实现模型的推理过程
|
||
"""
|
||
# 数据预处理
|
||
input = input[..., 0:2].transpose(1, 3)
|
||
input = F.pad(input, (1, 0, 0, 0))
|
||
in_len = input.size(3)
|
||
x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input
|
||
|
||
# 初始卷积
|
||
x, skip, new_supports = self.start_conv(x), 0, None
|
||
|
||
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
|
||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||
new_supports = self.supports + [adp]
|
||
|
||
# 主网络层的前向传播
|
||
for i in range(self.blocks * self.layers):
|
||
residual = x
|
||
# 时间卷积操作
|
||
f = self.filter_convs[i](residual).tanh()
|
||
g = self.gate_convs[i](residual).sigmoid()
|
||
x = f * g
|
||
s = self.skip_convs[i](x)
|
||
skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s
|
||
|
||
# 图卷积操作
|
||
if self.gcn_bool and self.supports is not None:
|
||
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
||
else:
|
||
x = self.residual_convs[i](x)
|
||
x = x + residual[:, :, :, -x.size(3):]
|
||
x = self.bn[i](x)
|
||
|
||
# 输出层处理
|
||
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|