TrafficWheel/model/TEDDCF/model.py

445 lines
13 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import sys
from model.TEDDCF.ISTF import SelfAttentionLayer
class GLU(nn.Module):
def __init__(self, features, dropout=0.1): # PEMS08: 192
super(GLU, self).__init__()
self.conv1 = nn.Conv2d(features, features, (1, 1))
self.conv2 = nn.Conv2d(features, features, (1, 1))
self.conv3 = nn.Conv2d(features, features, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
out = x1 * torch.sigmoid(x2)
out = self.dropout(out)
out = self.conv3(out)
return out # [64,192,170,12]
class TemporalEmbedding(nn.Module):
def __init__(self, time, features):
super(TemporalEmbedding, self).__init__()
# S08:time 288 features 96
self.time = time
self.time_day = nn.Parameter(torch.empty(time, features)) # [288 96]
nn.init.xavier_uniform_(self.time_day)
self.time_week = nn.Parameter(torch.empty(7, features)) # [7 96]
nn.init.xavier_uniform_(self.time_week)
def forward(self, x):
# x #in:[64,12,170,3]
day_emb = x[..., 1]
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
time_day = time_day.transpose(1, 2).contiguous()
week_emb = x[..., 2]
time_week = self.time_week[
(week_emb[:, :, :]).type(torch.LongTensor)
] # [64,12,170,96]
time_week = time_week.transpose(
1, 2
).contiguous() # torch.Size([64, 170, 12, 96])
tem_emb = time_day + time_week # [64,170,12,96]
tem_emb = tem_emb.permute(0, 3, 1, 2) # [64,96,170,12]
return tem_emb
class Diffusion_GCN(nn.Module):
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
super().__init__()
self.diffusion_step = diffusion_step # 1
self.conv = nn.Conv2d(
diffusion_step * channels, channels, (1, 1)
) # [192,192,(1,1)]
self.dropout = nn.Dropout(dropout)
def forward(self, x, adj):
out = []
for i in range(0, self.diffusion_step): # 1
if adj.dim() == 3:
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
out.append(x)
elif adj.dim() == 2:
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
out.append(x)
x = torch.cat(out, dim=1)
x = self.conv(x)
output = self.dropout(x)
return output
class EventGraph_Fusion(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2, 1)
def forward(self, x):
adj_dyn_1 = torch.softmax(
F.relu(
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_f = torch.cat(
[(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1
)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(
adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1
)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class EventGCN(nn.Module):
def __init__(
self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None
):
super().__init__()
self.conv = nn.Conv2d(channels, channels, (1, 1))
self.generator = EventGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x * self.emb + skip
return x
class TrendGCN(nn.Module):
def __init__(
self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None
):
super().__init__()
self.conv = nn.Conv2d(channels, channels, (1, 1))
self.generator = TrendGraph_Fusion(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x * self.emb + skip
return x
class TrendGraph_Fusion(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2, 1)
self.E_adaptive = nn.Parameter(torch.randn(num_nodes, 10))
def forward(self, x):
# adj_dyn_1 = torch.softmax(
# F.relu(
# torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
# / math.sqrt(x.shape[1])
# ),
# -1,
# )
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_adp = F.softmax(
F.relu(torch.mm(self.E_adaptive, self.E_adaptive.transpose(0, 1))), dim=1
)
adj_adp_expanded = adj_adp.unsqueeze(0)
adj_adp = adj_adp_expanded.repeat(x.shape[0], 1, 1)
adj_f = torch.cat(
[(adj_dyn_2).unsqueeze(-1)] + [(adj_adp).unsqueeze(-1)], dim=-1
)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(
adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1
)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class Chomp1d(nn.Module):
"""
extra dimension will be added by padding, remove it
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, : -self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, features, kernel_size=2, dropout=0.2, levels=1):
super(TemporalConvNet, self).__init__()
layers = []
for i in range(levels):
dilation_size = 2**i
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(
features,
features,
(1, kernel_size),
dilation=(1, dilation_size),
padding=(0, padding),
)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.tcn = nn.Sequential(*layers)
def forward(self, xh):
xh = self.tcn(xh)
return xh
pass
class FeedForward(nn.Module):
def __init__(self, fea, res_ln=False):
super(FeedForward, self).__init__()
self.res_ln = res_ln
self.L = len(fea) - 1 # 2
self.linear = nn.ModuleList(
[nn.Linear(fea[i], fea[i + 1]) for i in range(self.L)]
)
self.ln = nn.LayerNorm(fea[self.L], elementwise_affine=False)
def forward(self, inputs):
x = inputs
for i in range(self.L):
x = self.linear[i](x)
if i != self.L - 1:
x = F.relu(x)
if self.res_ln:
x += inputs
x = self.ln(x)
return x
class Adaptive_Fusion(nn.Module):
def __init__(self, heads, dims):
super(Adaptive_Fusion, self).__init__()
features = dims # 192
self.h = heads # 8
self.d = int(dims / heads) # 16
self.qlfc = FeedForward([features, features])
self.khfc = FeedForward([features, features])
self.vhfc = FeedForward([features, features])
self.ofc = FeedForward([features, features])
self.ln = nn.LayerNorm(features, elementwise_affine=False)
self.ff = FeedForward([features, features, features], True)
def forward(self, xl, xh, Mask=True):
"""
xl: [B,T,N,F]
xh: [B,T,N,F]
te: [B,T,N,F]
return: [B,T,N,F]
"""
# xl += te
# xh += te
query = self.qlfc(xl) # [B,T,N,F]
keyh = torch.relu(self.khfc(xh)) # [B,T,N,F]
valueh = torch.relu(self.vhfc(xh)) # [B,T,N,F]
query = torch.cat(torch.split(query, self.d, -1), 0).permute(
0, 2, 1, 3
) # [k*B,N,T,d]
keyh = torch.cat(torch.split(keyh, self.d, -1), 0).permute(
0, 2, 3, 1
) # [k*B,N,d,T]
valueh = torch.cat(torch.split(valueh, self.d, -1), 0).permute(
0, 2, 1, 3
) # [k*B,N,T,d]
attentionh = torch.matmul(query, keyh) # [k*B,N,T,T]
if Mask:
batch_size = xl.shape[0]
num_steps = xl.shape[1]
num_vertexs = xl.shape[2]
mask = torch.ones(num_steps, num_steps).to(xl.device) # [T,T]
mask = torch.tril(mask) # [T,T]
mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0) # [1,1,T,T]
mask = mask.repeat(self.h * batch_size, num_vertexs, 1, 1) # [k*B,N,T,T]
mask = mask.to(torch.bool)
zero_vec = (-(2**15) + 1) * torch.ones_like(attentionh).to(
xl.device
) # [k*B,N,T,T]
attentionh = torch.where(mask, attentionh, zero_vec)
attentionh /= self.d**0.5 # scaled
attentionh = F.softmax(attentionh, -1) # [k*B,N,T,T]
value = torch.matmul(attentionh, valueh) # [k*B,N,T,d]
value = torch.cat(torch.split(value, value.shape[0] // self.h, 0), -1).permute(
0, 2, 1, 3
) # [B,T,N,F]
value = self.ofc(value)
value = value + xl
value = self.ln(value)
return self.ff(value) # [64,12,170,128]
class TEDDCF(nn.Module):
def __init__(
self, device, input_dim, num_nodes, channels, granularity, dropout=0.1
):
super().__init__()
self.device = device
self.num_nodes = num_nodes
self.output_len = 12
self.input_len = 12
self.heads = 8
diffusion_step = 1
self.Temb = TemporalEmbedding(granularity, channels)
self.start_conv = nn.Conv2d(
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
)
self.glu = GLU(channels * 2, dropout)
self.regression_layer = nn.Conv2d(
channels * 2, self.output_len, kernel_size=(1, self.output_len)
)
self.temporal_conv = TemporalConvNet(channels * 2)
self.pre_h = nn.Conv2d(
in_channels=self.input_len, out_channels=self.output_len, kernel_size=(1, 1)
)
self.adp_f = Adaptive_Fusion(self.heads, channels * 2)
num_layers = 3
self.attn_layers_t = nn.ModuleList(
[
SelfAttentionLayer(
channels * 2, feed_forward_dim=256, num_heads=4, dropout=0.1
)
for _ in range(num_layers) # 3
]
)
self.xh_emb = nn.Parameter(torch.randn(channels * 2, num_nodes, 12))
self.xh_dgcn = EventGCN(
channels * 2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xh_emb
)
self.xl_emb = nn.Parameter(torch.randn(channels * 2, num_nodes, 12))
self.xl_dgcn = TrendGCN(
channels * 2, num_nodes, diffusion_step=1, dropout=0.1, emb=self.xl_emb
)
def param_num(self):
return sum([param.nelement() for param in self.parameters()])
def forward(self, inputxl, inputxh):
xl = inputxl
xh = inputxh
# Encoder
# Data Embedding
time_embl = self.Temb(inputxl.permute(0, 3, 2, 1))
time_embh = self.Temb(inputxh.permute(0, 3, 2, 1))
# t = self.start_conv(x)#[64,96,170,12]
xl = torch.cat([self.start_conv(xl)] + [time_embl], dim=1)
xh = torch.cat([self.start_conv(xh)] + [time_embh], dim=1)
xl = xl.permute(0, 3, 2, 1)
for attn in self.attn_layers_t:
xl = attn(xl, dim=1)
xl = xl.permute(0, 3, 2, 1)
xl = self.xl_dgcn(xl)
xl = self.glu(xl) + xl
xh = self.temporal_conv(xh)
xh = self.xh_dgcn(xh)
# simple plus
x_all = xh + xl
# STwave_fusion
# xl = xl.transpose(1, 3)
# xh = self.pre_h(xh.transpose(1,3))#[64,12,170,192]
# x_all = self.adp_f(xl, xh)
# x_all = x_all.transpose(1, 3)
prediction = self.regression_layer(F.relu(x_all))
return prediction