414 lines
13 KiB
Python
414 lines
13 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
import pandas as pd
|
|
import sys
|
|
from model.TEDDCF.ISTF import SelfAttentionLayer
|
|
|
|
|
|
class GLU(nn.Module):
|
|
def __init__(self, features, dropout=0.1):#PEMS08: 192
|
|
super(GLU, self).__init__()
|
|
self.conv1 = nn.Conv2d(features, features, (1, 1))
|
|
self.conv2 = nn.Conv2d(features, features, (1, 1))
|
|
self.conv3 = nn.Conv2d(features, features, (1, 1))
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
out = x1 * torch.sigmoid(x2)
|
|
out = self.dropout(out)
|
|
out = self.conv3(out)
|
|
return out#[64,192,170,12]
|
|
|
|
|
|
class TemporalEmbedding(nn.Module):
|
|
def __init__(self, time, features):
|
|
super(TemporalEmbedding, self).__init__()
|
|
#S08:time 288 features 96
|
|
self.time = time
|
|
self.time_day = nn.Parameter(torch.empty(time, features))#[288 96]
|
|
nn.init.xavier_uniform_(self.time_day)
|
|
|
|
self.time_week = nn.Parameter(torch.empty(7, features))#[7 96]
|
|
nn.init.xavier_uniform_(self.time_week)
|
|
|
|
def forward(self, x):
|
|
#x #in:[64,12,170,3]
|
|
day_emb = x[..., 1]
|
|
|
|
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
|
|
|
|
|
|
time_day = time_day.transpose(1, 2).contiguous()
|
|
|
|
week_emb = x[..., 2]
|
|
|
|
|
|
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]#[64,12,170,96]
|
|
time_week = time_week.transpose(1, 2).contiguous()#torch.Size([64, 170, 12, 96])
|
|
|
|
|
|
tem_emb = time_day + time_week#[64,170,12,96]
|
|
|
|
tem_emb = tem_emb.permute(0,3,1,2)#[64,96,170,12]
|
|
|
|
return tem_emb
|
|
|
|
|
|
|
|
class Diffusion_GCN(nn.Module):
|
|
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
|
|
super().__init__()
|
|
self.diffusion_step = diffusion_step#1
|
|
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))#[192,192,(1,1)]
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x, adj):
|
|
|
|
out = []
|
|
for i in range(0, self.diffusion_step):#1
|
|
if adj.dim() == 3:
|
|
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
|
|
out.append(x)
|
|
elif adj.dim() == 2:
|
|
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
|
|
out.append(x)
|
|
x = torch.cat(out, dim=1)
|
|
x = self.conv(x)
|
|
output = self.dropout(x)
|
|
return output
|
|
|
|
|
|
class EventGraph_Fusion(nn.Module):
|
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
|
super().__init__()
|
|
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
|
|
nn.init.xavier_uniform_(self.memory)
|
|
self.fc = nn.Linear(2,1)
|
|
|
|
def forward(self, x):
|
|
adj_dyn_1 = torch.softmax(
|
|
F.relu(
|
|
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
|
/ math.sqrt(x.shape[1])
|
|
),
|
|
-1,
|
|
)
|
|
adj_dyn_2 = torch.softmax(
|
|
F.relu(
|
|
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
|
/ math.sqrt(x.shape[1])
|
|
),
|
|
-1,
|
|
)
|
|
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
|
|
|
|
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
|
|
|
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1]*0.8), dim=-1)
|
|
|
|
mask = torch.zeros_like(adj_f)
|
|
|
|
mask.scatter_(-1, topk_indices, 1)
|
|
|
|
adj_f = adj_f * mask
|
|
|
|
return adj_f
|
|
|
|
|
|
class EventGCN(nn.Module):
|
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(channels,channels,(1,1))
|
|
self.generator = EventGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
|
|
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
|
self.emb = emb
|
|
|
|
def forward(self, x):
|
|
|
|
skip = x
|
|
x = self.conv(x)
|
|
adj_dyn = self.generator(x)
|
|
x = self.gcn(x, adj_dyn)
|
|
x = x*self.emb + skip
|
|
|
|
return x
|
|
|
|
class TrendGCN(nn.Module):
|
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(channels,channels,(1,1))
|
|
self.generator = TrendGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
|
|
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
|
|
self.emb = emb
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
skip = x
|
|
x = self.conv(x)
|
|
adj_dyn = self.generator(x)
|
|
x = self.gcn(x, adj_dyn)
|
|
x = x*self.emb + skip
|
|
|
|
return x
|
|
|
|
|
|
class TrendGraph_Fusion(nn.Module):
|
|
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
|
|
super().__init__()
|
|
self.memory = nn.Parameter(
|
|
torch.randn(channels, num_nodes))
|
|
nn.init.xavier_uniform_(self.memory)
|
|
self.fc = nn.Linear(2, 1)
|
|
self.E_adaptive = nn.Parameter(torch.randn(num_nodes, 10))
|
|
|
|
def forward(self, x):
|
|
# adj_dyn_1 = torch.softmax(
|
|
# F.relu(
|
|
# torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
|
|
# / math.sqrt(x.shape[1])
|
|
# ),
|
|
# -1,
|
|
# )
|
|
|
|
adj_dyn_2 = torch.softmax(
|
|
F.relu(
|
|
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
|
|
/ math.sqrt(x.shape[1])
|
|
),
|
|
-1,
|
|
)
|
|
adj_adp = F.softmax(F.relu(torch.mm(self.E_adaptive, self.E_adaptive.transpose(0, 1))), dim=1)
|
|
|
|
adj_adp_expanded = adj_adp.unsqueeze(0)
|
|
|
|
adj_adp = adj_adp_expanded.repeat(x.shape[0], 1, 1)
|
|
|
|
adj_f = torch.cat([(adj_dyn_2).unsqueeze(-1)] + [(adj_adp).unsqueeze(-1)], dim=-1)
|
|
|
|
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
|
|
|
|
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
|
|
|
|
mask = torch.zeros_like(adj_f)
|
|
|
|
mask.scatter_(-1, topk_indices, 1)
|
|
|
|
adj_f = adj_f * mask
|
|
|
|
return adj_f
|
|
|
|
class Chomp1d(nn.Module):
|
|
"""
|
|
extra dimension will be added by padding, remove it
|
|
"""
|
|
def __init__(self, chomp_size):
|
|
super(Chomp1d, self).__init__()
|
|
self.chomp_size = chomp_size
|
|
|
|
def forward(self, x):
|
|
return x[:, :, :, :-self.chomp_size].contiguous()
|
|
|
|
class TemporalConvNet(nn.Module):
|
|
def __init__(self, features, kernel_size=2, dropout=0.2, levels=1):
|
|
super(TemporalConvNet, self).__init__()
|
|
|
|
layers = []
|
|
for i in range(levels):
|
|
dilation_size = 2 ** i
|
|
padding = (kernel_size - 1) * dilation_size
|
|
self.conv = nn.Conv2d(features, features, (1, kernel_size), dilation=(1, dilation_size),
|
|
padding=(0, padding))
|
|
self.chomp = Chomp1d(padding)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
|
self.tcn = nn.Sequential(*layers)
|
|
|
|
def forward(self, xh):
|
|
xh = self.tcn(xh)
|
|
return xh
|
|
pass
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, fea, res_ln=False):
|
|
super(FeedForward, self).__init__()
|
|
|
|
|
|
|
|
self.res_ln = res_ln
|
|
self.L = len(fea) - 1#2
|
|
self.linear = nn.ModuleList([nn.Linear(fea[i], fea[i+1]) for i in range(self.L)])
|
|
self.ln = nn.LayerNorm(fea[self.L], elementwise_affine=False)
|
|
|
|
def forward(self, inputs):
|
|
|
|
x = inputs
|
|
for i in range(self.L):
|
|
x = self.linear[i](x)
|
|
if i != self.L-1:
|
|
x = F.relu(x)
|
|
|
|
|
|
if self.res_ln:
|
|
x += inputs
|
|
x = self.ln(x)
|
|
return x
|
|
|
|
class Adaptive_Fusion(nn.Module):
|
|
def __init__(self, heads, dims):
|
|
super(Adaptive_Fusion, self).__init__()
|
|
features = dims # 192
|
|
self.h = heads # 8
|
|
self.d = int(dims / heads) # 16
|
|
|
|
self.qlfc = FeedForward([features, features])
|
|
self.khfc = FeedForward([features, features])
|
|
self.vhfc = FeedForward([features, features])
|
|
self.ofc = FeedForward([features, features])
|
|
|
|
self.ln = nn.LayerNorm(features, elementwise_affine=False)
|
|
self.ff = FeedForward([features, features, features], True)
|
|
|
|
def forward(self, xl, xh, Mask=True):
|
|
'''
|
|
xl: [B,T,N,F]
|
|
xh: [B,T,N,F]
|
|
te: [B,T,N,F]
|
|
return: [B,T,N,F]
|
|
'''
|
|
# xl += te
|
|
# xh += te
|
|
|
|
query = self.qlfc(xl) # [B,T,N,F]
|
|
keyh = torch.relu(self.khfc(xh)) # [B,T,N,F]
|
|
valueh = torch.relu(self.vhfc(xh)) # [B,T,N,F]
|
|
|
|
query = torch.cat(torch.split(query, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
|
|
keyh = torch.cat(torch.split(keyh, self.d, -1), 0).permute(0, 2, 3, 1) # [k*B,N,d,T]
|
|
valueh = torch.cat(torch.split(valueh, self.d, -1), 0).permute(0, 2, 1, 3) # [k*B,N,T,d]
|
|
|
|
attentionh = torch.matmul(query, keyh) # [k*B,N,T,T]
|
|
|
|
if Mask:
|
|
batch_size = xl.shape[0]
|
|
num_steps = xl.shape[1]
|
|
num_vertexs = xl.shape[2]
|
|
mask = torch.ones(num_steps, num_steps).to(xl.device) # [T,T]
|
|
mask = torch.tril(mask) # [T,T]
|
|
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) # [1,1,T,T]
|
|
mask = mask.repeat(self.h * batch_size, num_vertexs, 1, 1) # [k*B,N,T,T]
|
|
mask = mask.to(torch.bool)
|
|
zero_vec = (-2 ** 15 + 1) * torch.ones_like(attentionh).to(xl.device) # [k*B,N,T,T]
|
|
attentionh = torch.where(mask, attentionh, zero_vec)
|
|
|
|
attentionh /= (self.d ** 0.5) # scaled
|
|
attentionh = F.softmax(attentionh, -1) # [k*B,N,T,T]
|
|
|
|
value = torch.matmul(attentionh, valueh) # [k*B,N,T,d]
|
|
|
|
value = torch.cat(torch.split(value, value.shape[0] // self.h, 0), -1).permute(0, 2, 1, 3) # [B,T,N,F]
|
|
value = self.ofc(value)
|
|
value = value + xl
|
|
|
|
value = self.ln(value)
|
|
|
|
return self.ff(value) # [64,12,170,128]
|
|
|
|
class TEDDCF(nn.Module):
|
|
def __init__(
|
|
self, device, input_dim, num_nodes, channels, granularity, dropout=0.1
|
|
):
|
|
super().__init__()
|
|
|
|
self.device = device
|
|
self.num_nodes = num_nodes
|
|
self.output_len = 12
|
|
self.input_len = 12
|
|
self.heads = 8
|
|
diffusion_step = 1
|
|
|
|
self.Temb = TemporalEmbedding(granularity, channels)
|
|
|
|
self.start_conv = nn.Conv2d(
|
|
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
|
|
)
|
|
|
|
|
|
self.glu = GLU(channels*2, dropout)
|
|
|
|
self.regression_layer = nn.Conv2d(
|
|
channels*2, self.output_len, kernel_size=(1, self.output_len)
|
|
)
|
|
|
|
self.temporal_conv = TemporalConvNet(channels*2)
|
|
self.pre_h = nn.Conv2d(in_channels=self.input_len, out_channels=self.output_len, kernel_size=(1,1))
|
|
self.adp_f = Adaptive_Fusion(self.heads, channels*2)
|
|
|
|
num_layers = 3
|
|
self.attn_layers_t = nn.ModuleList(
|
|
[
|
|
SelfAttentionLayer(channels*2, feed_forward_dim=256, num_heads=4, dropout=0.1)
|
|
for _ in range(num_layers) # 3
|
|
]
|
|
)
|
|
self.xh_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
|
|
self.xh_dgcn = EventGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1,emb=self.xh_emb)
|
|
|
|
self.xl_emb = nn.Parameter(torch.randn(channels*2, num_nodes, 12))
|
|
self.xl_dgcn = TrendGCN(channels*2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xl_emb)
|
|
|
|
|
|
def param_num(self):
|
|
return sum([param.nelement() for param in self.parameters()])
|
|
|
|
def forward(self, inputxl, inputxh):
|
|
|
|
xl = inputxl
|
|
xh = inputxh
|
|
|
|
# Encoder
|
|
# Data Embedding
|
|
time_embl = self.Temb(inputxl.permute(0, 3, 2, 1))
|
|
time_embh = self.Temb(inputxh.permute(0, 3, 2, 1))
|
|
#t = self.start_conv(x)#[64,96,170,12]
|
|
xl = torch.cat([self.start_conv(xl)] + [time_embl], dim=1)
|
|
xh = torch.cat([self.start_conv(xh)] + [time_embh], dim=1)
|
|
|
|
|
|
|
|
xl = xl.permute(0, 3, 2, 1)
|
|
for attn in self.attn_layers_t:
|
|
xl = attn(xl, dim=1)
|
|
xl = xl.permute(0, 3, 2, 1)
|
|
|
|
xl = self.xl_dgcn(xl)
|
|
xl = self.glu(xl) + xl
|
|
|
|
|
|
xh = self.temporal_conv(xh)
|
|
|
|
|
|
xh = self.xh_dgcn(xh)
|
|
|
|
#simple plus
|
|
x_all = xh + xl
|
|
#STwave_fusion
|
|
# xl = xl.transpose(1, 3)
|
|
# xh = self.pre_h(xh.transpose(1,3))#[64,12,170,192]
|
|
# x_all = self.adp_f(xl, xh)
|
|
# x_all = x_all.transpose(1, 3)
|
|
|
|
prediction = self.regression_layer(F.relu(x_all))
|
|
|
|
|
|
return prediction |