TrafficWheel/model/MegaCRN/MegaCRNModel.py

67 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from model.MegaCRN.MegaCRN import MegaCRN
class MegaCRNModel(nn.Module):
def __init__(self, args):
super(MegaCRNModel, self).__init__()
# 设置默认参数
if 'rnn_units' not in args:
args['rnn_units'] = 64
if 'num_layers' not in args:
args['num_layers'] = 1
if 'cheb_k' not in args:
args['cheb_k'] = 3
if 'ycov_dim' not in args:
args['ycov_dim'] = 1
if 'mem_num' not in args:
args['mem_num'] = 20
if 'mem_dim' not in args:
args['mem_dim'] = 64
if 'cl_decay_steps' not in args:
args['cl_decay_steps'] = 2000
if 'use_curriculum_learning' not in args:
args['use_curriculum_learning'] = True
if 'horizon' not in args:
args['horizon'] = 12
# 创建MegaCRN模型
self.model = MegaCRN(
num_nodes=args['num_nodes'],
input_dim=1, # 固定为1因为我们只使用第一个通道
output_dim=args['output_dim'],
horizon=args['horizon'],
rnn_units=args['rnn_units'],
num_layers=args['num_layers'],
cheb_k=args['cheb_k'],
ycov_dim=args['ycov_dim'],
mem_num=args['mem_num'],
mem_dim=args['mem_dim'],
cl_decay_steps=args['cl_decay_steps'],
use_curriculum_learning=args['use_curriculum_learning']
)
self.args = args
self.batches_seen = 0 # 添加batches_seen计数器
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
# 按照DDGCRN的模式只使用第一个通道
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
# 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整)
y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device)
# 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整)
labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device)
# 前向传播
output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen)
# 更新batches_seen
self.batches_seen += 1
return output