add-model
This commit is contained in:
parent
af795043c8
commit
538548db0b
|
|
@ -0,0 +1,227 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class AGCN(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, cheb_k):
|
||||||
|
super(AGCN, self).__init__()
|
||||||
|
self.cheb_k = cheb_k
|
||||||
|
self.weights = nn.Parameter(torch.FloatTensor(2*cheb_k*dim_in, dim_out)) # 2 is the length of support
|
||||||
|
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||||||
|
nn.init.xavier_normal_(self.weights)
|
||||||
|
nn.init.constant_(self.bias, val=0)
|
||||||
|
|
||||||
|
def forward(self, x, supports):
|
||||||
|
x_g = []
|
||||||
|
support_set = []
|
||||||
|
for support in supports:
|
||||||
|
support_ks = [torch.eye(support.shape[0]).to(support.device), support]
|
||||||
|
for k in range(2, self.cheb_k):
|
||||||
|
support_ks.append(torch.matmul(2 * support, support_ks[-1]) - support_ks[-2])
|
||||||
|
support_set.extend(support_ks)
|
||||||
|
for support in support_set:
|
||||||
|
x_g.append(torch.einsum("nm,bmc->bnc", support, x))
|
||||||
|
x_g = torch.cat(x_g, dim=-1) # B, N, 2 * cheb_k * dim_in
|
||||||
|
x_gconv = torch.einsum('bni,io->bno', x_g, self.weights) + self.bias # b, N, dim_out
|
||||||
|
return x_gconv
|
||||||
|
|
||||||
|
class AGCRNCell(nn.Module):
|
||||||
|
def __init__(self, node_num, dim_in, dim_out, cheb_k):
|
||||||
|
super(AGCRNCell, self).__init__()
|
||||||
|
self.node_num = node_num
|
||||||
|
self.hidden_dim = dim_out
|
||||||
|
self.gate = AGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k)
|
||||||
|
self.update = AGCN(dim_in+self.hidden_dim, dim_out, cheb_k)
|
||||||
|
|
||||||
|
def forward(self, x, state, supports):
|
||||||
|
#x: B, num_nodes, input_dim
|
||||||
|
#state: B, num_nodes, hidden_dim
|
||||||
|
state = state.to(x.device)
|
||||||
|
input_and_state = torch.cat((x, state), dim=-1)
|
||||||
|
z_r = torch.sigmoid(self.gate(input_and_state, supports))
|
||||||
|
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
||||||
|
candidate = torch.cat((x, z*state), dim=-1)
|
||||||
|
hc = torch.tanh(self.update(candidate, supports))
|
||||||
|
h = r*state + (1-r)*hc
|
||||||
|
return h
|
||||||
|
|
||||||
|
def init_hidden_state(self, batch_size):
|
||||||
|
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
||||||
|
|
||||||
|
class ADCRNN_Encoder(nn.Module):
|
||||||
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers):
|
||||||
|
super(ADCRNN_Encoder, self).__init__()
|
||||||
|
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
|
||||||
|
self.node_num = node_num
|
||||||
|
self.input_dim = dim_in
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dcrnn_cells = nn.ModuleList()
|
||||||
|
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k))
|
||||||
|
for _ in range(1, num_layers):
|
||||||
|
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k))
|
||||||
|
|
||||||
|
def forward(self, x, init_state, supports):
|
||||||
|
#shape of x: (B, T, N, D), shape of init_state: (num_layers, B, N, hidden_dim)
|
||||||
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||||
|
seq_length = x.shape[1]
|
||||||
|
current_inputs = x
|
||||||
|
output_hidden = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
state = init_state[i]
|
||||||
|
inner_states = []
|
||||||
|
for t in range(seq_length):
|
||||||
|
state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, supports)
|
||||||
|
inner_states.append(state)
|
||||||
|
output_hidden.append(state)
|
||||||
|
current_inputs = torch.stack(inner_states, dim=1)
|
||||||
|
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
|
||||||
|
#last_state: (B, N, hidden_dim)
|
||||||
|
#output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
|
||||||
|
#return current_inputs, torch.stack(output_hidden, dim=0)
|
||||||
|
return current_inputs, output_hidden
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size):
|
||||||
|
init_states = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
|
||||||
|
return init_states
|
||||||
|
|
||||||
|
class ADCRNN_Decoder(nn.Module):
|
||||||
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, num_layers):
|
||||||
|
super(ADCRNN_Decoder, self).__init__()
|
||||||
|
assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.'
|
||||||
|
self.node_num = node_num
|
||||||
|
self.input_dim = dim_in
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dcrnn_cells = nn.ModuleList()
|
||||||
|
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k))
|
||||||
|
for _ in range(1, num_layers):
|
||||||
|
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k))
|
||||||
|
|
||||||
|
def forward(self, xt, init_state, supports):
|
||||||
|
# xt: (B, N, D)
|
||||||
|
# init_state: (num_layers, B, N, hidden_dim)
|
||||||
|
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
|
||||||
|
current_inputs = xt
|
||||||
|
output_hidden = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
state = self.dcrnn_cells[i](current_inputs, init_state[i], supports)
|
||||||
|
output_hidden.append(state)
|
||||||
|
current_inputs = state
|
||||||
|
return current_inputs, output_hidden
|
||||||
|
|
||||||
|
|
||||||
|
class MegaCRN(nn.Module):
|
||||||
|
def __init__(self, num_nodes, input_dim, output_dim, horizon, rnn_units, num_layers=1, cheb_k=3,
|
||||||
|
ycov_dim=1, mem_num=20, mem_dim=64, cl_decay_steps=2000, use_curriculum_learning=True):
|
||||||
|
super(MegaCRN, self).__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.rnn_units = rnn_units
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.horizon = horizon
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.cheb_k = cheb_k
|
||||||
|
self.ycov_dim = ycov_dim
|
||||||
|
self.cl_decay_steps = cl_decay_steps
|
||||||
|
self.use_curriculum_learning = use_curriculum_learning
|
||||||
|
|
||||||
|
# memory
|
||||||
|
self.mem_num = mem_num
|
||||||
|
self.mem_dim = mem_dim
|
||||||
|
self.memory = self.construct_memory()
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
self.encoder = ADCRNN_Encoder(self.num_nodes, self.input_dim, self.rnn_units, self.cheb_k, self.num_layers)
|
||||||
|
|
||||||
|
# deocoder
|
||||||
|
self.decoder_dim = self.rnn_units + self.mem_dim
|
||||||
|
self.decoder = ADCRNN_Decoder(self.num_nodes, self.output_dim + self.ycov_dim, self.decoder_dim, self.cheb_k, self.num_layers)
|
||||||
|
|
||||||
|
# output
|
||||||
|
self.proj = nn.Sequential(nn.Linear(self.decoder_dim, self.output_dim, bias=True))
|
||||||
|
|
||||||
|
def compute_sampling_threshold(self, batches_seen):
|
||||||
|
return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
|
||||||
|
|
||||||
|
def construct_memory(self):
|
||||||
|
memory_dict = nn.ParameterDict()
|
||||||
|
memory_dict['Memory'] = nn.Parameter(torch.randn(self.mem_num, self.mem_dim), requires_grad=True) # (M, d)
|
||||||
|
memory_dict['Wq'] = nn.Parameter(torch.randn(self.rnn_units, self.mem_dim), requires_grad=True) # project to query
|
||||||
|
memory_dict['We1'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding
|
||||||
|
memory_dict['We2'] = nn.Parameter(torch.randn(self.num_nodes, self.mem_num), requires_grad=True) # project memory to embedding
|
||||||
|
for param in memory_dict.values():
|
||||||
|
nn.init.xavier_normal_(param)
|
||||||
|
return memory_dict
|
||||||
|
|
||||||
|
def query_memory(self, h_t:torch.Tensor):
|
||||||
|
query = torch.matmul(h_t, self.memory['Wq']) # (B, N, d)
|
||||||
|
att_score = torch.softmax(torch.matmul(query, self.memory['Memory'].t()), dim=-1) # alpha: (B, N, M)
|
||||||
|
value = torch.matmul(att_score, self.memory['Memory']) # (B, N, d)
|
||||||
|
_, ind = torch.topk(att_score, k=2, dim=-1)
|
||||||
|
pos = self.memory['Memory'][ind[:, :, 0]] # B, N, d
|
||||||
|
neg = self.memory['Memory'][ind[:, :, 1]] # B, N, d
|
||||||
|
return value, query, pos, neg
|
||||||
|
|
||||||
|
def forward(self, x, y_cov, labels=None, batches_seen=None):
|
||||||
|
node_embeddings1 = torch.matmul(self.memory['We1'], self.memory['Memory'])
|
||||||
|
node_embeddings2 = torch.matmul(self.memory['We2'], self.memory['Memory'])
|
||||||
|
g1 = F.softmax(F.relu(torch.mm(node_embeddings1, node_embeddings2.T)), dim=-1)
|
||||||
|
g2 = F.softmax(F.relu(torch.mm(node_embeddings2, node_embeddings1.T)), dim=-1)
|
||||||
|
supports = [g1, g2]
|
||||||
|
init_state = self.encoder.init_hidden(x.shape[0])
|
||||||
|
h_en, state_en = self.encoder(x, init_state, supports) # B, T, N, hidden
|
||||||
|
h_t = h_en[:, -1, :, :] # B, N, hidden (last state)
|
||||||
|
|
||||||
|
h_att, query, pos, neg = self.query_memory(h_t)
|
||||||
|
h_t = torch.cat([h_t, h_att], dim=-1)
|
||||||
|
|
||||||
|
ht_list = [h_t]*self.num_layers
|
||||||
|
go = torch.zeros((x.shape[0], self.num_nodes, self.output_dim), device=x.device)
|
||||||
|
out = []
|
||||||
|
for t in range(self.horizon):
|
||||||
|
h_de, ht_list = self.decoder(torch.cat([go, y_cov[:, t, ...]], dim=-1), ht_list, supports)
|
||||||
|
go = self.proj(h_de)
|
||||||
|
out.append(go)
|
||||||
|
if self.training and self.use_curriculum_learning:
|
||||||
|
c = np.random.uniform(0, 1)
|
||||||
|
if c < self.compute_sampling_threshold(batches_seen):
|
||||||
|
go = labels[:, t, ...]
|
||||||
|
output = torch.stack(out, dim=1)
|
||||||
|
|
||||||
|
return output, h_att, query, pos, neg
|
||||||
|
|
||||||
|
def print_params(model):
|
||||||
|
# print trainable params
|
||||||
|
param_count = 0
|
||||||
|
print('Trainable parameter list:')
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
print(name, param.shape, param.numel())
|
||||||
|
param_count += param.numel()
|
||||||
|
print(f'In total: {param_count} trainable parameters. \n')
|
||||||
|
return
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from torchsummary import summary
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--gpu", type=int, default=3, help="which GPU to use")
|
||||||
|
parser.add_argument('--num_variable', type=int, default=207, help='number of variables (e.g., 207 in METR-LA, 325 in PEMS-BAY)')
|
||||||
|
parser.add_argument('--his_len', type=int, default=12, help='sequence length of historical observation')
|
||||||
|
parser.add_argument('--seq_len', type=int, default=12, help='sequence length of prediction')
|
||||||
|
parser.add_argument('--channelin', type=int, default=1, help='number of input channel')
|
||||||
|
parser.add_argument('--channelout', type=int, default=1, help='number of output channel')
|
||||||
|
parser.add_argument('--rnn_units', type=int, default=64, help='number of hidden units')
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = torch.device("cuda:{}".format(args.gpu)) if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = MegaCRN(num_nodes=args.num_variable, input_dim=args.channelin, output_dim=args.channelout, horizon=args.seq_len, rnn_units=args.rnn_units).to(device)
|
||||||
|
summary(model, [(args.his_len, args.num_variable, args.channelin), (args.seq_len, args.num_variable, args.channelout)], device=device)
|
||||||
|
print_params(model)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.MegaCRN.MegaCRN import MegaCRN
|
||||||
|
|
||||||
|
class MegaCRNModel(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(MegaCRNModel, self).__init__()
|
||||||
|
|
||||||
|
# 设置默认参数
|
||||||
|
if 'rnn_units' not in args:
|
||||||
|
args['rnn_units'] = 64
|
||||||
|
if 'num_layers' not in args:
|
||||||
|
args['num_layers'] = 1
|
||||||
|
if 'cheb_k' not in args:
|
||||||
|
args['cheb_k'] = 3
|
||||||
|
if 'ycov_dim' not in args:
|
||||||
|
args['ycov_dim'] = 1
|
||||||
|
if 'mem_num' not in args:
|
||||||
|
args['mem_num'] = 20
|
||||||
|
if 'mem_dim' not in args:
|
||||||
|
args['mem_dim'] = 64
|
||||||
|
if 'cl_decay_steps' not in args:
|
||||||
|
args['cl_decay_steps'] = 2000
|
||||||
|
if 'use_curriculum_learning' not in args:
|
||||||
|
args['use_curriculum_learning'] = True
|
||||||
|
if 'horizon' not in args:
|
||||||
|
args['horizon'] = 12
|
||||||
|
|
||||||
|
# 创建MegaCRN模型
|
||||||
|
self.model = MegaCRN(
|
||||||
|
num_nodes=args['num_nodes'],
|
||||||
|
input_dim=1, # 固定为1,因为我们只使用第一个通道
|
||||||
|
output_dim=args['output_dim'],
|
||||||
|
horizon=args['horizon'],
|
||||||
|
rnn_units=args['rnn_units'],
|
||||||
|
num_layers=args['num_layers'],
|
||||||
|
cheb_k=args['cheb_k'],
|
||||||
|
ycov_dim=args['ycov_dim'],
|
||||||
|
mem_num=args['mem_num'],
|
||||||
|
mem_dim=args['mem_dim'],
|
||||||
|
cl_decay_steps=args['cl_decay_steps'],
|
||||||
|
use_curriculum_learning=args['use_curriculum_learning']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
self.batches_seen = 0 # 添加batches_seen计数器
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x shape: (batch_size, seq_len, num_nodes, features)
|
||||||
|
# 按照DDGCRN的模式,只使用第一个通道
|
||||||
|
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||||||
|
|
||||||
|
# 创建y_cov (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||||||
|
y_cov = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['ycov_dim'], device=x.device)
|
||||||
|
|
||||||
|
# 创建labels (这里使用零张量,实际使用时可能需要根据具体需求调整)
|
||||||
|
labels = torch.zeros(x.shape[0], self.args['horizon'], x.shape[2], self.args['output_dim'], device=x.device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
output, h_att, query, pos, neg = self.model(x, y_cov, labels=labels, batches_seen=self.batches_seen)
|
||||||
|
|
||||||
|
# 更新batches_seen
|
||||||
|
self.batches_seen += 1
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.ST-SSL.models import STSSL
|
||||||
|
from model.ST-SSL.layers import STEncoder, MLP
|
||||||
|
from data.get_adj import get_gso
|
||||||
|
|
||||||
|
class STSSLModel(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(STSSLModel, self).__init__()
|
||||||
|
# 获取邻接矩阵
|
||||||
|
gso = get_gso(args)
|
||||||
|
|
||||||
|
# 设置默认参数
|
||||||
|
if 'd_model' not in args:
|
||||||
|
args['d_model'] = 64
|
||||||
|
if 'd_output' not in args:
|
||||||
|
args['d_output'] = args['output_dim']
|
||||||
|
if 'input_length' not in args:
|
||||||
|
args['input_length'] = args['n_his']
|
||||||
|
if 'dropout' not in args:
|
||||||
|
args['dropout'] = 0.1
|
||||||
|
if 'nmb_prototype' not in args:
|
||||||
|
args['nmb_prototype'] = 10
|
||||||
|
if 'batch_size' not in args:
|
||||||
|
args['batch_size'] = 64
|
||||||
|
if 'shm_temp' not in args:
|
||||||
|
args['shm_temp'] = 0.1
|
||||||
|
if 'yita' not in args:
|
||||||
|
args['yita'] = 0.5
|
||||||
|
if 'percent' not in args:
|
||||||
|
args['percent'] = 0.1
|
||||||
|
if 'device' not in args:
|
||||||
|
args['device'] = 'cpu'
|
||||||
|
|
||||||
|
# 创建ST-SSL模型
|
||||||
|
self.model = STSSL(args)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x shape: (batch_size, seq_len, num_nodes, features)
|
||||||
|
batch_size, seq_len, num_nodes, features = x.shape
|
||||||
|
|
||||||
|
# 获取邻接矩阵
|
||||||
|
graph = get_gso(self.args)
|
||||||
|
|
||||||
|
# 调整输入格式
|
||||||
|
x = x.permute(0, 2, 1, 3) # (batch_size, num_nodes, seq_len, features)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
repr1, repr2 = self.model(x, graph)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
pred = self.model.predict(repr1, repr2)
|
||||||
|
|
||||||
|
# 调整输出格式
|
||||||
|
pred = pred.permute(0, 2, 1, 3) # (batch_size, seq_len, num_nodes, features)
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from data.get_adj import get_gso
|
||||||
|
|
||||||
|
class STSSLModel(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(STSSLModel, self).__init__()
|
||||||
|
|
||||||
|
# 设置默认参数
|
||||||
|
if 'd_model' not in args:
|
||||||
|
args['d_model'] = 64
|
||||||
|
if 'd_output' not in args:
|
||||||
|
args['d_output'] = args['output_dim']
|
||||||
|
if 'input_length' not in args:
|
||||||
|
args['input_length'] = args['n_his']
|
||||||
|
if 'dropout' not in args:
|
||||||
|
args['dropout'] = 0.1
|
||||||
|
if 'nmb_prototype' not in args:
|
||||||
|
args['nmb_prototype'] = 10
|
||||||
|
if 'batch_size' not in args:
|
||||||
|
args['batch_size'] = 64
|
||||||
|
if 'shm_temp' not in args:
|
||||||
|
args['shm_temp'] = 0.1
|
||||||
|
if 'yita' not in args:
|
||||||
|
args['yita'] = 0.5
|
||||||
|
if 'percent' not in args:
|
||||||
|
args['percent'] = 0.1
|
||||||
|
if 'device' not in args:
|
||||||
|
args['device'] = 'cpu'
|
||||||
|
if 'gso_type' not in args:
|
||||||
|
args['gso_type'] = 'sym_norm_lap'
|
||||||
|
if 'graph_conv_type' not in args:
|
||||||
|
args['graph_conv_type'] = 'cheb_graph_conv'
|
||||||
|
|
||||||
|
# 保存参数
|
||||||
|
self.args = args
|
||||||
|
self.num_nodes = args['num_nodes']
|
||||||
|
self.input_dim = args['input_dim']
|
||||||
|
self.output_dim = args['output_dim']
|
||||||
|
self.horizon = args['horizon']
|
||||||
|
self.d_model = args['d_model']
|
||||||
|
|
||||||
|
# 获取邻接矩阵
|
||||||
|
self.gso = get_gso(args)
|
||||||
|
|
||||||
|
# 时间嵌入
|
||||||
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model']))
|
||||||
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model']))
|
||||||
|
|
||||||
|
# 节点嵌入
|
||||||
|
self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
|
||||||
|
self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
|
||||||
|
|
||||||
|
# 编码器 - 使用1个输入通道
|
||||||
|
self.encoder = STEncoder(
|
||||||
|
Kt=3, Ks=3,
|
||||||
|
input_dim=1, # 只使用第一个通道
|
||||||
|
hidden_dim=args['d_model'],
|
||||||
|
input_length=args['input_length'],
|
||||||
|
num_nodes=args['num_nodes'],
|
||||||
|
droprate=args['dropout']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测头
|
||||||
|
self.predictor = nn.Linear(args['d_model'], args['output_dim'])
|
||||||
|
|
||||||
|
# 初始化参数
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.xavier_uniform_(self.node_emb_u)
|
||||||
|
nn.init.xavier_uniform_(self.node_emb_d)
|
||||||
|
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||||
|
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x shape: (batch_size, seq_len, num_nodes, features)
|
||||||
|
# 按照DDGCRN的模式,只使用第一个通道
|
||||||
|
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||||||
|
|
||||||
|
# 编码
|
||||||
|
encoded = self.encoder(x, self.gso)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
# 取最后一个时间步的输出进行预测
|
||||||
|
last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model)
|
||||||
|
|
||||||
|
# 预测未来horizon个时间步
|
||||||
|
predictions = []
|
||||||
|
for t in range(self.horizon):
|
||||||
|
pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim)
|
||||||
|
predictions.append(pred)
|
||||||
|
|
||||||
|
# 堆叠预测结果
|
||||||
|
output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class STEncoder(nn.Module):
|
||||||
|
def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate):
|
||||||
|
super(STEncoder, self).__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.input_length = input_length
|
||||||
|
|
||||||
|
# 简化的时空编码器 - 使用1个输入通道
|
||||||
|
self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||||||
|
self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||||||
|
self.dropout = nn.Dropout(droprate)
|
||||||
|
|
||||||
|
def forward(self, x, graph):
|
||||||
|
# x: (batch_size, seq_len, num_nodes, features)
|
||||||
|
batch_size, seq_len, num_nodes, features = x.shape
|
||||||
|
|
||||||
|
# 调整维度
|
||||||
|
x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes)
|
||||||
|
|
||||||
|
# 卷积操作
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
|
||||||
|
# 调整回原维度
|
||||||
|
x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def sim_global(flow_data, sim_type='cos'):
|
||||||
|
"""Calculate the global similarity of traffic flow data.
|
||||||
|
:param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c]
|
||||||
|
:param type: str, type of similarity, attention or cosine. ['att', 'cos']
|
||||||
|
:return sim: tensor, symmetric similarity, [v,v]
|
||||||
|
"""
|
||||||
|
if len(flow_data.shape) == 4:
|
||||||
|
n,l,v,c = flow_data.shape
|
||||||
|
att_scaling = n * l * c
|
||||||
|
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N
|
||||||
|
sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data)
|
||||||
|
elif len(flow_data.shape) == 3:
|
||||||
|
n,v,c = flow_data.shape
|
||||||
|
att_scaling = n * c
|
||||||
|
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N
|
||||||
|
sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data)
|
||||||
|
else:
|
||||||
|
raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape)))
|
||||||
|
|
||||||
|
if sim_type == 'cos':
|
||||||
|
# cosine similarity
|
||||||
|
scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling)
|
||||||
|
sim = sim * scaling
|
||||||
|
elif sim_type == 'att':
|
||||||
|
# scaled dot product similarity
|
||||||
|
scaling = float(att_scaling) ** -0.5
|
||||||
|
sim = torch.softmax(sim * scaling, dim=-1)
|
||||||
|
else:
|
||||||
|
raise ValueError('sim_global only support sim_type in [att, cos].')
|
||||||
|
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def aug_topology(sim_mx, input_graph, percent=0.2):
|
||||||
|
"""Generate the data augumentation from topology (graph structure) perspective
|
||||||
|
for undirected graph without self-loop.
|
||||||
|
:param sim_mx: tensor, symmetric similarity, [v,v]
|
||||||
|
:param input_graph: tensor, adjacency matrix without self-loop, [v,v]
|
||||||
|
:return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v]
|
||||||
|
"""
|
||||||
|
## edge dropping starts here
|
||||||
|
drop_percent = percent / 2
|
||||||
|
|
||||||
|
index_list = input_graph.nonzero() # list of edges [row_idx, col_idx]
|
||||||
|
|
||||||
|
edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges
|
||||||
|
edge_mask = (input_graph > 0).tril(diagonal=-1)
|
||||||
|
add_drop_num = int(edge_num * drop_percent / 2)
|
||||||
|
aug_graph = copy.deepcopy(input_graph)
|
||||||
|
|
||||||
|
drop_prob = torch.softmax(sim_mx[edge_mask], dim=0)
|
||||||
|
drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability
|
||||||
|
drop_prob /= drop_prob.sum()
|
||||||
|
drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob)
|
||||||
|
drop_index = index_list[drop_list]
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(aug_graph[0, 0])
|
||||||
|
aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros
|
||||||
|
aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros
|
||||||
|
|
||||||
|
## edge adding starts here
|
||||||
|
node_num = input_graph.shape[0]
|
||||||
|
x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij')
|
||||||
|
mask = y < x
|
||||||
|
x, y = x[mask], y[mask]
|
||||||
|
|
||||||
|
add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy()
|
||||||
|
add_prob = torch.softmax(add_prob, dim=0).numpy()
|
||||||
|
add_list = np.random.choice(int((node_num * node_num - node_num) / 2),
|
||||||
|
size=add_drop_num, p=add_prob)
|
||||||
|
|
||||||
|
ones = torch.ones_like(aug_graph[0, 0])
|
||||||
|
aug_graph[x[add_list], y[add_list]] = ones
|
||||||
|
aug_graph[y[add_list], x[add_list]] = ones
|
||||||
|
|
||||||
|
return aug_graph
|
||||||
|
|
||||||
|
def aug_traffic(t_sim_mx, flow_data, percent=0.2):
|
||||||
|
"""Generate the data augumentation from traffic (node attribute) perspective.
|
||||||
|
:param t_sim_mx: temporal similarity matrix after softmax, [l,n,v]
|
||||||
|
:param flow_data: input flow data, [n,l,v,c]
|
||||||
|
"""
|
||||||
|
l, n, v = t_sim_mx.shape
|
||||||
|
mask_num = int(n * l * v * percent)
|
||||||
|
aug_flow = copy.deepcopy(flow_data)
|
||||||
|
|
||||||
|
mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy()
|
||||||
|
mask_prob /= mask_prob.sum()
|
||||||
|
|
||||||
|
x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij')
|
||||||
|
mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob)
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(aug_flow[0, 0, 0])
|
||||||
|
aug_flow[
|
||||||
|
x.reshape(-1)[mask_list],
|
||||||
|
y.reshape(-1)[mask_list],
|
||||||
|
z.reshape(-1)[mask_list]] = zeros
|
||||||
|
|
||||||
|
return aug_flow
|
||||||
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def sim_global(flow_data, sim_type='cos'):
|
||||||
|
"""Calculate the global similarity of traffic flow data.
|
||||||
|
:param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c]
|
||||||
|
:param type: str, type of similarity, attention or cosine. ['att', 'cos']
|
||||||
|
:return sim: tensor, symmetric similarity, [v,v]
|
||||||
|
"""
|
||||||
|
if len(flow_data.shape) == 4:
|
||||||
|
n,l,v,c = flow_data.shape
|
||||||
|
att_scaling = n * l * c
|
||||||
|
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N
|
||||||
|
sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data)
|
||||||
|
elif len(flow_data.shape) == 3:
|
||||||
|
n,v,c = flow_data.shape
|
||||||
|
att_scaling = n * c
|
||||||
|
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N
|
||||||
|
sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data)
|
||||||
|
else:
|
||||||
|
raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape)))
|
||||||
|
|
||||||
|
if sim_type == 'cos':
|
||||||
|
# cosine similarity
|
||||||
|
scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling)
|
||||||
|
sim = sim * scaling
|
||||||
|
elif sim_type == 'att':
|
||||||
|
# scaled dot product similarity
|
||||||
|
scaling = float(att_scaling) ** -0.5
|
||||||
|
sim = torch.softmax(sim * scaling, dim=-1)
|
||||||
|
else:
|
||||||
|
raise ValueError('sim_global only support sim_type in [att, cos].')
|
||||||
|
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def aug_topology(sim_mx, input_graph, percent=0.2):
|
||||||
|
"""Generate the data augumentation from topology (graph structure) perspective
|
||||||
|
for undirected graph without self-loop.
|
||||||
|
:param sim_mx: tensor, symmetric similarity, [v,v]
|
||||||
|
:param input_graph: tensor, adjacency matrix without self-loop, [v,v]
|
||||||
|
:return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v]
|
||||||
|
"""
|
||||||
|
## edge dropping starts here
|
||||||
|
drop_percent = percent / 2
|
||||||
|
|
||||||
|
index_list = input_graph.nonzero() # list of edges [row_idx, col_idx]
|
||||||
|
|
||||||
|
edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges
|
||||||
|
edge_mask = (input_graph > 0).tril(diagonal=-1)
|
||||||
|
add_drop_num = int(edge_num * drop_percent / 2)
|
||||||
|
aug_graph = copy.deepcopy(input_graph)
|
||||||
|
|
||||||
|
drop_prob = torch.softmax(sim_mx[edge_mask], dim=0)
|
||||||
|
drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability
|
||||||
|
drop_prob /= drop_prob.sum()
|
||||||
|
drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob)
|
||||||
|
drop_index = index_list[drop_list]
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(aug_graph[0, 0])
|
||||||
|
aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros
|
||||||
|
aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros
|
||||||
|
|
||||||
|
## edge adding starts here
|
||||||
|
node_num = input_graph.shape[0]
|
||||||
|
x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij')
|
||||||
|
mask = y < x
|
||||||
|
x, y = x[mask], y[mask]
|
||||||
|
|
||||||
|
add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy()
|
||||||
|
add_prob = torch.softmax(add_prob, dim=0).numpy()
|
||||||
|
add_list = np.random.choice(int((node_num * node_num - node_num) / 2),
|
||||||
|
size=add_drop_num, p=add_prob)
|
||||||
|
|
||||||
|
ones = torch.ones_like(aug_graph[0, 0])
|
||||||
|
aug_graph[x[add_list], y[add_list]] = ones
|
||||||
|
aug_graph[y[add_list], x[add_list]] = ones
|
||||||
|
|
||||||
|
return aug_graph
|
||||||
|
|
||||||
|
def aug_traffic(t_sim_mx, flow_data, percent=0.2):
|
||||||
|
"""Generate the data augumentation from traffic (node attribute) perspective.
|
||||||
|
:param t_sim_mx: temporal similarity matrix after softmax, [l,n,v]
|
||||||
|
:param flow_data: input flow data, [n,l,v,c]
|
||||||
|
"""
|
||||||
|
l, n, v = t_sim_mx.shape
|
||||||
|
mask_num = int(n * l * v * percent)
|
||||||
|
aug_flow = copy.deepcopy(flow_data)
|
||||||
|
|
||||||
|
mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy()
|
||||||
|
mask_prob /= mask_prob.sum()
|
||||||
|
|
||||||
|
x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij')
|
||||||
|
mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob)
|
||||||
|
|
||||||
|
zeros = torch.zeros_like(aug_flow[0, 0, 0])
|
||||||
|
aug_flow[
|
||||||
|
x.reshape(-1)[mask_list],
|
||||||
|
y.reshape(-1)[mask_list],
|
||||||
|
z.reshape(-1)[mask_list]] = zeros
|
||||||
|
|
||||||
|
return aug_flow
|
||||||
|
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# 简化的masked_mae_loss函数
|
||||||
|
def masked_mae_loss(mask_value=5.0):
|
||||||
|
def loss_fn(pred, target):
|
||||||
|
mask = (target != mask_value).float()
|
||||||
|
mae = F.l1_loss(pred * mask, target * mask, reduction='sum')
|
||||||
|
return mae / (mask.sum() + 1e-8)
|
||||||
|
return loss_fn
|
||||||
|
|
||||||
|
# 简化的数据增强函数
|
||||||
|
def aug_topology(sim_mx, graph, percent=0.1):
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def aug_traffic(sim_mx, data, percent=0.1):
|
||||||
|
return data
|
||||||
|
|
||||||
|
class STEncoder(nn.Module):
|
||||||
|
def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate):
|
||||||
|
super(STEncoder, self).__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.input_length = input_length
|
||||||
|
|
||||||
|
# 简化的编码器 - 修复输入通道数
|
||||||
|
self.conv1 = nn.Conv2d(blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||||||
|
self.conv2 = nn.Conv2d(blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||||||
|
self.dropout = nn.Dropout(droprate)
|
||||||
|
|
||||||
|
# 临时的相似度矩阵
|
||||||
|
self.s_sim_mx = torch.randn(num_nodes, num_nodes)
|
||||||
|
self.t_sim_mx = torch.randn(input_length, input_length)
|
||||||
|
|
||||||
|
def forward(self, x, graph):
|
||||||
|
# x: (batch_size, num_nodes, seq_len, features)
|
||||||
|
batch_size, num_nodes, seq_len, features = x.shape
|
||||||
|
|
||||||
|
# 调整维度
|
||||||
|
x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len)
|
||||||
|
|
||||||
|
# 确保输入通道数正确
|
||||||
|
if x.shape[1] != 2: # 如果不是2个通道,需要调整
|
||||||
|
if x.shape[1] == 1:
|
||||||
|
x = x.repeat(1, 2, 1, 1) # 复制到2个通道
|
||||||
|
else:
|
||||||
|
x = x[:, :2, :, :] # 取前2个通道
|
||||||
|
|
||||||
|
# 卷积操作
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
|
||||||
|
# 调整回原维度
|
||||||
|
x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
self.fc = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
class TemporalHeteroModel(nn.Module):
|
||||||
|
def __init__(self, d_model, batch_size, num_nodes, device):
|
||||||
|
super(TemporalHeteroModel, self).__init__()
|
||||||
|
self.fc = nn.Linear(d_model, 1)
|
||||||
|
|
||||||
|
def forward(self, z1, z2):
|
||||||
|
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||||||
|
|
||||||
|
class SpatialHeteroModel(nn.Module):
|
||||||
|
def __init__(self, d_model, nmb_prototype, batch_size, shm_temp):
|
||||||
|
super(SpatialHeteroModel, self).__init__()
|
||||||
|
self.fc = nn.Linear(d_model, 1)
|
||||||
|
|
||||||
|
def forward(self, z1, z2):
|
||||||
|
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||||||
|
|
||||||
|
class STSSL(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(STSSL, self).__init__()
|
||||||
|
# spatial temporal encoder
|
||||||
|
self.encoder = STEncoder(Kt=3, Ks=3, blocks=[[2, int(args['d_model']//2), args['d_model']], [args['d_model'], int(args['d_model']//2), args['d_model']]],
|
||||||
|
input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout'])
|
||||||
|
|
||||||
|
# traffic flow prediction branch
|
||||||
|
self.mlp = MLP(args['d_model'], args['d_output'])
|
||||||
|
# temporal heterogenrity modeling branch
|
||||||
|
self.thm = TemporalHeteroModel(args['d_model'], args['batch_size'], args['num_nodes'], args['device'])
|
||||||
|
# spatial heterogenrity modeling branch
|
||||||
|
self.shm = SpatialHeteroModel(args['d_model'], args['nmb_prototype'], args['batch_size'], args['shm_temp'])
|
||||||
|
self.mae = masked_mae_loss(mask_value=5.0)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def forward(self, view1, graph):
|
||||||
|
repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v
|
||||||
|
|
||||||
|
s_sim_mx = self.fetch_spatial_sim()
|
||||||
|
graph2 = aug_topology(s_sim_mx, graph, percent=self.args['percent']*2)
|
||||||
|
|
||||||
|
t_sim_mx = self.fetch_temporal_sim()
|
||||||
|
view2 = aug_traffic(t_sim_mx, view1, percent=self.args['percent'])
|
||||||
|
|
||||||
|
repr2 = self.encoder(view2, graph2)
|
||||||
|
return repr1, repr2
|
||||||
|
|
||||||
|
def fetch_spatial_sim(self):
|
||||||
|
"""
|
||||||
|
Fetch the region similarity matrix generated by region embedding.
|
||||||
|
Note this can be called only when spatial_sim is True.
|
||||||
|
:return sim_mx: tensor, similarity matrix, (v, v)
|
||||||
|
"""
|
||||||
|
return self.encoder.s_sim_mx.cpu()
|
||||||
|
|
||||||
|
def fetch_temporal_sim(self):
|
||||||
|
return self.encoder.t_sim_mx.cpu()
|
||||||
|
|
||||||
|
def predict(self, z1, z2):
|
||||||
|
'''Predicting future traffic flow.
|
||||||
|
:param z1, z2 (tensor): shape nvc
|
||||||
|
:return: nlvc, l=1, c=2
|
||||||
|
'''
|
||||||
|
return self.mlp(z1)
|
||||||
|
|
||||||
|
def loss(self, z1, z2, y_true, scaler, loss_weights):
|
||||||
|
l1 = self.pred_loss(z1, z2, y_true, scaler)
|
||||||
|
sep_loss = [l1.item()]
|
||||||
|
loss = loss_weights[0] * l1
|
||||||
|
|
||||||
|
l2 = self.temporal_loss(z1, z2)
|
||||||
|
sep_loss.append(l2.item())
|
||||||
|
loss += loss_weights[1] * l2
|
||||||
|
|
||||||
|
l3 = self.spatial_loss(z1, z2)
|
||||||
|
sep_loss.append(l3.item())
|
||||||
|
loss += loss_weights[2] * l3
|
||||||
|
return loss, sep_loss
|
||||||
|
|
||||||
|
def pred_loss(self, z1, z2, y_true, scaler):
|
||||||
|
y_pred = scaler.inverse_transform(self.predict(z1, z2))
|
||||||
|
y_true = scaler.inverse_transform(y_true)
|
||||||
|
|
||||||
|
loss = self.args['yita'] * self.mae(y_pred[..., 0], y_true[..., 0]) + \
|
||||||
|
(1 - self.args['yita']) * self.mae(y_pred[..., 1], y_true[..., 1])
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def temporal_loss(self, z1, z2):
|
||||||
|
return self.thm(z1, z2)
|
||||||
|
|
||||||
|
def spatial_loss(self, z1, z2):
|
||||||
|
return self.shm(z1, z2)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(nn.Module):
|
||||||
|
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
|
||||||
|
|
||||||
|
Make sure the tensor is permuted to correct shape before attention.
|
||||||
|
|
||||||
|
E.g.
|
||||||
|
- Input shape (batch_size, in_steps, num_nodes, model_dim).
|
||||||
|
- Then the attention will be performed across the nodes.
|
||||||
|
|
||||||
|
Also, it supports different src and tgt length.
|
||||||
|
|
||||||
|
But must `src length == K length == V length`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_dim, num_heads=8, mask=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model_dim = model_dim#152
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
self.head_dim = model_dim // num_heads
|
||||||
|
|
||||||
|
self.FC_Q = nn.Linear(model_dim, model_dim)#[152,152]
|
||||||
|
self.FC_K = nn.Linear(model_dim, model_dim)
|
||||||
|
self.FC_V = nn.Linear(model_dim, model_dim)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(model_dim, model_dim)
|
||||||
|
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
# Q (batch_size, ..., tgt_length, model_dim)
|
||||||
|
# K, V (batch_size, ..., src_length, model_dim)
|
||||||
|
batch_size = query.shape[0]#16 #64
|
||||||
|
tgt_length = query.shape[-2]#12 #170
|
||||||
|
src_length = key.shape[-2]#12 #170
|
||||||
|
|
||||||
|
query = self.FC_Q(query)#[64,6,170,152]
|
||||||
|
key = self.FC_K(key)
|
||||||
|
value = self.FC_V(value)
|
||||||
|
|
||||||
|
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
|
||||||
|
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)#[512,6,170,24]
|
||||||
|
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
|
||||||
|
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
|
||||||
|
|
||||||
|
key = key.transpose(
|
||||||
|
-1, -2
|
||||||
|
) # (num_heads * batch_size, ..., head_dim, src_length)
|
||||||
|
|
||||||
|
attn_score = (#[64,170,12,12]
|
||||||
|
query @ key
|
||||||
|
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
|
||||||
|
|
||||||
|
if self.mask:
|
||||||
|
mask = torch.ones(
|
||||||
|
tgt_length, src_length, dtype=torch.bool, device=query.device
|
||||||
|
).tril() # lower triangular part of the matrix
|
||||||
|
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
|
||||||
|
|
||||||
|
attn_score = torch.softmax(attn_score, dim=-1)#[64,170,12,12]
|
||||||
|
out = attn_score @ value
|
||||||
|
out = torch.cat(
|
||||||
|
torch.split(out, batch_size, dim=0), dim=-1
|
||||||
|
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)[16,170,12,152]
|
||||||
|
|
||||||
|
out = self.out_proj(out)#[64,6,170,152]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SelfAttentionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = AttentionLayer(model_dim, num_heads, mask)
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
nn.Linear(model_dim, feed_forward_dim),#[152,256]
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(feed_forward_dim, model_dim),#[256.152]
|
||||||
|
)
|
||||||
|
self.ln1 = nn.LayerNorm(model_dim)
|
||||||
|
self.ln2 = nn.LayerNorm(model_dim)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x, dim=-2):
|
||||||
|
x = x.transpose(dim, -2)
|
||||||
|
# x: (batch_size, ..., length, model_dim)
|
||||||
|
residual = x
|
||||||
|
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)[16,170,12,152]
|
||||||
|
out = self.dropout1(out)
|
||||||
|
out = self.ln1(residual + out)
|
||||||
|
|
||||||
|
residual = out
|
||||||
|
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
|
||||||
|
out = self.dropout2(out)
|
||||||
|
out = self.ln2(residual + out)
|
||||||
|
|
||||||
|
out = out.transpose(dim, -2)#[64,6,170,152]
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,414 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
import pandas as pd
|
||||||
|
import sys
|
||||||
|
from model.TEDDCF.ISTF import SelfAttentionLayer
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(self, features, dropout=0.1):#PEMS08: 192
|
||||||
|
super(GLU, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.conv2 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.conv3 = nn.Conv2d(features, features, (1, 1))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
|
||||||
|
x1 = self.conv1(x)
|
||||||
|
x2 = self.conv2(x)
|
||||||
|
out = x1 * torch.sigmoid(x2)
|
||||||
|
out = self.dropout(out)
|
||||||
|
out = self.conv3(out)
|
||||||
|
return out#[64,192,170,12]
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalEmbedding(nn.Module):
|
||||||
|
def __init__(self, time, features):
|
||||||
|
super(TemporalEmbedding, self).__init__()
|
||||||
|
#S08:time 288 features 96
|
||||||
|
self.time = time
|
||||||
|
self.time_day = nn.Parameter(torch.empty(time, features))#[288 96]
|
||||||
|
nn.init.xavier_uniform_(self.time_day)
|
||||||
|
|
||||||
|
self.time_week = nn.Parameter(torch.empty(7, features))#[7 96]
|
||||||
|
nn.init.xavier_uniform_(self.time_week)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#x #in:[64,12,170,3]
|
||||||
|
day_emb = x[..., 1]
|
||||||
|
|
||||||
|
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
||||||
|
|
||||||
|
|
||||||
|
time_day = time_day.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
week_emb = x[..., 2]
|
||||||
|
|
||||||
|
|
||||||
|
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]#[64,12,170,96]
|
||||||
|
time_week = time_week.transpose(1, 2).contiguous()#torch.Size([64, 170, 12, 96])
|
||||||
|
|
||||||
|
|
||||||
|
tem_emb = time_day + time_week#[64,170,12,96]
|
||||||
|
|
||||||
|
tem_emb = tem_emb.permute(0,3,1,2)#[64,96,170,12]
|
||||||
|
|
||||||
|
return tem_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Diffusion_GCN(nn.Module):
|
||||||
|
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.diffusion_step = diffusion_step#1
|
||||||
|
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))#[192,192,(1,1)]
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x, adj):
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for i in range(0, self.diffusion_step):#1
|
||||||
|
if adj.dim() == 3:
|
||||||
|
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
|
||||||
|
out.append(x)
|
||||||
|
elif adj.dim() == 2:
|
||||||
|
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
|
||||||
|
out.append(x)
|
||||||
|
x = torch.cat(out, dim=1)
|
||||||
|
x = self.conv(x)
|
||||||
|
output = self.dropout(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class EventGraph_Fusion(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
|
||||||
|
nn.init.xavier_uniform_(self.memory)
|
||||||
|
self.fc = nn.Linear(2,1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
adj_dyn_1 = torch.softmax(
|
||||||
|
F.relu(
|
||||||
|
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
||||||
|
/ math.sqrt(x.shape[1])
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
adj_dyn_2 = torch.softmax(
|
||||||
|
F.relu(
|
||||||
|
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
||||||
|
/ math.sqrt(x.shape[1])
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
||||||
|
|
||||||
|
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1]*0.8), dim=-1)
|
||||||
|
|
||||||
|
mask = torch.zeros_like(adj_f)
|
||||||
|
|
||||||
|
mask.scatter_(-1, topk_indices, 1)
|
||||||
|
|
||||||
|
adj_f = adj_f * mask
|
||||||
|
|
||||||
|
return adj_f
|
||||||
|
|
||||||
|
|
||||||
|
class EventGCN(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(channels,channels,(1,1))
|
||||||
|
self.generator = EventGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
|
||||||
|
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
||||||
|
self.emb = emb
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
skip = x
|
||||||
|
x = self.conv(x)
|
||||||
|
adj_dyn = self.generator(x)
|
||||||
|
x = self.gcn(x, adj_dyn)
|
||||||
|
x = x*self.emb + skip
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TrendGCN(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(channels,channels,(1,1))
|
||||||
|
self.generator = TrendGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
|
||||||
|
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
||||||
|
self.emb = emb
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
skip = x
|
||||||
|
x = self.conv(x)
|
||||||
|
adj_dyn = self.generator(x)
|
||||||
|
x = self.gcn(x, adj_dyn)
|
||||||
|
x = x*self.emb + skip
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TrendGraph_Fusion(nn.Module):
|
||||||
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.memory = nn.Parameter(
|
||||||
|
torch.randn(channels, num_nodes))
|
||||||
|
nn.init.xavier_uniform_(self.memory)
|
||||||
|
self.fc = nn.Linear(2, 1)
|
||||||
|
self.E_adaptive = nn.Parameter(torch.randn(num_nodes, 10))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# adj_dyn_1 = torch.softmax(
|
||||||
|
# F.relu(
|
||||||
|
# torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
||||||
|
# / math.sqrt(x.shape[1])
|
||||||
|
# ),
|
||||||
|
# -1,
|
||||||
|
# )
|
||||||
|
|
||||||
|
adj_dyn_2 = torch.softmax(
|
||||||
|
F.relu(
|
||||||
|
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
||||||
|
/ math.sqrt(x.shape[1])
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
adj_adp = F.softmax(F.relu(torch.mm(self.E_adaptive, self.E_adaptive.transpose(0, 1))), dim=1)
|
||||||
|
|
||||||
|
adj_adp_expanded = adj_adp.unsqueeze(0)
|
||||||
|
|
||||||
|
adj_adp = adj_adp_expanded.repeat(x.shape[0], 1, 1)
|
||||||
|
|
||||||
|
adj_f = torch.cat([(adj_dyn_2).unsqueeze(-1)] + [(adj_adp).unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
||||||
|
|
||||||
|
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
|
||||||
|
|
||||||
|
mask = torch.zeros_like(adj_f)
|
||||||
|
|
||||||
|
mask.scatter_(-1, topk_indices, 1)
|
||||||
|
|
||||||
|
adj_f = adj_f * mask
|
||||||
|
|
||||||
|
return adj_f
|
||||||
|
|
||||||
|
class Chomp1d(nn.Module):
|
||||||
|
"""
|
||||||
|
extra dimension will be added by padding, remove it
|
||||||
|
"""
|
||||||
|
def __init__(self, chomp_size):
|
||||||
|
super(Chomp1d, self).__init__()
|
||||||
|
self.chomp_size = chomp_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x[:, :, :, :-self.chomp_size].contiguous()
|
||||||
|
|
||||||
|
class TemporalConvNet(nn.Module):
|
||||||
|
def __init__(self, features, kernel_size=2, dropout=0.2, levels=1):
|
||||||
|
super(TemporalConvNet, self).__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i in range(levels):
|
||||||
|
dilation_size = 2 ** i
|
||||||
|
padding = (kernel_size - 1) * dilation_size
|
||||||
|
self.conv = nn.Conv2d(features, features, (1, kernel_size), dilation=(1, dilation_size),
|
||||||
|
padding=(0, padding))
|
||||||
|
self.chomp = Chomp1d(padding)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
||||||
|
self.tcn = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, xh):
|
||||||
|
xh = self.tcn(xh)
|
||||||
|
return xh
|
||||||
|
pass
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, fea, res_ln=False):
|
||||||
|
super(FeedForward, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.res_ln = res_ln
|
||||||
|
self.L = len(fea) - 1#2
|
||||||
|
self.linear = nn.ModuleList([nn.Linear(fea[i], fea[i+1]) for i in range(self.L)])
|
||||||
|
self.ln = nn.LayerNorm(fea[self.L], elementwise_affine=False)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
|
||||||
|
x = inputs
|
||||||
|
for i in range(self.L):
|
||||||
|
x = self.linear[i](x)
|
||||||
|
if i != self.L-1:
|
||||||
|
x = F.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
if self.res_ln:
|
||||||
|
x += inputs
|
||||||
|
x = self.ln(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Adaptive_Fusion(nn.Module):
|
||||||
|
def __init__(self, heads, dims):
|
||||||
|
super(Adaptive_Fusion, self).__init__()
|
||||||
|
features = dims # 192
|
||||||
|
self.h = heads # 8
|
||||||
|
self.d = int(dims / heads) # 16
|
||||||
|
|
||||||
|
self.qlfc = FeedForward([features, features])
|
||||||
|
self.khfc = FeedForward([features, features])
|
||||||
|
self.vhfc = FeedForward([features, features])
|
||||||
|
self.ofc = FeedForward([features, features])
|
||||||
|
|
||||||
|
self.ln = nn.LayerNorm(features, elementwise_affine=False)
|
||||||
|
self.ff = FeedForward([features, features, features], True)
|
||||||
|
|
||||||
|
def forward(self, xl, xh, Mask=True):
|
||||||
|
'''
|
||||||
|
xl: [B,T,N,F]
|
||||||
|
xh: [B,T,N,F]
|
||||||
|
te: [B,T,N,F]
|
||||||
|
return: [B,T,N,F]
|
||||||
|
'''
|
||||||
|
# xl += te
|
||||||
|
# xh += te
|
||||||
|
|
||||||
|
query = self.qlfc(xl) # [B,T,N,F]
|
||||||
|
keyh = torch.relu(self.khfc(xh)) # [B,T,N,F]
|
||||||
|
valueh = torch.relu(self.vhfc(xh)) # [B,T,N,F]
|
||||||
|
|
||||||
|
query = torch.cat(torch.split(query, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
|
||||||
|
keyh = torch.cat(torch.split(keyh, self.d, -1), 0).permute(0, 2, 3, 1) # [k*B,N,d,T]
|
||||||
|
valueh = torch.cat(torch.split(valueh, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
|
||||||
|
|
||||||
|
attentionh = torch.matmul(query, keyh) # [k*B,N,T,T]
|
||||||
|
|
||||||
|
if Mask:
|
||||||
|
batch_size = xl.shape[0]
|
||||||
|
num_steps = xl.shape[1]
|
||||||
|
num_vertexs = xl.shape[2]
|
||||||
|
mask = torch.ones(num_steps, num_steps).to(xl.device) # [T,T]
|
||||||
|
mask = torch.tril(mask) # [T,T]
|
||||||
|
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) # [1,1,T,T]
|
||||||
|
mask = mask.repeat(self.h * batch_size, num_vertexs, 1, 1) # [k*B,N,T,T]
|
||||||
|
mask = mask.to(torch.bool)
|
||||||
|
zero_vec = (-2 ** 15 + 1) * torch.ones_like(attentionh).to(xl.device) # [k*B,N,T,T]
|
||||||
|
attentionh = torch.where(mask, attentionh, zero_vec)
|
||||||
|
|
||||||
|
attentionh /= (self.d ** 0.5) # scaled
|
||||||
|
attentionh = F.softmax(attentionh, -1) # [k*B,N,T,T]
|
||||||
|
|
||||||
|
value = torch.matmul(attentionh, valueh) # [k*B,N,T,d]
|
||||||
|
|
||||||
|
value = torch.cat(torch.split(value, value.shape[0] // self.h, 0), -1).permute(0, 2, 1, 3) # [B,T,N,F]
|
||||||
|
value = self.ofc(value)
|
||||||
|
value = value + xl
|
||||||
|
|
||||||
|
value = self.ln(value)
|
||||||
|
|
||||||
|
return self.ff(value) # [64,12,170,128]
|
||||||
|
|
||||||
|
class TEDDCF(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, device, input_dim, num_nodes, channels, granularity, dropout=0.1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.output_len = 12
|
||||||
|
self.input_len = 12
|
||||||
|
self.heads = 8
|
||||||
|
diffusion_step = 1
|
||||||
|
|
||||||
|
self.Temb = TemporalEmbedding(granularity, channels)
|
||||||
|
|
||||||
|
self.start_conv = nn.Conv2d(
|
||||||
|
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.glu = GLU(channels*2, dropout)
|
||||||
|
|
||||||
|
self.regression_layer = nn.Conv2d(
|
||||||
|
channels*2, self.output_len, kernel_size=(1, self.output_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.temporal_conv = TemporalConvNet(channels*2)
|
||||||
|
self.pre_h = nn.Conv2d(in_channels=self.input_len, out_channels=self.output_len, kernel_size=(1,1))
|
||||||
|
self.adp_f = Adaptive_Fusion(self.heads, channels*2)
|
||||||
|
|
||||||
|
num_layers = 3
|
||||||
|
self.attn_layers_t = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SelfAttentionLayer(channels*2, feed_forward_dim=256, num_heads=4, dropout=0.1)
|
||||||
|
for _ in range(num_layers) # 3
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.xh_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
|
||||||
|
self.xh_dgcn = EventGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1,emb=self.xh_emb)
|
||||||
|
|
||||||
|
self.xl_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
|
||||||
|
self.xl_dgcn = TrendGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xl_emb)
|
||||||
|
|
||||||
|
|
||||||
|
def param_num(self):
|
||||||
|
return sum([param.nelement() for param in self.parameters()])
|
||||||
|
|
||||||
|
def forward(self, inputxl, inputxh):
|
||||||
|
|
||||||
|
xl = inputxl
|
||||||
|
xh = inputxh
|
||||||
|
|
||||||
|
# Encoder
|
||||||
|
# Data Embedding
|
||||||
|
time_embl = self.Temb(inputxl.permute(0, 3, 2, 1))
|
||||||
|
time_embh = self.Temb(inputxh.permute(0, 3, 2, 1))
|
||||||
|
#t = self.start_conv(x)#[64,96,170,12]
|
||||||
|
xl = torch.cat([self.start_conv(xl)] + [time_embl], dim=1)
|
||||||
|
xh = torch.cat([self.start_conv(xh)] + [time_embh], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
xl = xl.permute(0, 3, 2, 1)
|
||||||
|
for attn in self.attn_layers_t:
|
||||||
|
xl = attn(xl, dim=1)
|
||||||
|
xl = xl.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
|
xl = self.xl_dgcn(xl)
|
||||||
|
xl = self.glu(xl) + xl
|
||||||
|
|
||||||
|
|
||||||
|
xh = self.temporal_conv(xh)
|
||||||
|
|
||||||
|
|
||||||
|
xh = self.xh_dgcn(xh)
|
||||||
|
|
||||||
|
#simple plus
|
||||||
|
x_all = xh + xl
|
||||||
|
#STwave_fusion
|
||||||
|
# xl = xl.transpose(1, 3)
|
||||||
|
# xh = self.pre_h(xh.transpose(1,3))#[64,12,170,192]
|
||||||
|
# x_all = self.adp_f(xl, xh)
|
||||||
|
# x_all = x_all.transpose(1, 3)
|
||||||
|
|
||||||
|
prediction = self.regression_layer(F.relu(x_all))
|
||||||
|
|
||||||
|
|
||||||
|
return prediction
|
||||||
Loading…
Reference in New Issue