import torch import torch.nn as nn from model.MegaCRN.MegaCRN import MegaCRN class MegaCRNModel(nn.Module): def __init__(self, args): super(MegaCRNModel, self).__init__() # 设置默认参数 if "rnn_units" not in args: args["rnn_units"] = 64 if "num_layers" not in args: args["num_layers"] = 1 if "cheb_k" not in args: args["cheb_k"] = 3 if "ycov_dim" not in args: args["ycov_dim"] = 1 if "mem_num" not in args: args["mem_num"] = 20 if "mem_dim" not in args: args["mem_dim"] = 64 if "cl_decay_steps" not in args: args["cl_decay_steps"] = 2000 if "use_curriculum_learning" not in args: args["use_curriculum_learning"] = True if "horizon" not in args: args["horizon"] = 12 # 创建MegaCRN模型 self.model = MegaCRN( num_nodes=args["num_nodes"], input_dim=1, # 固定为1,因为我们只使用第一个通道 output_dim=args["output_dim"], horizon=args["horizon"], rnn_units=args["rnn_units"], num_layers=args["num_layers"], cheb_k=args["cheb_k"], ycov_dim=args["ycov_dim"], mem_num=args["mem_num"], mem_dim=args["mem_dim"], cl_decay_steps=args["cl_decay_steps"], use_curriculum_learning=args["use_curriculum_learning"], ) self.args = args self.batches_seen = 0 # 添加batches_seen计数器 def forward(self, x): # x shape: (batch_size, seq_len, num_nodes, features) # 按照DDGCRN的模式,只使用第一个通道 x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) # 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整) y_cov = torch.zeros( x.shape[0], self.args["horizon"], x.shape[2], self.args["ycov_dim"], device=x.device, ) # 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整) labels = torch.zeros( x.shape[0], self.args["horizon"], x.shape[2], self.args["output_dim"], device=x.device, ) # 前向传播 output, h_att, query, pos, neg = self.model( x, y_cov, labels=labels, batches_seen=self.batches_seen ) # 更新batches_seen self.batches_seen += 1 return output