TrafficWheel/model/ST_SSL/ST-SSL.py

58 lines
1.7 KiB
Python

import torch.nn as nn
from model.ST-SSL.models import STSSL
from model.ST-SSL.layers
from utils.get_adj import get_gso
class STSSLModel(nn.Module):
def __init__(self, args):
super(STSSLModel, self).__init__()
# 获取邻接矩阵
gso = get_gso(args)
# 设置默认参数
if 'd_model' not in args:
args['d_model'] = 64
if 'd_output' not in args:
args['d_output'] = args['output_dim']
if 'input_length' not in args:
args['input_length'] = args['n_his']
if 'dropout' not in args:
args['dropout'] = 0.1
if 'nmb_prototype' not in args:
args['nmb_prototype'] = 10
if 'batch_size' not in args:
args['batch_size'] = 64
if 'shm_temp' not in args:
args['shm_temp'] = 0.1
if 'yita' not in args:
args['yita'] = 0.5
if 'percent' not in args:
args['percent'] = 0.1
if 'device' not in args:
args['device'] = 'cpu'
# 创建ST-SSL模型
self.model = STSSL(args)
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
batch_size, seq_len, num_nodes, features = x.shape
# 获取邻接矩阵
graph = get_gso(self.args)
# 调整输入格式
x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features)
# 前向传播
repr1, repr2 = self.model(x, graph)
# 预测
pred = self.model.predict(repr1, repr2)
# 调整输出格式
pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features)
return pred