TrafficWheel/model/ARIMA/ARIMA.py

101 lines
3.7 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class ARIMA(nn.Module):
def __init__(self, args):
super(ARIMA, self).__init__()
self.p = args['p'] # 自回归阶数
self.d = args['d'] # 差分阶数
self.q = args['q'] # 移动平均阶数
self.num_node = args.get('num_nodes', 1) # 节点数量
# AR 参数
self.ar_params = nn.Parameter(torch.randn(self.p, self.num_node))
# MA 参数
self.ma_params = nn.Parameter(torch.randn(self.q, self.num_node))
# 偏移项(可选)
self.drift = nn.Parameter(torch.zeros(self.num_node)) if args.get('drift', False) else None
def forward(self, x):
"""
输入: [batch_size, time, num_node, dim]
输出: [batch_size, time, num_node, 1]
"""
# 提取目标维度的输入
x = x[..., 0].unsqueeze(-1) # [batch_size, time, num_node, 1]
batch_size, time, num_node, _ = x.shape
# 初始化输出张量
y_hat = torch.zeros_like(x)
# 逐节点处理
for n in range(num_node):
node_x = x[:, :, n, 0] # 当前节点的时间序列 [batch_size, time]
drift = self.drift[n] if self.drift is not None else None
node_y_hat = self._arima_forward(node_x, self.ar_params[:, n], self.ma_params[:, n], drift)
y_hat[:, :, n, 0] = node_y_hat
return y_hat
def _arima_forward(self, x, ar_params, ma_params, drift):
"""
单节点 ARIMA 前向传播
输入: x [batch_size, time]
输出: y_hat [batch_size, time]
"""
batch_size, time = x.shape
y_hat = torch.zeros_like(x)
# 差分
x_diff = x
for _ in range(self.d):
x_diff = torch.diff(x_diff, dim=1, prepend=x_diff[:, :1]) # 使用 prepend 保持时间维度不变
# 自回归部分
ar_out = torch.zeros_like(x_diff)
if self.p > 0:
# 使用 unfold 方法创建滑动窗口
x_diff_windows = x_diff.unfold(1, self.p, 1) # [batch_size, time - self.p + 1, self.p]
ar_out = torch.matmul(x_diff_windows, ar_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.p + 1]
ar_out = F.pad(ar_out, (self.p - 1, 0)) # 通过零填充使 ar_out 和 x_diff 的时间维度一致
# 移动平均部分
ma_out = torch.zeros_like(x_diff)
if self.q > 0:
# 使用 unfold 方法创建滑动窗口
x_diff_windows = x_diff.unfold(1, self.q, 1) # [batch_size, time - self.q + 1, self.q]
ma_out = torch.matmul(x_diff_windows, ma_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.q + 1]
ma_out = F.pad(ma_out, (self.q - 1, 0)) # 通过零填充使 ma_out 和 x_diff 的时间维度一致
# 预测值
pred_diff = ar_out + ma_out
# 积分恢复原始序列
y_hat[:, :self.d] = x[:, :self.d] # 前 d 个值直接复制
y_hat[:, self.d:] = torch.cumsum(pred_diff[:, self.d:], dim=1) + x[:, self.d - 1:self.d]
# 添加偏移项(如果有)
if drift is not None:
y_hat += drift
return y_hat
if __name__ == '__main__':
# 输入数据 [batch_size, time, num_node, dim]
batch_size, time, num_node, dim = 2, 20, 3, 1
x = torch.randn(batch_size, time, num_node, dim)
# 模型参数
args = {
'p': 2, # 自回归阶数
'd': 1, # 差分阶数
'q': 1, # 移动平均阶数
'num_nodes': num_node,
'drift': True # 是否包含偏移项
}
# 初始化模型
model = ARIMA(args)
output = model(x)
print(output.shape) # 应为 [batch_size, time, num_node, 1]