TrafficWheel/model/STGNCDE/vector_fields.py

506 lines
17 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
class FinalTanh_f(nn.Module):
def __init__(
self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers
):
super(FinalTanh_f, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels)
self.linears = nn.ModuleList(
torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
for _ in range(num_hidden_layers - 1)
)
self.linear_out = nn.Linear(
hidden_hidden_channels, input_channels * hidden_channels
) # 32,32*4 -> # 32,32,4
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
# z: torch.Size([64, 207, 32])
# self.linear_out(z): torch.Size([64, 207, 64])
z = self.linear_out(z).view(
*z.shape[:-1], self.hidden_channels, self.input_channels
)
z = z.tanh()
return z
class FinalTanh_f_prime(nn.Module):
def __init__(
self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers
):
super(FinalTanh_f_prime, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels)
self.linears = nn.ModuleList(
torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
for _ in range(num_hidden_layers - 1)
)
# self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.linear_out = nn.Linear(
hidden_hidden_channels, hidden_channels * hidden_channels
) # 32,32*4 -> # 32,32,4
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
# z: torch.Size([64, 207, 32])
# self.linear_out(z): torch.Size([64, 207, 64])
# z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = self.linear_out(z).view(
*z.shape[:-1], self.hidden_channels, self.hidden_channels
)
z = z.tanh()
return z
class FinalTanh_f2(torch.nn.Module):
def __init__(
self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers
):
super(FinalTanh_f2, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
# self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
# self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.start_conv = torch.nn.Conv2d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=(1, 1),
)
# self.linear = torch.nn.Conv2d(in_channels=hidden_channels,
# out_channels=hidden_channels,
# kernel_size=(1,1))
self.linears = torch.nn.ModuleList(
torch.nn.Conv2d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=(1, 1),
)
for _ in range(num_hidden_layers - 1)
)
self.linear_out = torch.nn.Conv2d(
in_channels=hidden_channels,
out_channels=input_channels * hidden_channels,
kernel_size=(1, 1),
)
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
# z: torch.Size([64, 207, 32])
z = self.start_conv(z.transpose(1, 2).unsqueeze(-1))
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
z = (
self.linear_out(z)
.squeeze()
.transpose(1, 2)
.view(
*z.transpose(1, 2).shape[:-2], self.hidden_channels, self.input_channels
)
)
z = z.tanh()
return z
class VectorField_g(torch.nn.Module):
def __init__(
self,
input_channels,
hidden_channels,
hidden_hidden_channels,
num_hidden_layers,
num_nodes,
cheb_k,
embed_dim,
g_type,
):
super(VectorField_g, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
# FIXME:
# self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.linear_out = torch.nn.Linear(
hidden_hidden_channels, hidden_channels * hidden_channels
) # 32,32*4 -> # 32,32,4
self.g_type = g_type
if self.g_type == "agc":
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(
torch.FloatTensor(
embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels
)
)
self.bias_pool = nn.Parameter(
torch.FloatTensor(embed_dim, hidden_hidden_channels)
)
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
if self.g_type == "agc":
z = self.agc(z)
else:
raise ValueError("Check g_type argument")
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
# FIXME:
# z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = self.linear_out(z).view(
*z.shape[:-1], self.hidden_channels, self.hidden_channels
)
z = z.tanh()
return z # torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(
F.relu(
torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))
),
dim=1,
)
# laplacian=False
laplacian = False
if laplacian == True:
# support_set = [torch.eye(node_num).to(supports.device), -supports]
support_set = [supports, -torch.eye(node_num).to(supports.device)]
# support_set = [torch.eye(node_num).to(supports.device), -supports]
# support_set = [-supports]
else:
support_set = [torch.eye(node_num).to(supports.device), supports]
# default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(
torch.matmul(2 * supports, support_set[-1]) - support_set[-2]
)
supports = torch.stack(support_set, dim=0)
weights = torch.einsum(
"nd,dkio->nkio", self.node_embeddings, self.weights_pool
) # N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out
return z
class VectorField_only_g(torch.nn.Module):
def __init__(
self,
input_channels,
hidden_channels,
hidden_hidden_channels,
num_hidden_layers,
num_nodes,
cheb_k,
embed_dim,
g_type,
):
super(VectorField_only_g, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
# FIXME:
self.linear_out = torch.nn.Linear(
hidden_hidden_channels, input_channels * hidden_channels
) # 32,32*4 -> # 32,32,4
# self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.g_type = g_type
if self.g_type == "agc":
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(
torch.FloatTensor(
embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels
)
)
self.bias_pool = nn.Parameter(
torch.FloatTensor(embed_dim, hidden_hidden_channels)
)
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
if self.g_type == "agc":
z = self.agc(z)
else:
raise ValueError("Check g_type argument")
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
# FIXME:
z = self.linear_out(z).view(
*z.shape[:-1], self.hidden_channels, self.input_channels
)
# z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels)
z = z.tanh()
return z # torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(
F.relu(
torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))
),
dim=1,
)
laplacian = False
if laplacian == True:
# support_set = [torch.eye(node_num).to(supports.device), -supports]
support_set = [supports, -torch.eye(node_num).to(supports.device)]
# support_set = [torch.eye(node_num).to(supports.device), -supports]
# support_set = [-supports]
else:
support_set = [torch.eye(node_num).to(supports.device), supports]
# default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(
torch.matmul(2 * supports, support_set[-1]) - support_set[-2]
)
supports = torch.stack(support_set, dim=0)
weights = torch.einsum(
"nd,dkio->nkio", self.node_embeddings, self.weights_pool
) # N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out
return z
class VectorField_g_prime(torch.nn.Module):
def __init__(
self,
input_channels,
hidden_channels,
hidden_hidden_channels,
num_hidden_layers,
num_nodes,
cheb_k,
embed_dim,
g_type,
):
super(VectorField_g_prime, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
self.linear_out = torch.nn.Linear(
hidden_hidden_channels, input_channels * hidden_channels
) # 32,32*4 -> # 32,32,4
self.g_type = g_type
if self.g_type == "agc":
self.node_embeddings = nn.Parameter(
torch.randn(num_nodes, embed_dim), requires_grad=True
)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(
torch.FloatTensor(
embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels
)
)
self.bias_pool = nn.Parameter(
torch.FloatTensor(embed_dim, hidden_hidden_channels)
)
def extra_repr(self):
return (
"input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}"
"".format(
self.input_channels,
self.hidden_channels,
self.hidden_hidden_channels,
self.num_hidden_layers,
)
)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
if self.g_type == "agc":
z = self.agc(z)
else:
raise ValueError("Check g_type argument")
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
z = self.linear_out(z).view(
*z.shape[:-1], self.hidden_channels, self.input_channels
)
z = z.tanh()
return z # torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(
F.relu(
torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))
),
dim=1,
)
support_set = [torch.eye(node_num).to(supports.device), supports]
# default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(
torch.matmul(2 * supports, support_set[-1]) - support_set[-2]
)
supports = torch.stack(support_set, dim=0)
weights = torch.einsum(
"nd,dkio->nkio", self.node_embeddings, self.weights_pool
) # N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) # N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) # B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum("bnki,nkio->bno", x_g, weights) + bias # b, N, dim_out
return z