TrafficWheel/model/ARIMA/ARIMA.py

126 lines
4.0 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
class ARIMA(nn.Module):
def __init__(self, args):
super(ARIMA, self).__init__()
self.p = args["p"] # 自回归阶数
self.d = args["d"] # 差分阶数
self.q = args["q"] # 移动平均阶数
self.num_node = args.get("num_nodes", 1) # 节点数量
# AR 参数
self.ar_params = nn.Parameter(torch.randn(self.p, self.num_node))
# MA 参数
self.ma_params = nn.Parameter(torch.randn(self.q, self.num_node))
# 偏移项(可选)
self.drift = (
nn.Parameter(torch.zeros(self.num_node))
if args.get("drift", False)
else None
)
def forward(self, x):
"""
输入: [batch_size, time, num_node, dim]
输出: [batch_size, time, num_node, 1]
"""
# 提取目标维度的输入
x = x[..., 0].unsqueeze(-1) # [batch_size, time, num_node, 1]
batch_size, time, num_node, _ = x.shape
# 初始化输出张量
y_hat = torch.zeros_like(x)
# 逐节点处理
for n in range(num_node):
node_x = x[:, :, n, 0] # 当前节点的时间序列 [batch_size, time]
drift = self.drift[n] if self.drift is not None else None
node_y_hat = self._arima_forward(
node_x, self.ar_params[:, n], self.ma_params[:, n], drift
)
y_hat[:, :, n, 0] = node_y_hat
return y_hat
def _arima_forward(self, x, ar_params, ma_params, drift):
"""
单节点 ARIMA 前向传播
输入: x [batch_size, time]
输出: y_hat [batch_size, time]
"""
batch_size, time = x.shape
y_hat = torch.zeros_like(x)
# 差分
x_diff = x
for _ in range(self.d):
x_diff = torch.diff(
x_diff, dim=1, prepend=x_diff[:, :1]
) # 使用 prepend 保持时间维度不变
# 自回归部分
ar_out = torch.zeros_like(x_diff)
if self.p > 0:
# 使用 unfold 方法创建滑动窗口
x_diff_windows = x_diff.unfold(
1, self.p, 1
) # [batch_size, time - self.p + 1, self.p]
ar_out = torch.matmul(x_diff_windows, ar_params.reshape(-1, 1)).squeeze(
-1
) # [batch_size, time - self.p + 1]
ar_out = F.pad(
ar_out, (self.p - 1, 0)
) # 通过零填充使 ar_out 和 x_diff 的时间维度一致
# 移动平均部分
ma_out = torch.zeros_like(x_diff)
if self.q > 0:
# 使用 unfold 方法创建滑动窗口
x_diff_windows = x_diff.unfold(
1, self.q, 1
) # [batch_size, time - self.q + 1, self.q]
ma_out = torch.matmul(x_diff_windows, ma_params.reshape(-1, 1)).squeeze(
-1
) # [batch_size, time - self.q + 1]
ma_out = F.pad(
ma_out, (self.q - 1, 0)
) # 通过零填充使 ma_out 和 x_diff 的时间维度一致
# 预测值
pred_diff = ar_out + ma_out
# 积分恢复原始序列
y_hat[:, : self.d] = x[:, : self.d] # 前 d 个值直接复制
y_hat[:, self.d :] = (
torch.cumsum(pred_diff[:, self.d :], dim=1) + x[:, self.d - 1 : self.d]
)
# 添加偏移项(如果有)
if drift is not None:
y_hat += drift
return y_hat
if __name__ == "__main__":
# 输入数据 [batch_size, time, num_node, dim]
batch_size, time, num_node, dim = 2, 20, 3, 1
x = torch.randn(batch_size, time, num_node, dim)
# 模型参数
args = {
"p": 2, # 自回归阶数
"d": 1, # 差分阶数
"q": 1, # 移动平均阶数
"num_nodes": num_node,
"drift": True, # 是否包含偏移项
}
# 初始化模型
model = ARIMA(args)
output = model(x)
print(output.shape) # 应为 [batch_size, time, num_node, 1]