TrafficWheel/model/STIDGCN/STIDGCN.py

380 lines
12 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GLU(nn.Module):
def __init__(self, features, dropout=0.1):
super(GLU, self).__init__()
self.conv1 = nn.Conv2d(features, features, (1, 1))
self.conv2 = nn.Conv2d(features, features, (1, 1))
self.conv3 = nn.Conv2d(features, features, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
out = x1 * torch.sigmoid(x2)
out = self.dropout(out)
out = self.conv3(out)
return out
# class TemporalEmbedding(nn.Module):
# def __init__(self, time, features):
# super(TemporalEmbedding, self).__init__()
#
# self.time = time
# # self.time_day = nn.Parameter(torch.empty(time, features))
# # nn.init.xavier_uniform_(self.time_day)
# #
# # self.time_week = nn.Parameter(torch.empty(7, features))
# # nn.init.xavier_uniform_(self.time_week)
# self.time_day = nn.Embedding(time, features)
# self.time_week = nn.Embedding(7, features)
#
# def forward(self, x):
# day_emb = x[..., 1]
# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
# # time_day = time_day.transpose(1, 2).contiguous()
#
# week_emb = x[..., 2]
# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
# # time_week = time_week.transpose(1, 2).contiguous()
#
# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N)
# d_idx = week_emb[:, -1, :, ].long() # (B, N)
# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
#
# tem_emb = t_idx + d_idx
#
# # tem_emb = tem_emb.permute(0, 3, 1, 2)
#
# return tem_emb
class TemporalEmbedding(nn.Module):
def __init__(self, time, features):
super(TemporalEmbedding, self).__init__()
self.time = time
self.time_day = nn.Parameter(torch.empty(time, features))
nn.init.xavier_uniform_(self.time_day)
self.time_week = nn.Parameter(torch.empty(7, features))
nn.init.xavier_uniform_(self.time_week)
def forward(self, x):
day_emb = x[..., 1]
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
time_day = time_day.transpose(1, 2).contiguous()
week_emb = x[..., 2]
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
time_week = time_week.transpose(1, 2).contiguous()
tem_emb = time_day + time_week
tem_emb = tem_emb.permute(0, 3, 1, 2)
return tem_emb
class Diffusion_GCN(nn.Module):
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
super().__init__()
self.diffusion_step = diffusion_step
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x, adj):
out = []
for i in range(0, self.diffusion_step):
if adj.dim() == 3:
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
out.append(x)
elif adj.dim() == 2:
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
out.append(x)
x = torch.cat(out, dim=1)
x = self.conv(x)
output = self.dropout(x)
return output
class Graph_Generator(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2, 1)
def forward(self, x):
adj_dyn_1 = torch.softmax(
F.relu(
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
# adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2
adj_f = torch.cat(
[(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1
)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(
adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1
)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class DGCN(nn.Module):
def __init__(
self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None
):
super().__init__()
self.conv = nn.Conv2d(channels, channels, (1, 1))
self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x * self.emb + skip
return x
class Splitting(nn.Module):
def __init__(self):
super(Splitting, self).__init__()
def even(self, x):
return x[:, :, :, ::2]
def odd(self, x):
return x[:, :, :, 1::2]
def forward(self, x):
return (self.even(x), self.odd(x))
class IDGCN(nn.Module):
def __init__(
self,
device,
channels=64,
diffusion_step=1,
splitting=True,
num_nodes=170,
dropout=0.2,
emb=None,
):
super(IDGCN, self).__init__()
device = device
self.dropout = dropout
self.num_nodes = num_nodes
self.splitting = splitting
self.split = Splitting()
Conv1 = []
Conv2 = []
Conv3 = []
Conv4 = []
pad_l = 3
pad_r = 3
k1 = 5
k2 = 3
Conv1 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv2 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv4 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv3 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
self.conv1 = nn.Sequential(*Conv1)
self.conv2 = nn.Sequential(*Conv2)
self.conv3 = nn.Sequential(*Conv3)
self.conv4 = nn.Sequential(*Conv4)
self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb)
def forward(self, x):
if self.splitting:
(x_even, x_odd) = self.split(x)
else:
(x_even, x_odd) = x
x1 = self.conv1(x_even)
x1 = self.dgcn(x1)
d = x_odd.mul(torch.tanh(x1))
x2 = self.conv2(x_odd)
x2 = self.dgcn(x2)
c = x_even.mul(torch.tanh(x2))
x3 = self.conv3(c)
x3 = self.dgcn(x3)
x_odd_update = d + x3
x4 = self.conv4(d)
x4 = self.dgcn(x4)
x_even_update = c + x4
return (x_even_update, x_odd_update)
class IDGCN_Tree(nn.Module):
def __init__(
self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1
):
super().__init__()
self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6))
self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3))
self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3))
self.IDGCN1 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout,
emb=self.memory1,
)
self.IDGCN2 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout,
emb=self.memory2,
)
self.IDGCN3 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout,
emb=self.memory2,
)
def concat(self, even, odd):
even = even.permute(3, 1, 2, 0)
odd = odd.permute(3, 1, 2, 0)
len = even.shape[0]
_ = []
for i in range(len):
_.append(even[i].unsqueeze(0))
_.append(odd[i].unsqueeze(0))
return torch.cat(_, 0).permute(3, 1, 2, 0)
def forward(self, x):
x_even_update1, x_odd_update1 = self.IDGCN1(x)
x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1)
x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1)
concat1 = self.concat(x_even_update2, x_odd_update2)
concat2 = self.concat(x_even_update3, x_odd_update3)
concat0 = self.concat(concat1, concat2)
output = concat0 + x
return output
class STIDGCN(nn.Module):
def __init__(self, args):
"""
device, input_dim, num_nodes, channels, granularity, dropout=0.1
"""
super().__init__()
device = args["device"]
input_dim = args["input_dim"]
self.num_nodes = args["num_nodes"]
self.output_len = 12
channels = args["channels"]
granularity = args["granularity"]
dropout = args["dropout"]
diffusion_step = 1
self.Temb = TemporalEmbedding(granularity, channels)
self.start_conv = nn.Conv2d(
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
)
self.tree = IDGCN_Tree(
device=device,
channels=channels * 2,
diffusion_step=diffusion_step,
num_nodes=self.num_nodes,
dropout=dropout,
)
self.glu = GLU(channels * 2, dropout)
self.regression_layer = nn.Conv2d(
channels * 2, self.output_len, kernel_size=(1, self.output_len)
)
def param_num(self):
return sum([param.nelement() for param in self.parameters()])
def forward(self, input):
input = input.transpose(1, 3)
x = input
# Encoder
# Data Embedding
time_emb = self.Temb(input.permute(0, 3, 2, 1))
x = torch.cat([self.start_conv(x)] + [time_emb], dim=1)
# IDGCN_Tree
x = self.tree(x)
# Decoder
gcn = self.glu(x) + x
prediction = self.regression_layer(F.relu(gcn))
return prediction