import torch import torch.nn as nn import torch.nn.functional as F class TemporalConvNet(nn.Module): def __init__(self, args): super(TemporalConvNet, self).__init__() num_inputs = args["input_dim"] # 输入维度,比如 1 num_channels = args["hidden_channels"] # 隐藏层通道列表,如 [64, 64, 64] kernel_size = args["kernel_size"] # 卷积核大小,比如 3 dropout = args["dropout"] # dropout 概率,比如 0.2 layers = [] num_levels = len(num_channels) for i in range(num_levels): in_channels = num_inputs if i == 0 else num_channels[i - 1] out_channels = num_channels[i] dilation = 2**i # 指数膨胀率 layers.append( TemporalBlock( in_channels, out_channels, kernel_size, stride=1, dilation=dilation, dropout=dropout, ) ) self.network = nn.Sequential(*layers) # 投影层将最后通道数转换为输出维度(通常与输入维度一致,比如 1) self.proj = nn.Conv2d(num_channels[-1], args["output_dim"], kernel_size=1) def forward(self, x): """ 输入 x 的形状为 (batch_size, time_step, num_nodes, dim) 这里先将其转换为 (batch, dim, num_nodes, time_step), 方便在时间维度上使用 2D 卷积 """ x = x[..., 0:1] x = x.permute(0, 3, 2, 1) x = self.network(x) x = self.proj(x) # 恢复到 (batch_size, time_step, num_nodes, output_dim) return x.permute(0, 3, 2, 1) class TemporalBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, dilation, dropout ): super(TemporalBlock, self).__init__() self.kernel_size = kernel_size self.dilation = dilation # 由于我们采用因果卷积,卷积层内部不再自动填充 self.conv1 = nn.Conv2d( in_channels, out_channels, (1, kernel_size), stride=stride, padding=0, dilation=(1, dilation), ) self.norm1 = nn.BatchNorm2d(out_channels) self.dropout1 = nn.Dropout(dropout) self.conv2 = nn.Conv2d( out_channels, out_channels, (1, kernel_size), stride=stride, padding=0, dilation=(1, dilation), ) self.norm2 = nn.BatchNorm2d(out_channels) self.dropout2 = nn.Dropout(dropout) # 如果通道数不匹配,则使用 1x1 卷积调整 residual 的通道数 self.downsample = ( nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None ) def forward(self, x): # residual 分支保持原始 x(或者经过 downsample 变换) res = x if self.downsample is None else self.downsample(x) # 计算因果卷积需要在时间维度左侧填充的步数 pad = (self.kernel_size - 1) * self.dilation # 第一个卷积层 out = F.pad(x, (pad, 0)) # 在时间维度左侧填充 pad 个零;格式为 (左侧, 右侧) out = self.conv1(out) out = self.norm1(out) out = F.relu(out) out = self.dropout1(out) # 第二个卷积层 out = F.pad(out, (pad, 0)) out = self.conv2(out) out = self.norm2(out) out = F.relu(out) out = self.dropout2(out) # 残差连接,确保输出的时间步数与输入一致 return F.relu(out + res)