TrafficWheel/model/TCN/TCN.py

83 lines
3.4 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class TemporalConvNet(nn.Module):
def __init__(self, args):
super(TemporalConvNet, self).__init__()
num_inputs = args['input_dim'] # 输入维度,比如 1
num_channels = args['hidden_channels'] # 隐藏层通道列表,如 [64, 64, 64]
kernel_size = args['kernel_size'] # 卷积核大小,比如 3
dropout = args['dropout'] # dropout 概率,比如 0.2
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
dilation = 2 ** i # 指数膨胀率
layers.append(
TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation, dropout=dropout))
self.network = nn.Sequential(*layers)
# 投影层将最后通道数转换为输出维度(通常与输入维度一致,比如 1
self.proj = nn.Conv2d(num_channels[-1], args['output_dim'], kernel_size=1)
def forward(self, x):
"""
输入 x 的形状为 (batch_size, time_step, num_nodes, dim)
这里先将其转换为 (batch, dim, num_nodes, time_step)
方便在时间维度上使用 2D 卷积
"""
x = x[..., 0:1]
x = x.permute(0, 3, 2, 1)
x = self.network(x)
x = self.proj(x)
# 恢复到 (batch_size, time_step, num_nodes, output_dim)
return x.permute(0, 3, 2, 1)
class TemporalBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, dropout):
super(TemporalBlock, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
# 由于我们采用因果卷积,卷积层内部不再自动填充
self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size),
stride=stride, padding=0, dilation=(1, dilation))
self.norm1 = nn.BatchNorm2d(out_channels)
self.dropout1 = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, (1, kernel_size),
stride=stride, padding=0, dilation=(1, dilation))
self.norm2 = nn.BatchNorm2d(out_channels)
self.dropout2 = nn.Dropout(dropout)
# 如果通道数不匹配,则使用 1x1 卷积调整 residual 的通道数
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
def forward(self, x):
# residual 分支保持原始 x或者经过 downsample 变换)
res = x if self.downsample is None else self.downsample(x)
# 计算因果卷积需要在时间维度左侧填充的步数
pad = (self.kernel_size - 1) * self.dilation
# 第一个卷积层
out = F.pad(x, (pad, 0)) # 在时间维度左侧填充 pad 个零;格式为 (左侧, 右侧)
out = self.conv1(out)
out = self.norm1(out)
out = F.relu(out)
out = self.dropout1(out)
# 第二个卷积层
out = F.pad(out, (pad, 0))
out = self.conv2(out)
out = self.norm2(out)
out = F.relu(out)
out = self.dropout2(out)
# 残差连接,确保输出的时间步数与输入一致
return F.relu(out + res)