101 lines
3.7 KiB
Python
Executable File
101 lines
3.7 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ARIMA(nn.Module):
|
|
def __init__(self, args):
|
|
super(ARIMA, self).__init__()
|
|
self.p = args['p'] # 自回归阶数
|
|
self.d = args['d'] # 差分阶数
|
|
self.q = args['q'] # 移动平均阶数
|
|
self.num_node = args.get('num_nodes', 1) # 节点数量
|
|
|
|
# AR 参数
|
|
self.ar_params = nn.Parameter(torch.randn(self.p, self.num_node))
|
|
# MA 参数
|
|
self.ma_params = nn.Parameter(torch.randn(self.q, self.num_node))
|
|
# 偏移项(可选)
|
|
self.drift = nn.Parameter(torch.zeros(self.num_node)) if args.get('drift', False) else None
|
|
|
|
def forward(self, x):
|
|
"""
|
|
输入: [batch_size, time, num_node, dim]
|
|
输出: [batch_size, time, num_node, 1]
|
|
"""
|
|
# 提取目标维度的输入
|
|
x = x[..., 0].unsqueeze(-1) # [batch_size, time, num_node, 1]
|
|
batch_size, time, num_node, _ = x.shape
|
|
|
|
# 初始化输出张量
|
|
y_hat = torch.zeros_like(x)
|
|
|
|
# 逐节点处理
|
|
for n in range(num_node):
|
|
node_x = x[:, :, n, 0] # 当前节点的时间序列 [batch_size, time]
|
|
drift = self.drift[n] if self.drift is not None else None
|
|
node_y_hat = self._arima_forward(node_x, self.ar_params[:, n], self.ma_params[:, n], drift)
|
|
y_hat[:, :, n, 0] = node_y_hat
|
|
|
|
return y_hat
|
|
|
|
def _arima_forward(self, x, ar_params, ma_params, drift):
|
|
"""
|
|
单节点 ARIMA 前向传播
|
|
输入: x [batch_size, time]
|
|
输出: y_hat [batch_size, time]
|
|
"""
|
|
batch_size, time = x.shape
|
|
y_hat = torch.zeros_like(x)
|
|
|
|
# 差分
|
|
x_diff = x
|
|
for _ in range(self.d):
|
|
x_diff = torch.diff(x_diff, dim=1, prepend=x_diff[:, :1]) # 使用 prepend 保持时间维度不变
|
|
|
|
# 自回归部分
|
|
ar_out = torch.zeros_like(x_diff)
|
|
if self.p > 0:
|
|
# 使用 unfold 方法创建滑动窗口
|
|
x_diff_windows = x_diff.unfold(1, self.p, 1) # [batch_size, time - self.p + 1, self.p]
|
|
ar_out = torch.matmul(x_diff_windows, ar_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.p + 1]
|
|
ar_out = F.pad(ar_out, (self.p - 1, 0)) # 通过零填充使 ar_out 和 x_diff 的时间维度一致
|
|
|
|
# 移动平均部分
|
|
ma_out = torch.zeros_like(x_diff)
|
|
if self.q > 0:
|
|
# 使用 unfold 方法创建滑动窗口
|
|
x_diff_windows = x_diff.unfold(1, self.q, 1) # [batch_size, time - self.q + 1, self.q]
|
|
ma_out = torch.matmul(x_diff_windows, ma_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.q + 1]
|
|
ma_out = F.pad(ma_out, (self.q - 1, 0)) # 通过零填充使 ma_out 和 x_diff 的时间维度一致
|
|
|
|
# 预测值
|
|
pred_diff = ar_out + ma_out
|
|
|
|
# 积分恢复原始序列
|
|
y_hat[:, :self.d] = x[:, :self.d] # 前 d 个值直接复制
|
|
y_hat[:, self.d:] = torch.cumsum(pred_diff[:, self.d:], dim=1) + x[:, self.d - 1:self.d]
|
|
|
|
# 添加偏移项(如果有)
|
|
if drift is not None:
|
|
y_hat += drift
|
|
|
|
return y_hat
|
|
|
|
if __name__ == '__main__':
|
|
# 输入数据 [batch_size, time, num_node, dim]
|
|
batch_size, time, num_node, dim = 2, 20, 3, 1
|
|
x = torch.randn(batch_size, time, num_node, dim)
|
|
|
|
# 模型参数
|
|
args = {
|
|
'p': 2, # 自回归阶数
|
|
'd': 1, # 差分阶数
|
|
'q': 1, # 移动平均阶数
|
|
'num_nodes': num_node,
|
|
'drift': True # 是否包含偏移项
|
|
}
|
|
|
|
# 初始化模型
|
|
model = ARIMA(args)
|
|
output = model(x)
|
|
print(output.shape) # 应为 [batch_size, time, num_node, 1] |