TrafficWheel/model/D2STGNN/diffusion_block/forecast.py

28 lines
1.2 KiB
Python

import torch
import torch.nn as nn
class Forecast(nn.Module):
def __init__(self, hidden_dim, forecast_hidden_dim=None, **model_args):
super().__init__()
self.k_t = model_args['k_t']
self.output_seq_len = model_args['horizon'] # 使用horizon作为目标序列长度
self.forecast_fc = nn.Linear(hidden_dim, forecast_hidden_dim)
self.model_args = model_args
def forward(self, gated_history_data, hidden_states_dif, localized_st_conv, dynamic_graph, static_graph):
predict = []
history = gated_history_data
predict.append(hidden_states_dif[:, -1, :, :].unsqueeze(1))
for _ in range(int(self.output_seq_len / self.model_args['gap'])-1):
_1 = predict[-self.k_t:]
if len(_1) < self.k_t:
sub = self.k_t - len(_1)
_2 = history[:, -sub:, :, :]
_1 = torch.cat([_2] + _1, dim=1)
else:
_1 = torch.cat(_1, dim=1)
predict.append(localized_st_conv(_1, dynamic_graph, static_graph))
predict = torch.cat(predict, dim=1)
predict = self.forecast_fc(predict)
return predict