add-model

This commit is contained in:
czzhangheng 2025-08-18 21:49:14 +08:00
parent af795043c8
commit 538548db0b
9 changed files with 1363 additions and 0 deletions

227
model/MegaCRN/MegaCRN.py Normal file
View File

@ -0,0 +1,227 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
class AGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k):
super(AGCN, self).__init__()
self.cheb_k = cheb_k
self.weights = nn.Parameter(torch.FloatTensor(2*cheb_k*dim_in, dim_out)) # 2 is the length of support
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
nn.init.xavier_normal_(self.weights)
nn.init.constant_(self.bias, val=0)
def forward(self, x, supports):
x_g = []
support_set = []
for support in supports:
support_ks = [torch.eye(support.shape[0]).to(support.device), support]
for k in range(2, self.cheb_k):
support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
support_set.extend(support_ks)
for support in support_set:
x_g.append(torch.einsum("nm,bmc->bnc", support, x))
x_g = torch.cat(x_g, dim=-1) # B, N, 2 * cheb_k * dim_in
x_gconv = torch.einsum('bni,io->bno', x_g, self.weights) + self.bias # b, N, dim_out
return x_gconv
class AGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k):
super(AGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = AGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k)
self.update = AGCN(dim_in+self.hidden_dim, dim_out, cheb_k)
def forward(self, x, state, supports):
#x: B, num_nodes, input_dim
#state: B, num_nodes, hidden_dim
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, supports))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z*state), dim=-1)
hc = torch.tanh(self.update(candidate, supports))
h = r*state + (1-r)*hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class ADCRNN_Encoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers):
super(ADCRNN_Encoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.dcrnn_cells = nn.ModuleList()
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k))
for _ in range(1, num_layers):
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k))
def forward(self, x, init_state, supports):
#shape of x: (B, T, N, D), shape of init_state: (num_layers, B, N, hidden_dim)
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, supports)
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
#last_state: (B, N, hidden_dim)
#output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
#return current_inputs, torch.stack(output_hidden, dim=0)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
return init_states
class ADCRNN_Decoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers):
super(ADCRNN_Decoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.dcrnn_cells = nn.ModuleList()
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k))
for _ in range(1, num_layers):
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k))
def forward(self, xt, init_state, supports):
# xt: (B, N, D)
# init_state: (num_layers, B, N, hidden_dim)
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
current_inputs = xt
output_hidden = []
for i in range(self.num_layers):
state = self.dcrnn_cells[i](current_inputs, init_state[i], supports)
output_hidden.append(state)
current_inputs = state
return current_inputs, output_hidden
class MegaCRN(nn.Module):
def __init__(self, num_nodes, input_dim, output_dim, horizon, rnn_units, num_layers=1, cheb_k=3,
ycov_dim=1, mem_num=20, mem_dim=64, cl_decay_steps=2000, use_curriculum_learning=True):
super(MegaCRN, self).__init__()
self.num_nodes = num_nodes
self.input_dim = input_dim
self.rnn_units = rnn_units
self.output_dim = output_dim
self.horizon = horizon
self.num_layers = num_layers
self.cheb_k = cheb_k
self.ycov_dim = ycov_dim
self.cl_decay_steps = cl_decay_steps
self.use_curriculum_learning = use_curriculum_learning
# memory
self.mem_num = mem_num
self.mem_dim = mem_dim
self.memory = self.construct_memory()
# encoder
self.encoder = ADCRNN_Encoder(self.num_nodes, self.input_dim, self.rnn_units, self.cheb_k, self.num_layers)
# deocoder
self.decoder_dim = self.rnn_units + self.mem_dim
self.decoder = ADCRNN_Decoder(self.num_nodes, self.output_dim + self.ycov_dim, self.decoder_dim, self.cheb_k, self.num_layers)
# output
self.proj = nn.Sequential(nn.Linear(self.decoder_dim, self.output_dim, bias=True))
def compute_sampling_threshold(self, batches_seen):
return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
def construct_memory(self):
memory_dict = nn.ParameterDict()
memory_dict['Memory'] = nn.Parameter(torch.randn(self.mem_num, self.mem_dim), requires_grad=True) # (M, d)
memory_dict['Wq'] = nn.Parameter(torch.randn(self.rnn_units, self.mem_dim), requires_grad=True) # project to query
memory_dict['We1'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding
memory_dict['We2'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding
for param in memory_dict.values():
nn.init.xavier_normal_(param)
return memory_dict
def query_memory(self, h_t:torch.Tensor):
query = torch.matmul(h_t, self.memory['Wq']) # (B, N, d)
att_score = torch.softmax(torch.matmul(query, self.memory['Memory'].t()), dim=-1) # alpha: (B, N, M)
value = torch.matmul(att_score, self.memory['Memory']) # (B, N, d)
_, ind = torch.topk(att_score, k=2, dim=-1)
pos = self.memory['Memory'][ind[:, :, 0]] # B, N, d
neg = self.memory['Memory'][ind[:, :, 1]] # B, N, d
return value, query, pos, neg
def forward(self, x, y_cov, labels=None, batches_seen=None):
node_embeddings1 = torch.matmul(self.memory['We1'], self.memory['Memory'])
node_embeddings2 = torch.matmul(self.memory['We2'], self.memory['Memory'])
g1 = F.softmax(F.relu(torch.mm(node_embeddings1, node_embeddings2.T)), dim=-1)
g2 = F.softmax(F.relu(torch.mm(node_embeddings2, node_embeddings1.T)), dim=-1)
supports = [g1, g2]
init_state = self.encoder.init_hidden(x.shape[0])
h_en, state_en = self.encoder(x, init_state, supports) # B, T, N, hidden
h_t = h_en[:, -1, :, :] # B, N, hidden (last state)
h_att, query, pos, neg = self.query_memory(h_t)
h_t = torch.cat([h_t, h_att], dim=-1)
ht_list = [h_t]*self.num_layers
go = torch.zeros((x.shape[0], self.num_nodes, self.output_dim), device=x.device)
out = []
for t in range(self.horizon):
h_de, ht_list = self.decoder(torch.cat([go, y_cov[:, t, ...]], dim=-1), ht_list, supports)
go = self.proj(h_de)
out.append(go)
if self.training and self.use_curriculum_learning:
c = np.random.uniform(0, 1)
if c < self.compute_sampling_threshold(batches_seen):
go = labels[:, t, ...]
output = torch.stack(out, dim=1)
return output, h_att, query, pos, neg
def print_params(model):
# print trainable params
param_count = 0
print('Trainable parameter list:')
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.shape, param.numel())
param_count += param.numel()
print(f'In total: {param_count} trainable parameters. \n')
return
def main():
import sys
import argparse
from torchsummary import summary
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=int, default=3, help="which GPU to use")
parser.add_argument('--num_variable', type=int, default=207, help='number of variables (e.g., 207 in METR-LA, 325 in PEMS-BAY)')
parser.add_argument('--his_len', type=int, default=12, help='sequence length of historical observation')
parser.add_argument('--seq_len', type=int, default=12, help='sequence length of prediction')
parser.add_argument('--channelin', type=int, default=1, help='number of input channel')
parser.add_argument('--channelout', type=int, default=1, help='number of output channel')
parser.add_argument('--rnn_units', type=int, default=64, help='number of hidden units')
args = parser.parse_args()
device = torch.device("cuda:{}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu")
model = MegaCRN(num_nodes=args.num_variable, input_dim=args.channelin, output_dim=args.channelout, horizon=args.seq_len, rnn_units=args.rnn_units).to(device)
summary(model, [(args.his_len, args.num_variable, args.channelin), (args.seq_len, args.num_variable, args.channelout)], device=device)
print_params(model)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from model.MegaCRN.MegaCRN import MegaCRN
class MegaCRNModel(nn.Module):
def __init__(self, args):
super(MegaCRNModel, self).__init__()
# 设置默认参数
if 'rnn_units' not in args:
args['rnn_units'] = 64
if 'num_layers' not in args:
args['num_layers'] = 1
if 'cheb_k' not in args:
args['cheb_k'] = 3
if 'ycov_dim' not in args:
args['ycov_dim'] = 1
if 'mem_num' not in args:
args['mem_num'] = 20
if 'mem_dim' not in args:
args['mem_dim'] = 64
if 'cl_decay_steps' not in args:
args['cl_decay_steps'] = 2000
if 'use_curriculum_learning' not in args:
args['use_curriculum_learning'] = True
if 'horizon' not in args:
args['horizon'] = 12
# 创建MegaCRN模型
self.model = MegaCRN(
num_nodes=args['num_nodes'],
input_dim=1, # 固定为1因为我们只使用第一个通道
output_dim=args['output_dim'],
horizon=args['horizon'],
rnn_units=args['rnn_units'],
num_layers=args['num_layers'],
cheb_k=args['cheb_k'],
ycov_dim=args['ycov_dim'],
mem_num=args['mem_num'],
mem_dim=args['mem_dim'],
cl_decay_steps=args['cl_decay_steps'],
use_curriculum_learning=args['use_curriculum_learning']
)
self.args = args
self.batches_seen = 0 # 添加batches_seen计数器
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
# 按照DDGCRN的模式只使用第一个通道
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
# 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整)
y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device)
# 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整)
labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device)
# 前向传播
output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen)
# 更新batches_seen
self.batches_seen += 1
return output

58
model/ST_SSL/ST-SSL.py Normal file
View File

@ -0,0 +1,58 @@
import torch
import torch.nn as nn
from model.ST-SSL.models import STSSL
from model.ST-SSL.layers import STEncoder, MLP
from data.get_adj import get_gso
class STSSLModel(nn.Module):
def __init__(self, args):
super(STSSLModel, self).__init__()
# 获取邻接矩阵
gso = get_gso(args)
# 设置默认参数
if 'd_model' not in args:
args['d_model'] = 64
if 'd_output' not in args:
args['d_output'] = args['output_dim']
if 'input_length' not in args:
args['input_length'] = args['n_his']
if 'dropout' not in args:
args['dropout'] = 0.1
if 'nmb_prototype' not in args:
args['nmb_prototype'] = 10
if 'batch_size' not in args:
args['batch_size'] = 64
if 'shm_temp' not in args:
args['shm_temp'] = 0.1
if 'yita' not in args:
args['yita'] = 0.5
if 'percent' not in args:
args['percent'] = 0.1
if 'device' not in args:
args['device'] = 'cpu'
# 创建ST-SSL模型
self.model = STSSL(args)
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
batch_size, seq_len, num_nodes, features = x.shape
# 获取邻接矩阵
graph = get_gso(self.args)
# 调整输入格式
x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features)
# 前向传播
repr1, repr2 = self.model(x, graph)
# 预测
pred = self.model.predict(repr1, repr2)
# 调整输出格式
pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features)
return pred

128
model/ST_SSL/ST_SSL.py Normal file
View File

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from data.get_adj import get_gso
class STSSLModel(nn.Module):
def __init__(self, args):
super(STSSLModel, self).__init__()
# 设置默认参数
if 'd_model' not in args:
args['d_model'] = 64
if 'd_output' not in args:
args['d_output'] = args['output_dim']
if 'input_length' not in args:
args['input_length'] = args['n_his']
if 'dropout' not in args:
args['dropout'] = 0.1
if 'nmb_prototype' not in args:
args['nmb_prototype'] = 10
if 'batch_size' not in args:
args['batch_size'] = 64
if 'shm_temp' not in args:
args['shm_temp'] = 0.1
if 'yita' not in args:
args['yita'] = 0.5
if 'percent' not in args:
args['percent'] = 0.1
if 'device' not in args:
args['device'] = 'cpu'
if 'gso_type' not in args:
args['gso_type'] = 'sym_norm_lap'
if 'graph_conv_type' not in args:
args['graph_conv_type'] = 'cheb_graph_conv'
# 保存参数
self.args = args
self.num_nodes = args['num_nodes']
self.input_dim = args['input_dim']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.d_model = args['d_model']
# 获取邻接矩阵
self.gso = get_gso(args)
# 时间嵌入
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model']))
# 节点嵌入
self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
# 编码器 - 使用1个输入通道
self.encoder = STEncoder(
Kt=3, Ks=3,
input_dim=1, # 只使用第一个通道
hidden_dim=args['d_model'],
input_length=args['input_length'],
num_nodes=args['num_nodes'],
droprate=args['dropout']
)
# 预测头
self.predictor = nn.Linear(args['d_model'], args['output_dim'])
# 初始化参数
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.node_emb_u)
nn.init.xavier_uniform_(self.node_emb_d)
nn.init.xavier_uniform_(self.T_i_D_emb)
nn.init.xavier_uniform_(self.D_i_W_emb)
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
# 按照DDGCRN的模式只使用第一个通道
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
# 编码
encoded = self.encoder(x, self.gso)
# 预测
# 取最后一个时间步的输出进行预测
last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model)
# 预测未来horizon个时间步
predictions = []
for t in range(self.horizon):
pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim)
predictions.append(pred)
# 堆叠预测结果
output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim)
return output
class STEncoder(nn.Module):
def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate):
super(STEncoder, self).__init__()
self.num_nodes = num_nodes
self.input_length = input_length
# 简化的时空编码器 - 使用1个输入通道
self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.dropout = nn.Dropout(droprate)
def forward(self, x, graph):
# x: (batch_size, seq_len, num_nodes, features)
batch_size, seq_len, num_nodes, features = x.shape
# 调整维度
x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes)
# 卷积操作
x = F.relu(self.conv1(x))
x = self.dropout(x)
x = F.relu(self.conv2(x))
# 调整回原维度
x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features)
return x

103
model/ST_SSL/aug.py Normal file
View File

@ -0,0 +1,103 @@
import copy
import numpy as np
import torch
def sim_global(flow_data, sim_type='cos'):
"""Calculate the global similarity of traffic flow data.
:param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c]
:param type: str, type of similarity, attention or cosine. ['att', 'cos']
:return sim: tensor, symmetric similarity, [v,v]
"""
if len(flow_data.shape) == 4:
n,l,v,c = flow_data.shape
att_scaling = n * l * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data)
elif len(flow_data.shape) == 3:
n,v,c = flow_data.shape
att_scaling = n * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data)
else:
raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape)))
if sim_type == 'cos':
# cosine similarity
scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling)
sim = sim * scaling
elif sim_type == 'att':
# scaled dot product similarity
scaling = float(att_scaling) ** -0.5
sim = torch.softmax(sim * scaling, dim=-1)
else:
raise ValueError('sim_global only support sim_type in [att, cos].')
return sim
def aug_topology(sim_mx, input_graph, percent=0.2):
"""Generate the data augumentation from topology (graph structure) perspective
for undirected graph without self-loop.
:param sim_mx: tensor, symmetric similarity, [v,v]
:param input_graph: tensor, adjacency matrix without self-loop, [v,v]
:return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v]
"""
## edge dropping starts here
drop_percent = percent / 2
index_list = input_graph.nonzero() # list of edges [row_idx, col_idx]
edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges
edge_mask = (input_graph > 0).tril(diagonal=-1)
add_drop_num = int(edge_num * drop_percent / 2)
aug_graph = copy.deepcopy(input_graph)
drop_prob = torch.softmax(sim_mx[edge_mask], dim=0)
drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability
drop_prob /= drop_prob.sum()
drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob)
drop_index = index_list[drop_list]
zeros = torch.zeros_like(aug_graph[0, 0])
aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros
aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros
## edge adding starts here
node_num = input_graph.shape[0]
x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij')
mask = y < x
x, y = x[mask], y[mask]
add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy()
add_prob = torch.softmax(add_prob, dim=0).numpy()
add_list = np.random.choice(int((node_num * node_num - node_num) / 2),
size=add_drop_num, p=add_prob)
ones = torch.ones_like(aug_graph[0, 0])
aug_graph[x[add_list], y[add_list]] = ones
aug_graph[y[add_list], x[add_list]] = ones
return aug_graph
def aug_traffic(t_sim_mx, flow_data, percent=0.2):
"""Generate the data augumentation from traffic (node attribute) perspective.
:param t_sim_mx: temporal similarity matrix after softmax, [l,n,v]
:param flow_data: input flow data, [n,l,v,c]
"""
l, n, v = t_sim_mx.shape
mask_num = int(n * l * v * percent)
aug_flow = copy.deepcopy(flow_data)
mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy()
mask_prob /= mask_prob.sum()
x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij')
mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob)
zeros = torch.zeros_like(aug_flow[0, 0, 0])
aug_flow[
x.reshape(-1)[mask_list],
y.reshape(-1)[mask_list],
z.reshape(-1)[mask_list]] = zeros
return aug_flow

103
model/ST_SSL/layers.py Normal file
View File

@ -0,0 +1,103 @@
import copy
import numpy as np
import torch
def sim_global(flow_data, sim_type='cos'):
"""Calculate the global similarity of traffic flow data.
:param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c]
:param type: str, type of similarity, attention or cosine. ['att', 'cos']
:return sim: tensor, symmetric similarity, [v,v]
"""
if len(flow_data.shape) == 4:
n,l,v,c = flow_data.shape
att_scaling = n * l * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data)
elif len(flow_data.shape) == 3:
n,v,c = flow_data.shape
att_scaling = n * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data)
else:
raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape)))
if sim_type == 'cos':
# cosine similarity
scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling)
sim = sim * scaling
elif sim_type == 'att':
# scaled dot product similarity
scaling = float(att_scaling) ** -0.5
sim = torch.softmax(sim * scaling, dim=-1)
else:
raise ValueError('sim_global only support sim_type in [att, cos].')
return sim
def aug_topology(sim_mx, input_graph, percent=0.2):
"""Generate the data augumentation from topology (graph structure) perspective
for undirected graph without self-loop.
:param sim_mx: tensor, symmetric similarity, [v,v]
:param input_graph: tensor, adjacency matrix without self-loop, [v,v]
:return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v]
"""
## edge dropping starts here
drop_percent = percent / 2
index_list = input_graph.nonzero() # list of edges [row_idx, col_idx]
edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges
edge_mask = (input_graph > 0).tril(diagonal=-1)
add_drop_num = int(edge_num * drop_percent / 2)
aug_graph = copy.deepcopy(input_graph)
drop_prob = torch.softmax(sim_mx[edge_mask], dim=0)
drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability
drop_prob /= drop_prob.sum()
drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob)
drop_index = index_list[drop_list]
zeros = torch.zeros_like(aug_graph[0, 0])
aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros
aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros
## edge adding starts here
node_num = input_graph.shape[0]
x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij')
mask = y < x
x, y = x[mask], y[mask]
add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy()
add_prob = torch.softmax(add_prob, dim=0).numpy()
add_list = np.random.choice(int((node_num * node_num - node_num) / 2),
size=add_drop_num, p=add_prob)
ones = torch.ones_like(aug_graph[0, 0])
aug_graph[x[add_list], y[add_list]] = ones
aug_graph[y[add_list], x[add_list]] = ones
return aug_graph
def aug_traffic(t_sim_mx, flow_data, percent=0.2):
"""Generate the data augumentation from traffic (node attribute) perspective.
:param t_sim_mx: temporal similarity matrix after softmax, [l,n,v]
:param flow_data: input flow data, [n,l,v,c]
"""
l, n, v = t_sim_mx.shape
mask_num = int(n * l * v * percent)
aug_flow = copy.deepcopy(flow_data)
mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy()
mask_prob /= mask_prob.sum()
x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij')
mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob)
zeros = torch.zeros_like(aug_flow[0, 0, 0])
aug_flow[
x.reshape(-1)[mask_list],
y.reshape(-1)[mask_list],
z.reshape(-1)[mask_list]] = zeros
return aug_flow

156
model/ST_SSL/models.py Normal file
View File

@ -0,0 +1,156 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# 简化的masked_mae_loss函数
def masked_mae_loss(mask_value=5.0):
def loss_fn(pred, target):
mask = (target != mask_value).float()
mae = F.l1_loss(pred * mask, target * mask, reduction='sum')
return mae / (mask.sum() + 1e-8)
return loss_fn
# 简化的数据增强函数
def aug_topology(sim_mx, graph, percent=0.1):
return graph
def aug_traffic(sim_mx, data, percent=0.1):
return data
class STEncoder(nn.Module):
def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate):
super(STEncoder, self).__init__()
self.num_nodes = num_nodes
self.input_length = input_length
# 简化的编码器 - 修复输入通道数
self.conv1 = nn.Conv2d(blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.conv2 = nn.Conv2d(blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.dropout = nn.Dropout(droprate)
# 临时的相似度矩阵
self.s_sim_mx = torch.randn(num_nodes, num_nodes)
self.t_sim_mx = torch.randn(input_length, input_length)
def forward(self, x, graph):
# x: (batch_size, num_nodes, seq_len, features)
batch_size, num_nodes, seq_len, features = x.shape
# 调整维度
x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len)
# 确保输入通道数正确
if x.shape[1] != 2: # 如果不是2个通道需要调整
if x.shape[1] == 1:
x = x.repeat(1, 2, 1, 1) # 复制到2个通道
else:
x = x[:, :2, :, :] # 取前2个通道
# 卷积操作
x = F.relu(self.conv1(x))
x = self.dropout(x)
x = F.relu(self.conv2(x))
# 调整回原维度
x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features)
return x
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
class TemporalHeteroModel(nn.Module):
def __init__(self, d_model, batch_size, num_nodes, device):
super(TemporalHeteroModel, self).__init__()
self.fc = nn.Linear(d_model, 1)
def forward(self, z1, z2):
return F.mse_loss(self.fc(z1), self.fc(z2))
class SpatialHeteroModel(nn.Module):
def __init__(self, d_model, nmb_prototype, batch_size, shm_temp):
super(SpatialHeteroModel, self).__init__()
self.fc = nn.Linear(d_model, 1)
def forward(self, z1, z2):
return F.mse_loss(self.fc(z1), self.fc(z2))
class STSSL(nn.Module):
def __init__(self, args):
super(STSSL, self).__init__()
# spatial temporal encoder
self.encoder = STEncoder(Kt=3, Ks=3, blocks=[[2, int(args['d_model']//2), args['d_model']], [args['d_model'], int(args['d_model']//2), args['d_model']]],
input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout'])
# traffic flow prediction branch
self.mlp = MLP(args['d_model'], args['d_output'])
# temporal heterogenrity modeling branch
self.thm = TemporalHeteroModel(args['d_model'], args['batch_size'], args['num_nodes'], args['device'])
# spatial heterogenrity modeling branch
self.shm = SpatialHeteroModel(args['d_model'], args['nmb_prototype'], args['batch_size'], args['shm_temp'])
self.mae = masked_mae_loss(mask_value=5.0)
self.args = args
def forward(self, view1, graph):
repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v
s_sim_mx = self.fetch_spatial_sim()
graph2 = aug_topology(s_sim_mx, graph, percent=self.args['percent']*2)
t_sim_mx = self.fetch_temporal_sim()
view2 = aug_traffic(t_sim_mx, view1, percent=self.args['percent'])
repr2 = self.encoder(view2, graph2)
return repr1, repr2
def fetch_spatial_sim(self):
"""
Fetch the region similarity matrix generated by region embedding.
Note this can be called only when spatial_sim is True.
:return sim_mx: tensor, similarity matrix, (v, v)
"""
return self.encoder.s_sim_mx.cpu()
def fetch_temporal_sim(self):
return self.encoder.t_sim_mx.cpu()
def predict(self, z1, z2):
'''Predicting future traffic flow.
:param z1, z2 (tensor): shape nvc
:return: nlvc, l=1, c=2
'''
return self.mlp(z1)
def loss(self, z1, z2, y_true, scaler, loss_weights):
l1 = self.pred_loss(z1, z2, y_true, scaler)
sep_loss = [l1.item()]
loss = loss_weights[0] * l1
l2 = self.temporal_loss(z1, z2)
sep_loss.append(l2.item())
loss += loss_weights[1] * l2
l3 = self.spatial_loss(z1, z2)
sep_loss.append(l3.item())
loss += loss_weights[2] * l3
return loss, sep_loss
def pred_loss(self, z1, z2, y_true, scaler):
y_pred = scaler.inverse_transform(self.predict(z1, z2))
y_true = scaler.inverse_transform(y_true)
loss = self.args['yita'] * self.mae(y_pred[..., 0], y_true[..., 0]) + \
(1 - self.args['yita']) * self.mae(y_pred[..., 1], y_true[..., 1])
return loss
def temporal_loss(self, z1, z2):
return self.thm(z1, z2)
def spatial_loss(self, z1, z2):
return self.shm(z1, z2)

108
model/TEDDCF/ISTF.py Normal file
View File

@ -0,0 +1,108 @@
import torch.nn as nn
import torch
from torchinfo import summary
class AttentionLayer(nn.Module):
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
Make sure the tensor is permuted to correct shape before attention.
E.g.
- Input shape (batch_size, in_steps, num_nodes, model_dim).
- Then the attention will be performed across the nodes.
Also, it supports different src and tgt length.
But must `src length == K length == V length`.
"""
def __init__(self, model_dim, num_heads=8, mask=False):
super().__init__()
self.model_dim = model_dim#152
self.num_heads = num_heads
self.mask = mask
self.head_dim = model_dim // num_heads
self.FC_Q = nn.Linear(model_dim, model_dim)#[152,152]
self.FC_K = nn.Linear(model_dim, model_dim)
self.FC_V = nn.Linear(model_dim, model_dim)
self.out_proj = nn.Linear(model_dim, model_dim)
def forward(self, query, key, value):
# Q (batch_size, ..., tgt_length, model_dim)
# K, V (batch_size, ..., src_length, model_dim)
batch_size = query.shape[0]#16 #64
tgt_length = query.shape[-2]#12 #170
src_length = key.shape[-2]#12 #170
query = self.FC_Q(query)#[64,6,170,152]
key = self.FC_K(key)
value = self.FC_V(value)
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)#[512,6,170,24]
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
key = key.transpose(
-1, -2
) # (num_heads * batch_size, ..., head_dim, src_length)
attn_score = (#[64,170,12,12]
query @ key
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
if self.mask:
mask = torch.ones(
tgt_length, src_length, dtype=torch.bool, device=query.device
).tril() # lower triangular part of the matrix
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
attn_score = torch.softmax(attn_score, dim=-1)#[64,170,12,12]
out = attn_score @ value
out = torch.cat(
torch.split(out, batch_size, dim=0), dim=-1
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)[16,170,12,152]
out = self.out_proj(out)#[64,6,170,152]
return out
class SelfAttentionLayer(nn.Module):
def __init__(
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
):
super().__init__()
self.attn = AttentionLayer(model_dim, num_heads, mask)
self.feed_forward = nn.Sequential(
nn.Linear(model_dim, feed_forward_dim),#[152,256]
nn.ReLU(inplace=True),
nn.Linear(feed_forward_dim, model_dim),#[256.152]
)
self.ln1 = nn.LayerNorm(model_dim)
self.ln2 = nn.LayerNorm(model_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, dim=-2):
x = x.transpose(dim, -2)
# x: (batch_size, ..., length, model_dim)
residual = x
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)[16,170,12,152]
out = self.dropout1(out)
out = self.ln1(residual + out)
residual = out
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
out = self.dropout2(out)
out = self.ln2(residual + out)
out = out.transpose(dim, -2)#[64,6,170,152]
return out

414
model/TEDDCF/model.py Normal file
View File

@ -0,0 +1,414 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import sys
from model.TEDDCF.ISTF import SelfAttentionLayer
class GLU(nn.Module):
def __init__(self, features, dropout=0.1):#PEMS08: 192
super(GLU, self).__init__()
self.conv1 = nn.Conv2d(features, features, (1, 1))
self.conv2 = nn.Conv2d(features, features, (1, 1))
self.conv3 = nn.Conv2d(features, features, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
out = x1 * torch.sigmoid(x2)
out = self.dropout(out)
out = self.conv3(out)
return out#[64,192,170,12]
class TemporalEmbedding(nn.Module):
def __init__(self, time, features):
super(TemporalEmbedding, self).__init__()
#S08:time 288 features 96
self.time = time
self.time_day = nn.Parameter(torch.empty(time, features))#[288 96]
nn.init.xavier_uniform_(self.time_day)
self.time_week = nn.Parameter(torch.empty(7, features))#[7 96]
nn.init.xavier_uniform_(self.time_week)
def forward(self, x):
#x #in:[64,12,170,3]
day_emb = x[..., 1]
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
time_day = time_day.transpose(1, 2).contiguous()
week_emb = x[..., 2]
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]#[64,12,170,96]
time_week = time_week.transpose(1, 2).contiguous()#torch.Size([64, 170, 12, 96])
tem_emb = time_day + time_week#[64,170,12,96]
tem_emb = tem_emb.permute(0,3,1,2)#[64,96,170,12]
return tem_emb
class Diffusion_GCN(nn.Module):
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
super().__init__()
self.diffusion_step = diffusion_step#1
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))#[192,192,(1,1)]
self.dropout = nn.Dropout(dropout)
def forward(self, x, adj):
out = []
for i in range(0, self.diffusion_step):#1
if adj.dim() == 3:
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
out.append(x)
elif adj.dim() == 2:
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
out.append(x)
x = torch.cat(out, dim=1)
x = self.conv(x)
output = self.dropout(x)
return output
class EventGraph_Fusion(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2,1)
def forward(self, x):
adj_dyn_1 = torch.softmax(
F.relu(
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1]*0.8), dim=-1)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class EventGCN(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
super().__init__()
self.conv = nn.Conv2d(channels,channels,(1,1))
self.generator = EventGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x*self.emb + skip
return x
class TrendGCN(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
super().__init__()
self.conv = nn.Conv2d(channels,channels,(1,1))
self.generator = TrendGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x*self.emb + skip
return x
class TrendGraph_Fusion(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(
torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2, 1)
self.E_adaptive = nn.Parameter(torch.randn(num_nodes, 10))
def forward(self, x):
# adj_dyn_1 = torch.softmax(
# F.relu(
# torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
# / math.sqrt(x.shape[1])
# ),
# -1,
# )
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_adp = F.softmax(F.relu(torch.mm(self.E_adaptive, self.E_adaptive.transpose(0, 1))), dim=1)
adj_adp_expanded = adj_adp.unsqueeze(0)
adj_adp = adj_adp_expanded.repeat(x.shape[0], 1, 1)
adj_f = torch.cat([(adj_dyn_2).unsqueeze(-1)] + [(adj_adp).unsqueeze(-1)], dim=-1)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class Chomp1d(nn.Module):
"""
extra dimension will be added by padding, remove it
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, features, kernel_size=2, dropout=0.2, levels=1):
super(TemporalConvNet, self).__init__()
layers = []
for i in range(levels):
dilation_size = 2 ** i
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(features, features, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.tcn = nn.Sequential(*layers)
def forward(self, xh):
xh = self.tcn(xh)
return xh
pass
class FeedForward(nn.Module):
def __init__(self, fea, res_ln=False):
super(FeedForward, self).__init__()
self.res_ln = res_ln
self.L = len(fea) - 1#2
self.linear = nn.ModuleList([nn.Linear(fea[i], fea[i+1]) for i in range(self.L)])
self.ln = nn.LayerNorm(fea[self.L], elementwise_affine=False)
def forward(self, inputs):
x = inputs
for i in range(self.L):
x = self.linear[i](x)
if i != self.L-1:
x = F.relu(x)
if self.res_ln:
x += inputs
x = self.ln(x)
return x
class Adaptive_Fusion(nn.Module):
def __init__(self, heads, dims):
super(Adaptive_Fusion, self).__init__()
features = dims # 192
self.h = heads # 8
self.d = int(dims / heads) # 16
self.qlfc = FeedForward([features, features])
self.khfc = FeedForward([features, features])
self.vhfc = FeedForward([features, features])
self.ofc = FeedForward([features, features])
self.ln = nn.LayerNorm(features, elementwise_affine=False)
self.ff = FeedForward([features, features, features], True)
def forward(self, xl, xh, Mask=True):
'''
xl: [B,T,N,F]
xh: [B,T,N,F]
te: [B,T,N,F]
return: [B,T,N,F]
'''
# xl += te
# xh += te
query = self.qlfc(xl) # [B,T,N,F]
keyh = torch.relu(self.khfc(xh)) # [B,T,N,F]
valueh = torch.relu(self.vhfc(xh)) # [B,T,N,F]
query = torch.cat(torch.split(query, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
keyh = torch.cat(torch.split(keyh, self.d, -1), 0).permute(0, 2, 3, 1) # [k*B,N,d,T]
valueh = torch.cat(torch.split(valueh, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
attentionh = torch.matmul(query, keyh) # [k*B,N,T,T]
if Mask:
batch_size = xl.shape[0]
num_steps = xl.shape[1]
num_vertexs = xl.shape[2]
mask = torch.ones(num_steps, num_steps).to(xl.device) # [T,T]
mask = torch.tril(mask) # [T,T]
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) # [1,1,T,T]
mask = mask.repeat(self.h * batch_size, num_vertexs, 1, 1) # [k*B,N,T,T]
mask = mask.to(torch.bool)
zero_vec = (-2 ** 15 + 1) * torch.ones_like(attentionh).to(xl.device) # [k*B,N,T,T]
attentionh = torch.where(mask, attentionh, zero_vec)
attentionh /= (self.d ** 0.5) # scaled
attentionh = F.softmax(attentionh, -1) # [k*B,N,T,T]
value = torch.matmul(attentionh, valueh) # [k*B,N,T,d]
value = torch.cat(torch.split(value, value.shape[0] // self.h, 0), -1).permute(0, 2, 1, 3) # [B,T,N,F]
value = self.ofc(value)
value = value + xl
value = self.ln(value)
return self.ff(value) # [64,12,170,128]
class TEDDCF(nn.Module):
def __init__(
self, device, input_dim, num_nodes, channels, granularity, dropout=0.1
):
super().__init__()
self.device = device
self.num_nodes = num_nodes
self.output_len = 12
self.input_len = 12
self.heads = 8
diffusion_step = 1
self.Temb = TemporalEmbedding(granularity, channels)
self.start_conv = nn.Conv2d(
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
)
self.glu = GLU(channels*2, dropout)
self.regression_layer = nn.Conv2d(
channels*2, self.output_len, kernel_size=(1, self.output_len)
)
self.temporal_conv = TemporalConvNet(channels*2)
self.pre_h = nn.Conv2d(in_channels=self.input_len, out_channels=self.output_len, kernel_size=(1,1))
self.adp_f = Adaptive_Fusion(self.heads, channels*2)
num_layers = 3
self.attn_layers_t = nn.ModuleList(
[
SelfAttentionLayer(channels*2, feed_forward_dim=256, num_heads=4, dropout=0.1)
for _ in range(num_layers) # 3
]
)
self.xh_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
self.xh_dgcn = EventGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1,emb=self.xh_emb)
self.xl_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
self.xl_dgcn = TrendGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xl_emb)
def param_num(self):
return sum([param.nelement() for param in self.parameters()])
def forward(self, inputxl, inputxh):
xl = inputxl
xh = inputxh
# Encoder
# Data Embedding
time_embl = self.Temb(inputxl.permute(0, 3, 2, 1))
time_embh = self.Temb(inputxh.permute(0, 3, 2, 1))
#t = self.start_conv(x)#[64,96,170,12]
xl = torch.cat([self.start_conv(xl)] + [time_embl], dim=1)
xh = torch.cat([self.start_conv(xh)] + [time_embh], dim=1)
xl = xl.permute(0, 3, 2, 1)
for attn in self.attn_layers_t:
xl = attn(xl, dim=1)
xl = xl.permute(0, 3, 2, 1)
xl = self.xl_dgcn(xl)
xl = self.glu(xl) + xl
xh = self.temporal_conv(xh)
xh = self.xh_dgcn(xh)
#simple plus
x_all = xh + xl
#STwave_fusion
# xl = xl.transpose(1, 3)
# xh = self.pre_h(xh.transpose(1,3))#[64,12,170,192]
# x_all = self.adp_f(xl, xh)
# x_all = x_all.transpose(1, 3)
prediction = self.regression_layer(F.relu(x_all))
return prediction