import torch import torch.nn as nn class Forecast(nn.Module): def __init__(self, hidden_dim, forecast_hidden_dim=None, **model_args): super().__init__() self.k_t = model_args['k_t'] self.output_seq_len = model_args['horizon'] # 使用horizon作为目标序列长度 self.forecast_fc = nn.Linear(hidden_dim, forecast_hidden_dim) self.model_args = model_args def forward(self, gated_history_data, hidden_states_dif, localized_st_conv, dynamic_graph, static_graph): predict = [] history = gated_history_data predict.append(hidden_states_dif[:, -1, :, :].unsqueeze(1)) for _ in range(int(self.output_seq_len / self.model_args['gap'])-1): _1 = predict[-self.k_t:] if len(_1) < self.k_t: sub = self.k_t - len(_1) _2 = history[:, -sub:, :, :] _1 = torch.cat([_2] + _1, dim=1) else: _1 = torch.cat(_1, dim=1) predict.append(localized_st_conv(_1, dynamic_graph, static_graph)) predict = torch.cat(predict, dim=1) predict = self.forecast_fc(predict) return predict