import torch import torch.nn as nn import torch.nn.functional as F class ARIMA(nn.Module): def __init__(self, args): super(ARIMA, self).__init__() self.p = args['p'] # 自回归阶数 self.d = args['d'] # 差分阶数 self.q = args['q'] # 移动平均阶数 self.num_node = args.get('num_nodes', 1) # 节点数量 # AR 参数 self.ar_params = nn.Parameter(torch.randn(self.p, self.num_node)) # MA 参数 self.ma_params = nn.Parameter(torch.randn(self.q, self.num_node)) # 偏移项(可选) self.drift = nn.Parameter(torch.zeros(self.num_node)) if args.get('drift', False) else None def forward(self, x): """ 输入: [batch_size, time, num_node, dim] 输出: [batch_size, time, num_node, 1] """ # 提取目标维度的输入 x = x[..., 0].unsqueeze(-1) # [batch_size, time, num_node, 1] batch_size, time, num_node, _ = x.shape # 初始化输出张量 y_hat = torch.zeros_like(x) # 逐节点处理 for n in range(num_node): node_x = x[:, :, n, 0] # 当前节点的时间序列 [batch_size, time] drift = self.drift[n] if self.drift is not None else None node_y_hat = self._arima_forward(node_x, self.ar_params[:, n], self.ma_params[:, n], drift) y_hat[:, :, n, 0] = node_y_hat return y_hat def _arima_forward(self, x, ar_params, ma_params, drift): """ 单节点 ARIMA 前向传播 输入: x [batch_size, time] 输出: y_hat [batch_size, time] """ batch_size, time = x.shape y_hat = torch.zeros_like(x) # 差分 x_diff = x for _ in range(self.d): x_diff = torch.diff(x_diff, dim=1, prepend=x_diff[:, :1]) # 使用 prepend 保持时间维度不变 # 自回归部分 ar_out = torch.zeros_like(x_diff) if self.p > 0: # 使用 unfold 方法创建滑动窗口 x_diff_windows = x_diff.unfold(1, self.p, 1) # [batch_size, time - self.p + 1, self.p] ar_out = torch.matmul(x_diff_windows, ar_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.p + 1] ar_out = F.pad(ar_out, (self.p - 1, 0)) # 通过零填充使 ar_out 和 x_diff 的时间维度一致 # 移动平均部分 ma_out = torch.zeros_like(x_diff) if self.q > 0: # 使用 unfold 方法创建滑动窗口 x_diff_windows = x_diff.unfold(1, self.q, 1) # [batch_size, time - self.q + 1, self.q] ma_out = torch.matmul(x_diff_windows, ma_params.reshape(-1, 1)).squeeze(-1) # [batch_size, time - self.q + 1] ma_out = F.pad(ma_out, (self.q - 1, 0)) # 通过零填充使 ma_out 和 x_diff 的时间维度一致 # 预测值 pred_diff = ar_out + ma_out # 积分恢复原始序列 y_hat[:, :self.d] = x[:, :self.d] # 前 d 个值直接复制 y_hat[:, self.d:] = torch.cumsum(pred_diff[:, self.d:], dim=1) + x[:, self.d - 1:self.d] # 添加偏移项(如果有) if drift is not None: y_hat += drift return y_hat if __name__ == '__main__': # 输入数据 [batch_size, time, num_node, dim] batch_size, time, num_node, dim = 2, 20, 3, 1 x = torch.randn(batch_size, time, num_node, dim) # 模型参数 args = { 'p': 2, # 自回归阶数 'd': 1, # 差分阶数 'q': 1, # 移动平均阶数 'num_nodes': num_node, 'drift': True # 是否包含偏移项 } # 初始化模型 model = ARIMA(args) output = model(x) print(output.shape) # 应为 [batch_size, time, num_node, 1]