import torch import torch.nn as nn from model.MegaCRN.MegaCRN import MegaCRN class MegaCRNModel(nn.Module): def __init__(self, args): super(MegaCRNModel, self).__init__() # 设置默认参数 if 'rnn_units' not in args: args['rnn_units'] = 64 if 'num_layers' not in args: args['num_layers'] = 1 if 'cheb_k' not in args: args['cheb_k'] = 3 if 'ycov_dim' not in args: args['ycov_dim'] = 1 if 'mem_num' not in args: args['mem_num'] = 20 if 'mem_dim' not in args: args['mem_dim'] = 64 if 'cl_decay_steps' not in args: args['cl_decay_steps'] = 2000 if 'use_curriculum_learning' not in args: args['use_curriculum_learning'] = True if 'horizon' not in args: args['horizon'] = 12 # 创建MegaCRN模型 self.model = MegaCRN( num_nodes=args['num_nodes'], input_dim=1, # 固定为1,因为我们只使用第一个通道 output_dim=args['output_dim'], horizon=args['horizon'], rnn_units=args['rnn_units'], num_layers=args['num_layers'], cheb_k=args['cheb_k'], ycov_dim=args['ycov_dim'], mem_num=args['mem_num'], mem_dim=args['mem_dim'], cl_decay_steps=args['cl_decay_steps'], use_curriculum_learning=args['use_curriculum_learning'] ) self.args = args self.batches_seen = 0 # 添加batches_seen计数器 def forward(self, x): # x shape: (batch_size, seq_len, num_nodes, features) # 按照DDGCRN的模式,只使用第一个通道 x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) # 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整) y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device) # 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整) labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device) # 前向传播 output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen) # 更新batches_seen self.batches_seen += 1 return output