126 lines
4.0 KiB
Python
Executable File
126 lines
4.0 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class ARIMA(nn.Module):
|
|
def __init__(self, args):
|
|
super(ARIMA, self).__init__()
|
|
self.p = args["p"] # 自回归阶数
|
|
self.d = args["d"] # 差分阶数
|
|
self.q = args["q"] # 移动平均阶数
|
|
self.num_node = args.get("num_nodes", 1) # 节点数量
|
|
|
|
# AR 参数
|
|
self.ar_params = nn.Parameter(torch.randn(self.p, self.num_node))
|
|
# MA 参数
|
|
self.ma_params = nn.Parameter(torch.randn(self.q, self.num_node))
|
|
# 偏移项(可选)
|
|
self.drift = (
|
|
nn.Parameter(torch.zeros(self.num_node))
|
|
if args.get("drift", False)
|
|
else None
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
输入: [batch_size, time, num_node, dim]
|
|
输出: [batch_size, time, num_node, 1]
|
|
"""
|
|
# 提取目标维度的输入
|
|
x = x[..., 0].unsqueeze(-1) # [batch_size, time, num_node, 1]
|
|
batch_size, time, num_node, _ = x.shape
|
|
|
|
# 初始化输出张量
|
|
y_hat = torch.zeros_like(x)
|
|
|
|
# 逐节点处理
|
|
for n in range(num_node):
|
|
node_x = x[:, :, n, 0] # 当前节点的时间序列 [batch_size, time]
|
|
drift = self.drift[n] if self.drift is not None else None
|
|
node_y_hat = self._arima_forward(
|
|
node_x, self.ar_params[:, n], self.ma_params[:, n], drift
|
|
)
|
|
y_hat[:, :, n, 0] = node_y_hat
|
|
|
|
return y_hat
|
|
|
|
def _arima_forward(self, x, ar_params, ma_params, drift):
|
|
"""
|
|
单节点 ARIMA 前向传播
|
|
输入: x [batch_size, time]
|
|
输出: y_hat [batch_size, time]
|
|
"""
|
|
batch_size, time = x.shape
|
|
y_hat = torch.zeros_like(x)
|
|
|
|
# 差分
|
|
x_diff = x
|
|
for _ in range(self.d):
|
|
x_diff = torch.diff(
|
|
x_diff, dim=1, prepend=x_diff[:, :1]
|
|
) # 使用 prepend 保持时间维度不变
|
|
|
|
# 自回归部分
|
|
ar_out = torch.zeros_like(x_diff)
|
|
if self.p > 0:
|
|
# 使用 unfold 方法创建滑动窗口
|
|
x_diff_windows = x_diff.unfold(
|
|
1, self.p, 1
|
|
) # [batch_size, time - self.p + 1, self.p]
|
|
ar_out = torch.matmul(x_diff_windows, ar_params.reshape(-1, 1)).squeeze(
|
|
-1
|
|
) # [batch_size, time - self.p + 1]
|
|
ar_out = F.pad(
|
|
ar_out, (self.p - 1, 0)
|
|
) # 通过零填充使 ar_out 和 x_diff 的时间维度一致
|
|
|
|
# 移动平均部分
|
|
ma_out = torch.zeros_like(x_diff)
|
|
if self.q > 0:
|
|
# 使用 unfold 方法创建滑动窗口
|
|
x_diff_windows = x_diff.unfold(
|
|
1, self.q, 1
|
|
) # [batch_size, time - self.q + 1, self.q]
|
|
ma_out = torch.matmul(x_diff_windows, ma_params.reshape(-1, 1)).squeeze(
|
|
-1
|
|
) # [batch_size, time - self.q + 1]
|
|
ma_out = F.pad(
|
|
ma_out, (self.q - 1, 0)
|
|
) # 通过零填充使 ma_out 和 x_diff 的时间维度一致
|
|
|
|
# 预测值
|
|
pred_diff = ar_out + ma_out
|
|
|
|
# 积分恢复原始序列
|
|
y_hat[:, : self.d] = x[:, : self.d] # 前 d 个值直接复制
|
|
y_hat[:, self.d :] = (
|
|
torch.cumsum(pred_diff[:, self.d :], dim=1) + x[:, self.d - 1 : self.d]
|
|
)
|
|
|
|
# 添加偏移项(如果有)
|
|
if drift is not None:
|
|
y_hat += drift
|
|
|
|
return y_hat
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 输入数据 [batch_size, time, num_node, dim]
|
|
batch_size, time, num_node, dim = 2, 20, 3, 1
|
|
x = torch.randn(batch_size, time, num_node, dim)
|
|
|
|
# 模型参数
|
|
args = {
|
|
"p": 2, # 自回归阶数
|
|
"d": 1, # 差分阶数
|
|
"q": 1, # 移动平均阶数
|
|
"num_nodes": num_node,
|
|
"drift": True, # 是否包含偏移项
|
|
}
|
|
|
|
# 初始化模型
|
|
model = ARIMA(args)
|
|
output = model(x)
|
|
print(output.shape) # 应为 [batch_size, time, num_node, 1]
|