为GraphWaveNet模型添加详细注释,增强代码可读性和可维护性

This commit is contained in:
czzhangheng 2025-05-14 13:09:20 +08:00
parent 4f7fb52707
commit 61565cd33a
1 changed files with 46 additions and 0 deletions

View File

@ -2,10 +2,18 @@ import torch, torch.nn as nn, torch.nn.functional as F
class nconv(nn.Module): class nconv(nn.Module):
"""
图卷积操作的实现类
使用einsum进行矩阵运算实现图卷积操作
"""
def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
class linear(nn.Module): class linear(nn.Module):
"""
线性变换层
使用1x1卷积实现线性变换
"""
def __init__(self, c_in, c_out): def __init__(self, c_in, c_out):
super().__init__() super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, 1) self.mlp = nn.Conv2d(c_in, c_out, 1)
@ -15,6 +23,10 @@ class linear(nn.Module):
class gcn(nn.Module): class gcn(nn.Module):
"""
图卷积网络层
实现高阶图卷积操作支持多阶邻接矩阵
"""
def __init__(self, c_in, c_out, dropout, support_len=3, order=2): def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super().__init__() super().__init__()
self.nconv = nconv() self.nconv = nconv()
@ -33,16 +45,27 @@ class gcn(nn.Module):
class gwnet(nn.Module): class gwnet(nn.Module):
"""
Graph WaveNet模型的主类
结合了图卷积网络和时序卷积网络用于时空预测任务
"""
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
# 初始化基本参数
self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers'] self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers']
self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj'] self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj']
# 初始化各种卷积层和模块
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList() self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1) self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1)
self.supports = args.get('supports', None) self.supports = args.get('supports', None)
# 计算感受野
receptive_field = 1 receptive_field = 1
self.supports_len = len(self.supports) if self.supports is not None else 0 self.supports_len = len(self.supports) if self.supports is not None else 0
# 如果使用自适应邻接矩阵,初始化相关参数
if self.gcn_bool and self.addaptadj: if self.gcn_bool and self.addaptadj:
aptinit = args.get('aptinit', None) aptinit = args.get('aptinit', None)
if aptinit is None: if aptinit is None:
@ -58,11 +81,16 @@ class gwnet(nn.Module):
self.nodevec1 = nn.Parameter(initemb1) self.nodevec1 = nn.Parameter(initemb1)
self.nodevec2 = nn.Parameter(initemb2) self.nodevec2 = nn.Parameter(initemb2)
self.supports_len += 1 self.supports_len += 1
# 获取模型参数
ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \ ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \
args['skip_channels'], args['end_channels'], args['out_dim'] args['skip_channels'], args['end_channels'], args['out_dim']
# 构建模型层
for b in range(self.blocks): for b in range(self.blocks):
add_scope, new_dil = ks - 1, 1 add_scope, new_dil = ks - 1, 1
for i in range(self.layers): for i in range(self.layers):
# 添加时间卷积层
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
self.residual_convs.append(nn.Conv2d(dil, res, 1)) self.residual_convs.append(nn.Conv2d(dil, res, 1))
@ -72,30 +100,48 @@ class gwnet(nn.Module):
receptive_field += add_scope receptive_field += add_scope
add_scope *= 2 add_scope *= 2
if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len)) if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len))
# 输出层
self.end_conv_1 = nn.Conv2d(skip, endc, 1) self.end_conv_1 = nn.Conv2d(skip, endc, 1)
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
self.receptive_field = receptive_field self.receptive_field = receptive_field
def forward(self, input): def forward(self, input):
"""
前向传播函数
实现模型的推理过程
"""
# 数据预处理
input = input[..., 0:2].transpose(1, 3) input = input[..., 0:2].transpose(1, 3)
input = F.pad(input, (1, 0, 0, 0)) input = F.pad(input, (1, 0, 0, 0))
in_len = input.size(3) in_len = input.size(3)
x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input
# 初始卷积
x, skip, new_supports = self.start_conv(x), 0, None x, skip, new_supports = self.start_conv(x), 0, None
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
if self.gcn_bool and self.addaptadj and self.supports is not None: if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp] new_supports = self.supports + [adp]
# 主网络层的前向传播
for i in range(self.blocks * self.layers): for i in range(self.blocks * self.layers):
residual = x residual = x
# 时间卷积操作
f = self.filter_convs[i](residual).tanh() f = self.filter_convs[i](residual).tanh()
g = self.gate_convs[i](residual).sigmoid() g = self.gate_convs[i](residual).sigmoid()
x = f * g x = f * g
s = self.skip_convs[i](x) s = self.skip_convs[i](x)
skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s
# 图卷积操作
if self.gcn_bool and self.supports is not None: if self.gcn_bool and self.supports is not None:
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
else: else:
x = self.residual_convs[i](x) x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3):] x = x + residual[:, :, :, -x.size(3):]
x = self.bn[i](x) x = self.bn[i](x)
# 输出层处理
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))