From 538548db0befc5563610f3cf580cce7df4e172ce Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 18 Aug 2025 21:49:14 +0800 Subject: [PATCH] add-model --- model/MegaCRN/MegaCRN.py | 227 +++++++++++++++++++ model/MegaCRN/MegaCRNModel.py | 66 ++++++ model/ST_SSL/ST-SSL.py | 58 +++++ model/ST_SSL/ST_SSL.py | 128 +++++++++++ model/ST_SSL/aug.py | 103 +++++++++ model/ST_SSL/layers.py | 103 +++++++++ model/ST_SSL/models.py | 156 +++++++++++++ model/TEDDCF/ISTF.py | 108 +++++++++ model/TEDDCF/model.py | 414 ++++++++++++++++++++++++++++++++++ 9 files changed, 1363 insertions(+) create mode 100644 model/MegaCRN/MegaCRN.py create mode 100644 model/MegaCRN/MegaCRNModel.py create mode 100644 model/ST_SSL/ST-SSL.py create mode 100644 model/ST_SSL/ST_SSL.py create mode 100644 model/ST_SSL/aug.py create mode 100644 model/ST_SSL/layers.py create mode 100644 model/ST_SSL/models.py create mode 100644 model/TEDDCF/ISTF.py create mode 100644 model/TEDDCF/model.py diff --git a/model/MegaCRN/MegaCRN.py b/model/MegaCRN/MegaCRN.py new file mode 100644 index 0000000..00f0c48 --- /dev/null +++ b/model/MegaCRN/MegaCRN.py @@ -0,0 +1,227 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import math +import numpy as np + +class AGCN(nn.Module): + def __init__(self, dim_in, dim_out, cheb_k): + super(AGCN, self).__init__() + self.cheb_k = cheb_k + self.weights = nn.Parameter(torch.FloatTensor(2*cheb_k*dim_in, dim_out)) # 2 is the length of support + self.bias = nn.Parameter(torch.FloatTensor(dim_out)) + nn.init.xavier_normal_(self.weights) + nn.init.constant_(self.bias, val=0) + + def forward(self, x, supports): + x_g = [] + support_set = [] + for support in supports: + support_ks = [torch.eye(support.shape[0]).to(support.device), support] + for k in range(2, self.cheb_k): + support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2]) + support_set.extend(support_ks) + for support in support_set: + x_g.append(torch.einsum("nm,bmc->bnc", support, x)) + x_g = torch.cat(x_g, dim=-1) # B, N, 2 * cheb_k * dim_in + x_gconv = torch.einsum('bni,io->bno', x_g, self.weights) + self.bias # b, N, dim_out + return x_gconv + +class AGCRNCell(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k): + super(AGCRNCell, self).__init__() + self.node_num = node_num + self.hidden_dim = dim_out + self.gate = AGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k) + self.update = AGCN(dim_in+self.hidden_dim, dim_out, cheb_k) + + def forward(self, x, state, supports): + #x: B, num_nodes, input_dim + #state: B, num_nodes, hidden_dim + state = state.to(x.device) + input_and_state = torch.cat((x, state), dim=-1) + z_r = torch.sigmoid(self.gate(input_and_state, supports)) + z, r = torch.split(z_r, self.hidden_dim, dim=-1) + candidate = torch.cat((x, z*state), dim=-1) + hc = torch.tanh(self.update(candidate, supports)) + h = r*state + (1-r)*hc + return h + + def init_hidden_state(self, batch_size): + return torch.zeros(batch_size, self.node_num, self.hidden_dim) + +class ADCRNN_Encoder(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers): + super(ADCRNN_Encoder, self).__init__() + assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' + self.node_num = node_num + self.input_dim = dim_in + self.num_layers = num_layers + self.dcrnn_cells = nn.ModuleList() + self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k)) + for _ in range(1, num_layers): + self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k)) + + def forward(self, x, init_state, supports): + #shape of x: (B, T, N, D), shape of init_state: (num_layers, B, N, hidden_dim) + assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim + seq_length = x.shape[1] + current_inputs = x + output_hidden = [] + for i in range(self.num_layers): + state = init_state[i] + inner_states = [] + for t in range(seq_length): + state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, supports) + inner_states.append(state) + output_hidden.append(state) + current_inputs = torch.stack(inner_states, dim=1) + #current_inputs: the outputs of last layer: (B, T, N, hidden_dim) + #last_state: (B, N, hidden_dim) + #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) + #return current_inputs, torch.stack(output_hidden, dim=0) + return current_inputs, output_hidden + + def init_hidden(self, batch_size): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) + return init_states + +class ADCRNN_Decoder(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers): + super(ADCRNN_Decoder, self).__init__() + assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.' + self.node_num = node_num + self.input_dim = dim_in + self.num_layers = num_layers + self.dcrnn_cells = nn.ModuleList() + self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k)) + for _ in range(1, num_layers): + self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k)) + + def forward(self, xt, init_state, supports): + # xt: (B, N, D) + # init_state: (num_layers, B, N, hidden_dim) + assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim + current_inputs = xt + output_hidden = [] + for i in range(self.num_layers): + state = self.dcrnn_cells[i](current_inputs, init_state[i], supports) + output_hidden.append(state) + current_inputs = state + return current_inputs, output_hidden + + +class MegaCRN(nn.Module): + def __init__(self, num_nodes, input_dim, output_dim, horizon, rnn_units, num_layers=1, cheb_k=3, + ycov_dim=1, mem_num=20, mem_dim=64, cl_decay_steps=2000, use_curriculum_learning=True): + super(MegaCRN, self).__init__() + self.num_nodes = num_nodes + self.input_dim = input_dim + self.rnn_units = rnn_units + self.output_dim = output_dim + self.horizon = horizon + self.num_layers = num_layers + self.cheb_k = cheb_k + self.ycov_dim = ycov_dim + self.cl_decay_steps = cl_decay_steps + self.use_curriculum_learning = use_curriculum_learning + + # memory + self.mem_num = mem_num + self.mem_dim = mem_dim + self.memory = self.construct_memory() + + # encoder + self.encoder = ADCRNN_Encoder(self.num_nodes, self.input_dim, self.rnn_units, self.cheb_k, self.num_layers) + + # deocoder + self.decoder_dim = self.rnn_units + self.mem_dim + self.decoder = ADCRNN_Decoder(self.num_nodes, self.output_dim + self.ycov_dim, self.decoder_dim, self.cheb_k, self.num_layers) + + # output + self.proj = nn.Sequential(nn.Linear(self.decoder_dim, self.output_dim, bias=True)) + + def compute_sampling_threshold(self, batches_seen): + return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) + + def construct_memory(self): + memory_dict = nn.ParameterDict() + memory_dict['Memory'] = nn.Parameter(torch.randn(self.mem_num, self.mem_dim), requires_grad=True) # (M, d) + memory_dict['Wq'] = nn.Parameter(torch.randn(self.rnn_units, self.mem_dim), requires_grad=True) # project to query + memory_dict['We1'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding + memory_dict['We2'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding + for param in memory_dict.values(): + nn.init.xavier_normal_(param) + return memory_dict + + def query_memory(self, h_t:torch.Tensor): + query = torch.matmul(h_t, self.memory['Wq']) # (B, N, d) + att_score = torch.softmax(torch.matmul(query, self.memory['Memory'].t()), dim=-1) # alpha: (B, N, M) + value = torch.matmul(att_score, self.memory['Memory']) # (B, N, d) + _, ind = torch.topk(att_score, k=2, dim=-1) + pos = self.memory['Memory'][ind[:, :, 0]] # B, N, d + neg = self.memory['Memory'][ind[:, :, 1]] # B, N, d + return value, query, pos, neg + + def forward(self, x, y_cov, labels=None, batches_seen=None): + node_embeddings1 = torch.matmul(self.memory['We1'], self.memory['Memory']) + node_embeddings2 = torch.matmul(self.memory['We2'], self.memory['Memory']) + g1 = F.softmax(F.relu(torch.mm(node_embeddings1, node_embeddings2.T)), dim=-1) + g2 = F.softmax(F.relu(torch.mm(node_embeddings2, node_embeddings1.T)), dim=-1) + supports = [g1, g2] + init_state = self.encoder.init_hidden(x.shape[0]) + h_en, state_en = self.encoder(x, init_state, supports) # B, T, N, hidden + h_t = h_en[:, -1, :, :] # B, N, hidden (last state) + + h_att, query, pos, neg = self.query_memory(h_t) + h_t = torch.cat([h_t, h_att], dim=-1) + + ht_list = [h_t]*self.num_layers + go = torch.zeros((x.shape[0], self.num_nodes, self.output_dim), device=x.device) + out = [] + for t in range(self.horizon): + h_de, ht_list = self.decoder(torch.cat([go, y_cov[:, t, ...]], dim=-1), ht_list, supports) + go = self.proj(h_de) + out.append(go) + if self.training and self.use_curriculum_learning: + c = np.random.uniform(0, 1) + if c < self.compute_sampling_threshold(batches_seen): + go = labels[:, t, ...] + output = torch.stack(out, dim=1) + + return output, h_att, query, pos, neg + +def print_params(model): + # print trainable params + param_count = 0 + print('Trainable parameter list:') + for name, param in model.named_parameters(): + if param.requires_grad: + print(name, param.shape, param.numel()) + param_count += param.numel() + print(f'In total: {param_count} trainable parameters. \n') + return + +def main(): + import sys + import argparse + from torchsummary import summary + parser = argparse.ArgumentParser() + parser.add_argument("--gpu", type=int, default=3, help="which GPU to use") + parser.add_argument('--num_variable', type=int, default=207, help='number of variables (e.g., 207 in METR-LA, 325 in PEMS-BAY)') + parser.add_argument('--his_len', type=int, default=12, help='sequence length of historical observation') + parser.add_argument('--seq_len', type=int, default=12, help='sequence length of prediction') + parser.add_argument('--channelin', type=int, default=1, help='number of input channel') + parser.add_argument('--channelout', type=int, default=1, help='number of output channel') + parser.add_argument('--rnn_units', type=int, default=64, help='number of hidden units') + args = parser.parse_args() + device = torch.device("cuda:{}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu") + model = MegaCRN(num_nodes=args.num_variable, input_dim=args.channelin, output_dim=args.channelout, horizon=args.seq_len, rnn_units=args.rnn_units).to(device) + summary(model, [(args.his_len, args.num_variable, args.channelin), (args.seq_len, args.num_variable, args.channelout)], device=device) + print_params(model) + +if __name__ == '__main__': + main() + diff --git a/model/MegaCRN/MegaCRNModel.py b/model/MegaCRN/MegaCRNModel.py new file mode 100644 index 0000000..fb6e6f0 --- /dev/null +++ b/model/MegaCRN/MegaCRNModel.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from model.MegaCRN.MegaCRN import MegaCRN + +class MegaCRNModel(nn.Module): + def __init__(self, args): + super(MegaCRNModel, self).__init__() + + # 设置默认参数 + if 'rnn_units' not in args: + args['rnn_units'] = 64 + if 'num_layers' not in args: + args['num_layers'] = 1 + if 'cheb_k' not in args: + args['cheb_k'] = 3 + if 'ycov_dim' not in args: + args['ycov_dim'] = 1 + if 'mem_num' not in args: + args['mem_num'] = 20 + if 'mem_dim' not in args: + args['mem_dim'] = 64 + if 'cl_decay_steps' not in args: + args['cl_decay_steps'] = 2000 + if 'use_curriculum_learning' not in args: + args['use_curriculum_learning'] = True + if 'horizon' not in args: + args['horizon'] = 12 + + # 创建MegaCRN模型 + self.model = MegaCRN( + num_nodes=args['num_nodes'], + input_dim=1, # 固定为1,因为我们只使用第一个通道 + output_dim=args['output_dim'], + horizon=args['horizon'], + rnn_units=args['rnn_units'], + num_layers=args['num_layers'], + cheb_k=args['cheb_k'], + ycov_dim=args['ycov_dim'], + mem_num=args['mem_num'], + mem_dim=args['mem_dim'], + cl_decay_steps=args['cl_decay_steps'], + use_curriculum_learning=args['use_curriculum_learning'] + ) + + self.args = args + self.batches_seen = 0 # 添加batches_seen计数器 + + def forward(self, x): + # x shape: (batch_size, seq_len, num_nodes, features) + # 按照DDGCRN的模式,只使用第一个通道 + x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) + + # 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整) + y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device) + + # 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整) + labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device) + + # 前向传播 + output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen) + + # 更新batches_seen + self.batches_seen += 1 + + return output + diff --git a/model/ST_SSL/ST-SSL.py b/model/ST_SSL/ST-SSL.py new file mode 100644 index 0000000..1aaa647 --- /dev/null +++ b/model/ST_SSL/ST-SSL.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from model.ST-SSL.models import STSSL +from model.ST-SSL.layers import STEncoder, MLP +from data.get_adj import get_gso + +class STSSLModel(nn.Module): + def __init__(self, args): + super(STSSLModel, self).__init__() + # 获取邻接矩阵 + gso = get_gso(args) + + # 设置默认参数 + if 'd_model' not in args: + args['d_model'] = 64 + if 'd_output' not in args: + args['d_output'] = args['output_dim'] + if 'input_length' not in args: + args['input_length'] = args['n_his'] + if 'dropout' not in args: + args['dropout'] = 0.1 + if 'nmb_prototype' not in args: + args['nmb_prototype'] = 10 + if 'batch_size' not in args: + args['batch_size'] = 64 + if 'shm_temp' not in args: + args['shm_temp'] = 0.1 + if 'yita' not in args: + args['yita'] = 0.5 + if 'percent' not in args: + args['percent'] = 0.1 + if 'device' not in args: + args['device'] = 'cpu' + + # 创建ST-SSL模型 + self.model = STSSL(args) + + def forward(self, x): + # x shape: (batch_size, seq_len, num_nodes, features) + batch_size, seq_len, num_nodes, features = x.shape + + # 获取邻接矩阵 + graph = get_gso(self.args) + + # 调整输入格式 + x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features) + + # 前向传播 + repr1, repr2 = self.model(x, graph) + + # 预测 + pred = self.model.predict(repr1, repr2) + + # 调整输出格式 + pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features) + + return pred + diff --git a/model/ST_SSL/ST_SSL.py b/model/ST_SSL/ST_SSL.py new file mode 100644 index 0000000..7deeb21 --- /dev/null +++ b/model/ST_SSL/ST_SSL.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from data.get_adj import get_gso + +class STSSLModel(nn.Module): + def __init__(self, args): + super(STSSLModel, self).__init__() + + # 设置默认参数 + if 'd_model' not in args: + args['d_model'] = 64 + if 'd_output' not in args: + args['d_output'] = args['output_dim'] + if 'input_length' not in args: + args['input_length'] = args['n_his'] + if 'dropout' not in args: + args['dropout'] = 0.1 + if 'nmb_prototype' not in args: + args['nmb_prototype'] = 10 + if 'batch_size' not in args: + args['batch_size'] = 64 + if 'shm_temp' not in args: + args['shm_temp'] = 0.1 + if 'yita' not in args: + args['yita'] = 0.5 + if 'percent' not in args: + args['percent'] = 0.1 + if 'device' not in args: + args['device'] = 'cpu' + if 'gso_type' not in args: + args['gso_type'] = 'sym_norm_lap' + if 'graph_conv_type' not in args: + args['graph_conv_type'] = 'cheb_graph_conv' + + # 保存参数 + self.args = args + self.num_nodes = args['num_nodes'] + self.input_dim = args['input_dim'] + self.output_dim = args['output_dim'] + self.horizon = args['horizon'] + self.d_model = args['d_model'] + + # 获取邻接矩阵 + self.gso = get_gso(args) + + # 时间嵌入 + self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model'])) + self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model'])) + + # 节点嵌入 + self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model'])) + self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model'])) + + # 编码器 - 使用1个输入通道 + self.encoder = STEncoder( + Kt=3, Ks=3, + input_dim=1, # 只使用第一个通道 + hidden_dim=args['d_model'], + input_length=args['input_length'], + num_nodes=args['num_nodes'], + droprate=args['dropout'] + ) + + # 预测头 + self.predictor = nn.Linear(args['d_model'], args['output_dim']) + + # 初始化参数 + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.node_emb_u) + nn.init.xavier_uniform_(self.node_emb_d) + nn.init.xavier_uniform_(self.T_i_D_emb) + nn.init.xavier_uniform_(self.D_i_W_emb) + + def forward(self, x): + # x shape: (batch_size, seq_len, num_nodes, features) + # 按照DDGCRN的模式,只使用第一个通道 + x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) + + # 编码 + encoded = self.encoder(x, self.gso) + + # 预测 + # 取最后一个时间步的输出进行预测 + last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model) + + # 预测未来horizon个时间步 + predictions = [] + for t in range(self.horizon): + pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim) + predictions.append(pred) + + # 堆叠预测结果 + output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim) + + return output + + +class STEncoder(nn.Module): + def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate): + super(STEncoder, self).__init__() + self.num_nodes = num_nodes + self.input_length = input_length + + # 简化的时空编码器 - 使用1个输入通道 + self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) + self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) + self.dropout = nn.Dropout(droprate) + + def forward(self, x, graph): + # x: (batch_size, seq_len, num_nodes, features) + batch_size, seq_len, num_nodes, features = x.shape + + # 调整维度 + x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes) + + # 卷积操作 + x = F.relu(self.conv1(x)) + x = self.dropout(x) + x = F.relu(self.conv2(x)) + + # 调整回原维度 + x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features) + + return x + diff --git a/model/ST_SSL/aug.py b/model/ST_SSL/aug.py new file mode 100644 index 0000000..db764c2 --- /dev/null +++ b/model/ST_SSL/aug.py @@ -0,0 +1,103 @@ +import copy +import numpy as np +import torch + +def sim_global(flow_data, sim_type='cos'): + """Calculate the global similarity of traffic flow data. + :param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c] + :param type: str, type of similarity, attention or cosine. ['att', 'cos'] + :return sim: tensor, symmetric similarity, [v,v] + """ + if len(flow_data.shape) == 4: + n,l,v,c = flow_data.shape + att_scaling = n * l * c + cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N + sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data) + elif len(flow_data.shape) == 3: + n,v,c = flow_data.shape + att_scaling = n * c + cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N + sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data) + else: + raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape))) + + if sim_type == 'cos': + # cosine similarity + scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling) + sim = sim * scaling + elif sim_type == 'att': + # scaled dot product similarity + scaling = float(att_scaling) ** -0.5 + sim = torch.softmax(sim * scaling, dim=-1) + else: + raise ValueError('sim_global only support sim_type in [att, cos].') + + return sim + +def aug_topology(sim_mx, input_graph, percent=0.2): + """Generate the data augumentation from topology (graph structure) perspective + for undirected graph without self-loop. + :param sim_mx: tensor, symmetric similarity, [v,v] + :param input_graph: tensor, adjacency matrix without self-loop, [v,v] + :return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v] + """ + ## edge dropping starts here + drop_percent = percent / 2 + + index_list = input_graph.nonzero() # list of edges [row_idx, col_idx] + + edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges + edge_mask = (input_graph > 0).tril(diagonal=-1) + add_drop_num = int(edge_num * drop_percent / 2) + aug_graph = copy.deepcopy(input_graph) + + drop_prob = torch.softmax(sim_mx[edge_mask], dim=0) + drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability + drop_prob /= drop_prob.sum() + drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob) + drop_index = index_list[drop_list] + + zeros = torch.zeros_like(aug_graph[0, 0]) + aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros + aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros + + ## edge adding starts here + node_num = input_graph.shape[0] + x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij') + mask = y < x + x, y = x[mask], y[mask] + + add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy() + add_prob = torch.softmax(add_prob, dim=0).numpy() + add_list = np.random.choice(int((node_num * node_num - node_num) / 2), + size=add_drop_num, p=add_prob) + + ones = torch.ones_like(aug_graph[0, 0]) + aug_graph[x[add_list], y[add_list]] = ones + aug_graph[y[add_list], x[add_list]] = ones + + return aug_graph + +def aug_traffic(t_sim_mx, flow_data, percent=0.2): + """Generate the data augumentation from traffic (node attribute) perspective. + :param t_sim_mx: temporal similarity matrix after softmax, [l,n,v] + :param flow_data: input flow data, [n,l,v,c] + """ + l, n, v = t_sim_mx.shape + mask_num = int(n * l * v * percent) + aug_flow = copy.deepcopy(flow_data) + + mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy() + mask_prob /= mask_prob.sum() + + x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij') + mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob) + + zeros = torch.zeros_like(aug_flow[0, 0, 0]) + aug_flow[ + x.reshape(-1)[mask_list], + y.reshape(-1)[mask_list], + z.reshape(-1)[mask_list]] = zeros + + return aug_flow + diff --git a/model/ST_SSL/layers.py b/model/ST_SSL/layers.py new file mode 100644 index 0000000..db764c2 --- /dev/null +++ b/model/ST_SSL/layers.py @@ -0,0 +1,103 @@ +import copy +import numpy as np +import torch + +def sim_global(flow_data, sim_type='cos'): + """Calculate the global similarity of traffic flow data. + :param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c] + :param type: str, type of similarity, attention or cosine. ['att', 'cos'] + :return sim: tensor, symmetric similarity, [v,v] + """ + if len(flow_data.shape) == 4: + n,l,v,c = flow_data.shape + att_scaling = n * l * c + cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N + sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data) + elif len(flow_data.shape) == 3: + n,v,c = flow_data.shape + att_scaling = n * c + cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N + sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data) + else: + raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape))) + + if sim_type == 'cos': + # cosine similarity + scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling) + sim = sim * scaling + elif sim_type == 'att': + # scaled dot product similarity + scaling = float(att_scaling) ** -0.5 + sim = torch.softmax(sim * scaling, dim=-1) + else: + raise ValueError('sim_global only support sim_type in [att, cos].') + + return sim + +def aug_topology(sim_mx, input_graph, percent=0.2): + """Generate the data augumentation from topology (graph structure) perspective + for undirected graph without self-loop. + :param sim_mx: tensor, symmetric similarity, [v,v] + :param input_graph: tensor, adjacency matrix without self-loop, [v,v] + :return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v] + """ + ## edge dropping starts here + drop_percent = percent / 2 + + index_list = input_graph.nonzero() # list of edges [row_idx, col_idx] + + edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges + edge_mask = (input_graph > 0).tril(diagonal=-1) + add_drop_num = int(edge_num * drop_percent / 2) + aug_graph = copy.deepcopy(input_graph) + + drop_prob = torch.softmax(sim_mx[edge_mask], dim=0) + drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability + drop_prob /= drop_prob.sum() + drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob) + drop_index = index_list[drop_list] + + zeros = torch.zeros_like(aug_graph[0, 0]) + aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros + aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros + + ## edge adding starts here + node_num = input_graph.shape[0] + x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij') + mask = y < x + x, y = x[mask], y[mask] + + add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy() + add_prob = torch.softmax(add_prob, dim=0).numpy() + add_list = np.random.choice(int((node_num * node_num - node_num) / 2), + size=add_drop_num, p=add_prob) + + ones = torch.ones_like(aug_graph[0, 0]) + aug_graph[x[add_list], y[add_list]] = ones + aug_graph[y[add_list], x[add_list]] = ones + + return aug_graph + +def aug_traffic(t_sim_mx, flow_data, percent=0.2): + """Generate the data augumentation from traffic (node attribute) perspective. + :param t_sim_mx: temporal similarity matrix after softmax, [l,n,v] + :param flow_data: input flow data, [n,l,v,c] + """ + l, n, v = t_sim_mx.shape + mask_num = int(n * l * v * percent) + aug_flow = copy.deepcopy(flow_data) + + mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy() + mask_prob /= mask_prob.sum() + + x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij') + mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob) + + zeros = torch.zeros_like(aug_flow[0, 0, 0]) + aug_flow[ + x.reshape(-1)[mask_list], + y.reshape(-1)[mask_list], + z.reshape(-1)[mask_list]] = zeros + + return aug_flow + diff --git a/model/ST_SSL/models.py b/model/ST_SSL/models.py new file mode 100644 index 0000000..c128536 --- /dev/null +++ b/model/ST_SSL/models.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 简化的masked_mae_loss函数 +def masked_mae_loss(mask_value=5.0): + def loss_fn(pred, target): + mask = (target != mask_value).float() + mae = F.l1_loss(pred * mask, target * mask, reduction='sum') + return mae / (mask.sum() + 1e-8) + return loss_fn + +# 简化的数据增强函数 +def aug_topology(sim_mx, graph, percent=0.1): + return graph + +def aug_traffic(sim_mx, data, percent=0.1): + return data + +class STEncoder(nn.Module): + def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate): + super(STEncoder, self).__init__() + self.num_nodes = num_nodes + self.input_length = input_length + + # 简化的编码器 - 修复输入通道数 + self.conv1 = nn.Conv2d(blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) + self.conv2 = nn.Conv2d(blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) + self.dropout = nn.Dropout(droprate) + + # 临时的相似度矩阵 + self.s_sim_mx = torch.randn(num_nodes, num_nodes) + self.t_sim_mx = torch.randn(input_length, input_length) + + def forward(self, x, graph): + # x: (batch_size, num_nodes, seq_len, features) + batch_size, num_nodes, seq_len, features = x.shape + + # 调整维度 + x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len) + + # 确保输入通道数正确 + if x.shape[1] != 2: # 如果不是2个通道,需要调整 + if x.shape[1] == 1: + x = x.repeat(1, 2, 1, 1) # 复制到2个通道 + else: + x = x[:, :2, :, :] # 取前2个通道 + + # 卷积操作 + x = F.relu(self.conv1(x)) + x = self.dropout(x) + x = F.relu(self.conv2(x)) + + # 调整回原维度 + x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features) + + return x + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim): + super(MLP, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + +class TemporalHeteroModel(nn.Module): + def __init__(self, d_model, batch_size, num_nodes, device): + super(TemporalHeteroModel, self).__init__() + self.fc = nn.Linear(d_model, 1) + + def forward(self, z1, z2): + return F.mse_loss(self.fc(z1), self.fc(z2)) + +class SpatialHeteroModel(nn.Module): + def __init__(self, d_model, nmb_prototype, batch_size, shm_temp): + super(SpatialHeteroModel, self).__init__() + self.fc = nn.Linear(d_model, 1) + + def forward(self, z1, z2): + return F.mse_loss(self.fc(z1), self.fc(z2)) + +class STSSL(nn.Module): + def __init__(self, args): + super(STSSL, self).__init__() + # spatial temporal encoder + self.encoder = STEncoder(Kt=3, Ks=3, blocks=[[2, int(args['d_model']//2), args['d_model']], [args['d_model'], int(args['d_model']//2), args['d_model']]], + input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout']) + + # traffic flow prediction branch + self.mlp = MLP(args['d_model'], args['d_output']) + # temporal heterogenrity modeling branch + self.thm = TemporalHeteroModel(args['d_model'], args['batch_size'], args['num_nodes'], args['device']) + # spatial heterogenrity modeling branch + self.shm = SpatialHeteroModel(args['d_model'], args['nmb_prototype'], args['batch_size'], args['shm_temp']) + self.mae = masked_mae_loss(mask_value=5.0) + self.args = args + + def forward(self, view1, graph): + repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v + + s_sim_mx = self.fetch_spatial_sim() + graph2 = aug_topology(s_sim_mx, graph, percent=self.args['percent']*2) + + t_sim_mx = self.fetch_temporal_sim() + view2 = aug_traffic(t_sim_mx, view1, percent=self.args['percent']) + + repr2 = self.encoder(view2, graph2) + return repr1, repr2 + + def fetch_spatial_sim(self): + """ + Fetch the region similarity matrix generated by region embedding. + Note this can be called only when spatial_sim is True. + :return sim_mx: tensor, similarity matrix, (v, v) + """ + return self.encoder.s_sim_mx.cpu() + + def fetch_temporal_sim(self): + return self.encoder.t_sim_mx.cpu() + + def predict(self, z1, z2): + '''Predicting future traffic flow. + :param z1, z2 (tensor): shape nvc + :return: nlvc, l=1, c=2 + ''' + return self.mlp(z1) + + def loss(self, z1, z2, y_true, scaler, loss_weights): + l1 = self.pred_loss(z1, z2, y_true, scaler) + sep_loss = [l1.item()] + loss = loss_weights[0] * l1 + + l2 = self.temporal_loss(z1, z2) + sep_loss.append(l2.item()) + loss += loss_weights[1] * l2 + + l3 = self.spatial_loss(z1, z2) + sep_loss.append(l3.item()) + loss += loss_weights[2] * l3 + return loss, sep_loss + + def pred_loss(self, z1, z2, y_true, scaler): + y_pred = scaler.inverse_transform(self.predict(z1, z2)) + y_true = scaler.inverse_transform(y_true) + + loss = self.args['yita'] * self.mae(y_pred[..., 0], y_true[..., 0]) + \ + (1 - self.args['yita']) * self.mae(y_pred[..., 1], y_true[..., 1]) + return loss + + def temporal_loss(self, z1, z2): + return self.thm(z1, z2) + + def spatial_loss(self, z1, z2): + return self.shm(z1, z2) + diff --git a/model/TEDDCF/ISTF.py b/model/TEDDCF/ISTF.py new file mode 100644 index 0000000..f77c4f8 --- /dev/null +++ b/model/TEDDCF/ISTF.py @@ -0,0 +1,108 @@ + +import torch.nn as nn +import torch +from torchinfo import summary + + +class AttentionLayer(nn.Module): + """Perform attention across the -2 dim (the -1 dim is `model_dim`). + + Make sure the tensor is permuted to correct shape before attention. + + E.g. + - Input shape (batch_size, in_steps, num_nodes, model_dim). + - Then the attention will be performed across the nodes. + + Also, it supports different src and tgt length. + + But must `src length == K length == V length`. + + """ + + def __init__(self, model_dim, num_heads=8, mask=False): + super().__init__() + + self.model_dim = model_dim#152 + self.num_heads = num_heads + self.mask = mask + + self.head_dim = model_dim // num_heads + + self.FC_Q = nn.Linear(model_dim, model_dim)#[152,152] + self.FC_K = nn.Linear(model_dim, model_dim) + self.FC_V = nn.Linear(model_dim, model_dim) + + self.out_proj = nn.Linear(model_dim, model_dim) + + def forward(self, query, key, value): + # Q (batch_size, ..., tgt_length, model_dim) + # K, V (batch_size, ..., src_length, model_dim) + batch_size = query.shape[0]#16 #64 + tgt_length = query.shape[-2]#12 #170 + src_length = key.shape[-2]#12 #170 + + query = self.FC_Q(query)#[64,6,170,152] + key = self.FC_K(key) + value = self.FC_V(value) + + # Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim) + query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)#[512,6,170,24] + key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0) + value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0) + + key = key.transpose( + -1, -2 + ) # (num_heads * batch_size, ..., head_dim, src_length) + + attn_score = (#[64,170,12,12] + query @ key + ) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length) + + if self.mask: + mask = torch.ones( + tgt_length, src_length, dtype=torch.bool, device=query.device + ).tril() # lower triangular part of the matrix + attn_score.masked_fill_(~mask, -torch.inf) # fill in-place + + attn_score = torch.softmax(attn_score, dim=-1)#[64,170,12,12] + out = attn_score @ value + out = torch.cat( + torch.split(out, batch_size, dim=0), dim=-1 + ) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)[16,170,12,152] + + out = self.out_proj(out)#[64,6,170,152] + + return out + +class SelfAttentionLayer(nn.Module): + def __init__( + self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False + ): + super().__init__() + + self.attn = AttentionLayer(model_dim, num_heads, mask) + self.feed_forward = nn.Sequential( + nn.Linear(model_dim, feed_forward_dim),#[152,256] + nn.ReLU(inplace=True), + nn.Linear(feed_forward_dim, model_dim),#[256.152] + ) + self.ln1 = nn.LayerNorm(model_dim) + self.ln2 = nn.LayerNorm(model_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, x, dim=-2): + x = x.transpose(dim, -2) + # x: (batch_size, ..., length, model_dim) + residual = x + out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)[16,170,12,152] + out = self.dropout1(out) + out = self.ln1(residual + out) + + residual = out + out = self.feed_forward(out) # (batch_size, ..., length, model_dim) + out = self.dropout2(out) + out = self.ln2(residual + out) + + out = out.transpose(dim, -2)#[64,6,170,152] + return out \ No newline at end of file diff --git a/model/TEDDCF/model.py b/model/TEDDCF/model.py new file mode 100644 index 0000000..b0f0dae --- /dev/null +++ b/model/TEDDCF/model.py @@ -0,0 +1,414 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import pandas as pd +import sys +from model.TEDDCF.ISTF import SelfAttentionLayer + + +class GLU(nn.Module): + def __init__(self, features, dropout=0.1):#PEMS08: 192 + super(GLU, self).__init__() + self.conv1 = nn.Conv2d(features, features, (1, 1)) + self.conv2 = nn.Conv2d(features, features, (1, 1)) + self.conv3 = nn.Conv2d(features, features, (1, 1)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + + x1 = self.conv1(x) + x2 = self.conv2(x) + out = x1 * torch.sigmoid(x2) + out = self.dropout(out) + out = self.conv3(out) + return out#[64,192,170,12] + + +class TemporalEmbedding(nn.Module): + def __init__(self, time, features): + super(TemporalEmbedding, self).__init__() + #S08:time 288 features 96 + self.time = time + self.time_day = nn.Parameter(torch.empty(time, features))#[288 96] + nn.init.xavier_uniform_(self.time_day) + + self.time_week = nn.Parameter(torch.empty(7, features))#[7 96] + nn.init.xavier_uniform_(self.time_week) + + def forward(self, x): + #x #in:[64,12,170,3] + day_emb = x[..., 1] + + time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)] + + + time_day = time_day.transpose(1, 2).contiguous() + + week_emb = x[..., 2] + + + time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]#[64,12,170,96] + time_week = time_week.transpose(1, 2).contiguous()#torch.Size([64, 170, 12, 96]) + + + tem_emb = time_day + time_week#[64,170,12,96] + + tem_emb = tem_emb.permute(0,3,1,2)#[64,96,170,12] + + return tem_emb + + + +class Diffusion_GCN(nn.Module): + def __init__(self, channels=128, diffusion_step=1, dropout=0.1): + super().__init__() + self.diffusion_step = diffusion_step#1 + self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))#[192,192,(1,1)] + self.dropout = nn.Dropout(dropout) + + def forward(self, x, adj): + + out = [] + for i in range(0, self.diffusion_step):#1 + if adj.dim() == 3: + x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous() + out.append(x) + elif adj.dim() == 2: + x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous() + out.append(x) + x = torch.cat(out, dim=1) + x = self.conv(x) + output = self.dropout(x) + return output + + +class EventGraph_Fusion(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1): + super().__init__() + self.memory = nn.Parameter(torch.randn(channels, num_nodes)) + nn.init.xavier_uniform_(self.memory) + self.fc = nn.Linear(2,1) + + def forward(self, x): + adj_dyn_1 = torch.softmax( + F.relu( + torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + adj_dyn_2 = torch.softmax( + F.relu( + torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1) + + adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1) + + topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1]*0.8), dim=-1) + + mask = torch.zeros_like(adj_f) + + mask.scatter_(-1, topk_indices, 1) + + adj_f = adj_f * mask + + return adj_f + + +class EventGCN(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None): + super().__init__() + + self.conv = nn.Conv2d(channels,channels,(1,1)) + self.generator = EventGraph_Fusion(channels, num_nodes, diffusion_step, dropout) + self.gcn = Diffusion_GCN(channels, diffusion_step, dropout) + self.emb = emb + + def forward(self, x): + + skip = x + x = self.conv(x) + adj_dyn = self.generator(x) + x = self.gcn(x, adj_dyn) + x = x*self.emb + skip + + return x + +class TrendGCN(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None): + super().__init__() + + self.conv = nn.Conv2d(channels,channels,(1,1)) + self.generator = TrendGraph_Fusion(channels, num_nodes, diffusion_step, dropout) + self.gcn = Diffusion_GCN(channels, diffusion_step, dropout) + self.emb = emb + + + def forward(self, x): + + skip = x + x = self.conv(x) + adj_dyn = self.generator(x) + x = self.gcn(x, adj_dyn) + x = x*self.emb + skip + + return x + + +class TrendGraph_Fusion(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1): + super().__init__() + self.memory = nn.Parameter( + torch.randn(channels, num_nodes)) + nn.init.xavier_uniform_(self.memory) + self.fc = nn.Linear(2, 1) + self.E_adaptive = nn.Parameter(torch.randn(num_nodes, 10)) + + def forward(self, x): + # adj_dyn_1 = torch.softmax( + # F.relu( + # torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous() + # / math.sqrt(x.shape[1]) + # ), + # -1, + # ) + + adj_dyn_2 = torch.softmax( + F.relu( + torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + adj_adp = F.softmax(F.relu(torch.mm(self.E_adaptive, self.E_adaptive.transpose(0, 1))), dim=1) + + adj_adp_expanded = adj_adp.unsqueeze(0) + + adj_adp = adj_adp_expanded.repeat(x.shape[0], 1, 1) + + adj_f = torch.cat([(adj_dyn_2).unsqueeze(-1)] + [(adj_adp).unsqueeze(-1)], dim=-1) + + adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1) + + topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1) + + mask = torch.zeros_like(adj_f) + + mask.scatter_(-1, topk_indices, 1) + + adj_f = adj_f * mask + + return adj_f + +class Chomp1d(nn.Module): + """ + extra dimension will be added by padding, remove it + """ + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, :, :-self.chomp_size].contiguous() + +class TemporalConvNet(nn.Module): + def __init__(self, features, kernel_size=2, dropout=0.2, levels=1): + super(TemporalConvNet, self).__init__() + + layers = [] + for i in range(levels): + dilation_size = 2 ** i + padding = (kernel_size - 1) * dilation_size + self.conv = nn.Conv2d(features, features, (1, kernel_size), dilation=(1, dilation_size), + padding=(0, padding)) + self.chomp = Chomp1d(padding) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)] + self.tcn = nn.Sequential(*layers) + + def forward(self, xh): + xh = self.tcn(xh) + return xh + pass + +class FeedForward(nn.Module): + def __init__(self, fea, res_ln=False): + super(FeedForward, self).__init__() + + + + self.res_ln = res_ln + self.L = len(fea) - 1#2 + self.linear = nn.ModuleList([nn.Linear(fea[i], fea[i+1]) for i in range(self.L)]) + self.ln = nn.LayerNorm(fea[self.L], elementwise_affine=False) + + def forward(self, inputs): + + x = inputs + for i in range(self.L): + x = self.linear[i](x) + if i != self.L-1: + x = F.relu(x) + + + if self.res_ln: + x += inputs + x = self.ln(x) + return x + +class Adaptive_Fusion(nn.Module): + def __init__(self, heads, dims): + super(Adaptive_Fusion, self).__init__() + features = dims # 192 + self.h = heads # 8 + self.d = int(dims / heads) # 16 + + self.qlfc = FeedForward([features, features]) + self.khfc = FeedForward([features, features]) + self.vhfc = FeedForward([features, features]) + self.ofc = FeedForward([features, features]) + + self.ln = nn.LayerNorm(features, elementwise_affine=False) + self.ff = FeedForward([features, features, features], True) + + def forward(self, xl, xh, Mask=True): + ''' + xl: [B,T,N,F] + xh: [B,T,N,F] + te: [B,T,N,F] + return: [B,T,N,F] + ''' + # xl += te + # xh += te + + query = self.qlfc(xl) # [B,T,N,F] + keyh = torch.relu(self.khfc(xh)) # [B,T,N,F] + valueh = torch.relu(self.vhfc(xh)) # [B,T,N,F] + + query = torch.cat(torch.split(query, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d] + keyh = torch.cat(torch.split(keyh, self.d, -1), 0).permute(0, 2, 3, 1) # [k*B,N,d,T] + valueh = torch.cat(torch.split(valueh, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d] + + attentionh = torch.matmul(query, keyh) # [k*B,N,T,T] + + if Mask: + batch_size = xl.shape[0] + num_steps = xl.shape[1] + num_vertexs = xl.shape[2] + mask = torch.ones(num_steps, num_steps).to(xl.device) # [T,T] + mask = torch.tril(mask) # [T,T] + mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) # [1,1,T,T] + mask = mask.repeat(self.h * batch_size, num_vertexs, 1, 1) # [k*B,N,T,T] + mask = mask.to(torch.bool) + zero_vec = (-2 ** 15 + 1) * torch.ones_like(attentionh).to(xl.device) # [k*B,N,T,T] + attentionh = torch.where(mask, attentionh, zero_vec) + + attentionh /= (self.d ** 0.5) # scaled + attentionh = F.softmax(attentionh, -1) # [k*B,N,T,T] + + value = torch.matmul(attentionh, valueh) # [k*B,N,T,d] + + value = torch.cat(torch.split(value, value.shape[0] // self.h, 0), -1).permute(0, 2, 1, 3) # [B,T,N,F] + value = self.ofc(value) + value = value + xl + + value = self.ln(value) + + return self.ff(value) # [64,12,170,128] + +class TEDDCF(nn.Module): + def __init__( + self, device, input_dim, num_nodes, channels, granularity, dropout=0.1 + ): + super().__init__() + + self.device = device + self.num_nodes = num_nodes + self.output_len = 12 + self.input_len = 12 + self.heads = 8 + diffusion_step = 1 + + self.Temb = TemporalEmbedding(granularity, channels) + + self.start_conv = nn.Conv2d( + in_channels=input_dim, out_channels=channels, kernel_size=(1, 1) + ) + + + self.glu = GLU(channels*2, dropout) + + self.regression_layer = nn.Conv2d( + channels*2, self.output_len, kernel_size=(1, self.output_len) + ) + + self.temporal_conv = TemporalConvNet(channels*2) + self.pre_h = nn.Conv2d(in_channels=self.input_len, out_channels=self.output_len, kernel_size=(1,1)) + self.adp_f = Adaptive_Fusion(self.heads, channels*2) + + num_layers = 3 + self.attn_layers_t = nn.ModuleList( + [ + SelfAttentionLayer(channels*2, feed_forward_dim=256, num_heads=4, dropout=0.1) + for _ in range(num_layers) # 3 + ] + ) + self.xh_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12)) + self.xh_dgcn = EventGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1,emb=self.xh_emb) + + self.xl_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12)) + self.xl_dgcn = TrendGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xl_emb) + + + def param_num(self): + return sum([param.nelement() for param in self.parameters()]) + + def forward(self, inputxl, inputxh): + + xl = inputxl + xh = inputxh + + # Encoder + # Data Embedding + time_embl = self.Temb(inputxl.permute(0, 3, 2, 1)) + time_embh = self.Temb(inputxh.permute(0, 3, 2, 1)) + #t = self.start_conv(x)#[64,96,170,12] + xl = torch.cat([self.start_conv(xl)] + [time_embl], dim=1) + xh = torch.cat([self.start_conv(xh)] + [time_embh], dim=1) + + + + xl = xl.permute(0, 3, 2, 1) + for attn in self.attn_layers_t: + xl = attn(xl, dim=1) + xl = xl.permute(0, 3, 2, 1) + + xl = self.xl_dgcn(xl) + xl = self.glu(xl) + xl + + + xh = self.temporal_conv(xh) + + + xh = self.xh_dgcn(xh) + + #simple plus + x_all = xh + xl + #STwave_fusion + # xl = xl.transpose(1, 3) + # xh = self.pre_h(xh.transpose(1,3))#[64,12,170,192] + # x_all = self.adp_f(xl, xh) + # x_all = x_all.transpose(1, 3) + + prediction = self.regression_layer(F.relu(x_all)) + + + return prediction \ No newline at end of file