67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from model.MegaCRN.MegaCRN import MegaCRN
|
||
|
||
class MegaCRNModel(nn.Module):
|
||
def __init__(self, args):
|
||
super(MegaCRNModel, self).__init__()
|
||
|
||
# 设置默认参数
|
||
if 'rnn_units' not in args:
|
||
args['rnn_units'] = 64
|
||
if 'num_layers' not in args:
|
||
args['num_layers'] = 1
|
||
if 'cheb_k' not in args:
|
||
args['cheb_k'] = 3
|
||
if 'ycov_dim' not in args:
|
||
args['ycov_dim'] = 1
|
||
if 'mem_num' not in args:
|
||
args['mem_num'] = 20
|
||
if 'mem_dim' not in args:
|
||
args['mem_dim'] = 64
|
||
if 'cl_decay_steps' not in args:
|
||
args['cl_decay_steps'] = 2000
|
||
if 'use_curriculum_learning' not in args:
|
||
args['use_curriculum_learning'] = True
|
||
if 'horizon' not in args:
|
||
args['horizon'] = 12
|
||
|
||
# 创建MegaCRN模型
|
||
self.model = MegaCRN(
|
||
num_nodes=args['num_nodes'],
|
||
input_dim=1, # 固定为1,因为我们只使用第一个通道
|
||
output_dim=args['output_dim'],
|
||
horizon=args['horizon'],
|
||
rnn_units=args['rnn_units'],
|
||
num_layers=args['num_layers'],
|
||
cheb_k=args['cheb_k'],
|
||
ycov_dim=args['ycov_dim'],
|
||
mem_num=args['mem_num'],
|
||
mem_dim=args['mem_dim'],
|
||
cl_decay_steps=args['cl_decay_steps'],
|
||
use_curriculum_learning=args['use_curriculum_learning']
|
||
)
|
||
|
||
self.args = args
|
||
self.batches_seen = 0 # 添加batches_seen计数器
|
||
|
||
def forward(self, x):
|
||
# x shape: (batch_size, seq_len, num_nodes, features)
|
||
# 按照DDGCRN的模式,只使用第一个通道
|
||
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||
|
||
# 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||
y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device)
|
||
|
||
# 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||
labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device)
|
||
|
||
# 前向传播
|
||
output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen)
|
||
|
||
# 更新batches_seen
|
||
self.batches_seen += 1
|
||
|
||
return output
|
||
|