From 61565cd33a952723a42ea08346de16a3b2d669dd Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 14 May 2025 13:09:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BAGraphWaveNet=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AF=A6=E7=BB=86=E6=B3=A8=E9=87=8A=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E4=BB=A3=E7=A0=81=E5=8F=AF=E8=AF=BB=E6=80=A7?= =?UTF-8?q?=E5=92=8C=E5=8F=AF=E7=BB=B4=E6=8A=A4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/GWN/GraphWaveNet.py | 46 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/model/GWN/GraphWaveNet.py b/model/GWN/GraphWaveNet.py index d565f71..55b711e 100755 --- a/model/GWN/GraphWaveNet.py +++ b/model/GWN/GraphWaveNet.py @@ -2,10 +2,18 @@ import torch, torch.nn as nn, torch.nn.functional as F class nconv(nn.Module): + """ + 图卷积操作的实现类 + 使用einsum进行矩阵运算,实现图卷积操作 + """ def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() class linear(nn.Module): + """ + 线性变换层 + 使用1x1卷积实现线性变换 + """ def __init__(self, c_in, c_out): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, 1) @@ -15,6 +23,10 @@ class linear(nn.Module): class gcn(nn.Module): + """ + 图卷积网络层 + 实现高阶图卷积操作,支持多阶邻接矩阵 + """ def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super().__init__() self.nconv = nconv() @@ -33,16 +45,27 @@ class gcn(nn.Module): class gwnet(nn.Module): + """ + Graph WaveNet模型的主类 + 结合了图卷积网络和时序卷积网络,用于时空预测任务 + """ def __init__(self, args): super().__init__() + # 初始化基本参数 self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers'] self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj'] + + # 初始化各种卷积层和模块 self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList() self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1) self.supports = args.get('supports', None) + + # 计算感受野 receptive_field = 1 self.supports_len = len(self.supports) if self.supports is not None else 0 + + # 如果使用自适应邻接矩阵,初始化相关参数 if self.gcn_bool and self.addaptadj: aptinit = args.get('aptinit', None) if aptinit is None: @@ -58,11 +81,16 @@ class gwnet(nn.Module): self.nodevec1 = nn.Parameter(initemb1) self.nodevec2 = nn.Parameter(initemb2) self.supports_len += 1 + + # 获取模型参数 ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \ args['skip_channels'], args['end_channels'], args['out_dim'] + + # 构建模型层 for b in range(self.blocks): add_scope, new_dil = ks - 1, 1 for i in range(self.layers): + # 添加时间卷积层 self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.residual_convs.append(nn.Conv2d(dil, res, 1)) @@ -72,30 +100,48 @@ class gwnet(nn.Module): receptive_field += add_scope add_scope *= 2 if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len)) + + # 输出层 self.end_conv_1 = nn.Conv2d(skip, endc, 1) self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.receptive_field = receptive_field def forward(self, input): + """ + 前向传播函数 + 实现模型的推理过程 + """ + # 数据预处理 input = input[..., 0:2].transpose(1, 3) input = F.pad(input, (1, 0, 0, 0)) in_len = input.size(3) x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input + + # 初始卷积 x, skip, new_supports = self.start_conv(x), 0, None + + # 如果使用自适应邻接矩阵,计算新的邻接矩阵 if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] + + # 主网络层的前向传播 for i in range(self.blocks * self.layers): residual = x + # 时间卷积操作 f = self.filter_convs[i](residual).tanh() g = self.gate_convs[i](residual).sigmoid() x = f * g s = self.skip_convs[i](x) skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s + + # 图卷积操作 if self.gcn_bool and self.supports is not None: x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) else: x = self.residual_convs[i](x) x = x + residual[:, :, :, -x.size(3):] x = self.bn[i](x) + + # 输出层处理 return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))