import torch import torch.nn as nn from model.ST-SSL.models import STSSL from model.ST-SSL.layers import STEncoder, MLP from data.get_adj import get_gso class STSSLModel(nn.Module): def __init__(self, args): super(STSSLModel, self).__init__() # 获取邻接矩阵 gso = get_gso(args) # 设置默认参数 if 'd_model' not in args: args['d_model'] = 64 if 'd_output' not in args: args['d_output'] = args['output_dim'] if 'input_length' not in args: args['input_length'] = args['n_his'] if 'dropout' not in args: args['dropout'] = 0.1 if 'nmb_prototype' not in args: args['nmb_prototype'] = 10 if 'batch_size' not in args: args['batch_size'] = 64 if 'shm_temp' not in args: args['shm_temp'] = 0.1 if 'yita' not in args: args['yita'] = 0.5 if 'percent' not in args: args['percent'] = 0.1 if 'device' not in args: args['device'] = 'cpu' # 创建ST-SSL模型 self.model = STSSL(args) def forward(self, x): # x shape: (batch_size, seq_len, num_nodes, features) batch_size, seq_len, num_nodes, features = x.shape # 获取邻接矩阵 graph = get_gso(self.args) # 调整输入格式 x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features) # 前向传播 repr1, repr2 = self.model(x, graph) # 预测 pred = self.model.predict(repr1, repr2) # 调整输出格式 pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features) return pred