83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class TemporalConvNet(nn.Module):
|
||
def __init__(self, args):
|
||
super(TemporalConvNet, self).__init__()
|
||
num_inputs = args['input_dim'] # 输入维度,比如 1
|
||
num_channels = args['hidden_channels'] # 隐藏层通道列表,如 [64, 64, 64]
|
||
kernel_size = args['kernel_size'] # 卷积核大小,比如 3
|
||
dropout = args['dropout'] # dropout 概率,比如 0.2
|
||
|
||
layers = []
|
||
num_levels = len(num_channels)
|
||
for i in range(num_levels):
|
||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||
out_channels = num_channels[i]
|
||
dilation = 2 ** i # 指数膨胀率
|
||
layers.append(
|
||
TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation, dropout=dropout))
|
||
|
||
self.network = nn.Sequential(*layers)
|
||
# 投影层将最后通道数转换为输出维度(通常与输入维度一致,比如 1)
|
||
self.proj = nn.Conv2d(num_channels[-1], args['output_dim'], kernel_size=1)
|
||
|
||
def forward(self, x):
|
||
"""
|
||
输入 x 的形状为 (batch_size, time_step, num_nodes, dim)
|
||
这里先将其转换为 (batch, dim, num_nodes, time_step),
|
||
方便在时间维度上使用 2D 卷积
|
||
"""
|
||
x = x[..., 0:1]
|
||
x = x.permute(0, 3, 2, 1)
|
||
x = self.network(x)
|
||
x = self.proj(x)
|
||
# 恢复到 (batch_size, time_step, num_nodes, output_dim)
|
||
return x.permute(0, 3, 2, 1)
|
||
|
||
|
||
class TemporalBlock(nn.Module):
|
||
def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, dropout):
|
||
super(TemporalBlock, self).__init__()
|
||
self.kernel_size = kernel_size
|
||
self.dilation = dilation
|
||
|
||
# 由于我们采用因果卷积,卷积层内部不再自动填充
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size),
|
||
stride=stride, padding=0, dilation=(1, dilation))
|
||
self.norm1 = nn.BatchNorm2d(out_channels)
|
||
self.dropout1 = nn.Dropout(dropout)
|
||
|
||
self.conv2 = nn.Conv2d(out_channels, out_channels, (1, kernel_size),
|
||
stride=stride, padding=0, dilation=(1, dilation))
|
||
self.norm2 = nn.BatchNorm2d(out_channels)
|
||
self.dropout2 = nn.Dropout(dropout)
|
||
|
||
# 如果通道数不匹配,则使用 1x1 卷积调整 residual 的通道数
|
||
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
|
||
|
||
def forward(self, x):
|
||
# residual 分支保持原始 x(或者经过 downsample 变换)
|
||
res = x if self.downsample is None else self.downsample(x)
|
||
# 计算因果卷积需要在时间维度左侧填充的步数
|
||
pad = (self.kernel_size - 1) * self.dilation
|
||
|
||
# 第一个卷积层
|
||
out = F.pad(x, (pad, 0)) # 在时间维度左侧填充 pad 个零;格式为 (左侧, 右侧)
|
||
out = self.conv1(out)
|
||
out = self.norm1(out)
|
||
out = F.relu(out)
|
||
out = self.dropout1(out)
|
||
|
||
# 第二个卷积层
|
||
out = F.pad(out, (pad, 0))
|
||
out = self.conv2(out)
|
||
out = self.norm2(out)
|
||
out = F.relu(out)
|
||
out = self.dropout2(out)
|
||
|
||
# 残差连接,确保输出的时间步数与输入一致
|
||
return F.relu(out + res)
|