81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from model.MegaCRN.MegaCRN import MegaCRN
|
||
|
||
|
||
class MegaCRNModel(nn.Module):
|
||
def __init__(self, args):
|
||
super(MegaCRNModel, self).__init__()
|
||
|
||
# 设置默认参数
|
||
if "rnn_units" not in args:
|
||
args["rnn_units"] = 64
|
||
if "num_layers" not in args:
|
||
args["num_layers"] = 1
|
||
if "cheb_k" not in args:
|
||
args["cheb_k"] = 3
|
||
if "ycov_dim" not in args:
|
||
args["ycov_dim"] = 1
|
||
if "mem_num" not in args:
|
||
args["mem_num"] = 20
|
||
if "mem_dim" not in args:
|
||
args["mem_dim"] = 64
|
||
if "cl_decay_steps" not in args:
|
||
args["cl_decay_steps"] = 2000
|
||
if "use_curriculum_learning" not in args:
|
||
args["use_curriculum_learning"] = True
|
||
if "horizon" not in args:
|
||
args["horizon"] = 12
|
||
|
||
# 创建MegaCRN模型
|
||
self.model = MegaCRN(
|
||
num_nodes=args["num_nodes"],
|
||
input_dim=1, # 固定为1,因为我们只使用第一个通道
|
||
output_dim=args["output_dim"],
|
||
horizon=args["horizon"],
|
||
rnn_units=args["rnn_units"],
|
||
num_layers=args["num_layers"],
|
||
cheb_k=args["cheb_k"],
|
||
ycov_dim=args["ycov_dim"],
|
||
mem_num=args["mem_num"],
|
||
mem_dim=args["mem_dim"],
|
||
cl_decay_steps=args["cl_decay_steps"],
|
||
use_curriculum_learning=args["use_curriculum_learning"],
|
||
)
|
||
|
||
self.args = args
|
||
self.batches_seen = 0 # 添加batches_seen计数器
|
||
|
||
def forward(self, x):
|
||
# x shape: (batch_size, seq_len, num_nodes, features)
|
||
# 按照DDGCRN的模式,只使用第一个通道
|
||
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||
|
||
# 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||
y_cov = torch.zeros(
|
||
x.shape[0],
|
||
self.args["horizon"],
|
||
x.shape[2],
|
||
self.args["ycov_dim"],
|
||
device=x.device,
|
||
)
|
||
|
||
# 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||
labels = torch.zeros(
|
||
x.shape[0],
|
||
self.args["horizon"],
|
||
x.shape[2],
|
||
self.args["output_dim"],
|
||
device=x.device,
|
||
)
|
||
|
||
# 前向传播
|
||
output, h_att, query, pos, neg = self.model(
|
||
x, y_cov, labels=labels, batches_seen=self.batches_seen
|
||
)
|
||
|
||
# 更新batches_seen
|
||
self.batches_seen += 1
|
||
|
||
return output
|