59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from model.ST-SSL.models import STSSL
|
|
from model.ST-SSL.layers import STEncoder, MLP
|
|
from data.get_adj import get_gso
|
|
|
|
class STSSLModel(nn.Module):
|
|
def __init__(self, args):
|
|
super(STSSLModel, self).__init__()
|
|
# 获取邻接矩阵
|
|
gso = get_gso(args)
|
|
|
|
# 设置默认参数
|
|
if 'd_model' not in args:
|
|
args['d_model'] = 64
|
|
if 'd_output' not in args:
|
|
args['d_output'] = args['output_dim']
|
|
if 'input_length' not in args:
|
|
args['input_length'] = args['n_his']
|
|
if 'dropout' not in args:
|
|
args['dropout'] = 0.1
|
|
if 'nmb_prototype' not in args:
|
|
args['nmb_prototype'] = 10
|
|
if 'batch_size' not in args:
|
|
args['batch_size'] = 64
|
|
if 'shm_temp' not in args:
|
|
args['shm_temp'] = 0.1
|
|
if 'yita' not in args:
|
|
args['yita'] = 0.5
|
|
if 'percent' not in args:
|
|
args['percent'] = 0.1
|
|
if 'device' not in args:
|
|
args['device'] = 'cpu'
|
|
|
|
# 创建ST-SSL模型
|
|
self.model = STSSL(args)
|
|
|
|
def forward(self, x):
|
|
# x shape: (batch_size, seq_len, num_nodes, features)
|
|
batch_size, seq_len, num_nodes, features = x.shape
|
|
|
|
# 获取邻接矩阵
|
|
graph = get_gso(self.args)
|
|
|
|
# 调整输入格式
|
|
x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features)
|
|
|
|
# 前向传播
|
|
repr1, repr2 = self.model(x, graph)
|
|
|
|
# 预测
|
|
pred = self.model.predict(repr1, repr2)
|
|
|
|
# 调整输出格式
|
|
pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features)
|
|
|
|
return pred
|
|
|