TrafficWheel/model/STGNRDE/vector_fields.py

341 lines
16 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
class FinalTanh_f(nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers):
super(FinalTanh_f, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels)
self.linears = nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
for _ in range(num_hidden_layers - 1))
self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, *args):
z = args[0] if len(args) == 1 else args[1]
z = self.linear_in(z)
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
# z: torch.Size([64, 207, 32])
# self.linear_out(z): torch.Size([64, 207, 64])
z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = z.tanh()
return z
class FinalTanh_f_prime(nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers):
super(FinalTanh_f_prime, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = nn.Linear(hidden_channels, hidden_hidden_channels)
self.linears = nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
for _ in range(num_hidden_layers - 1))
# self.linear_out = nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.linear_out = nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, *args):
z = args[0] if len(args) == 1 else args[1]
z = self.linear_in(z)
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
# z: torch.Size([64, 207, 32])
# self.linear_out(z): torch.Size([64, 207, 64])
# z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels)
z = z.tanh()
return z
class FinalTanh_f2(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers):
super(FinalTanh_f2, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
# self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
# self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.start_conv = torch.nn.Conv2d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=(1,1))
# self.linear = torch.nn.Conv2d(in_channels=hidden_channels,
# out_channels=hidden_channels,
# kernel_size=(1,1))
self.linears = torch.nn.ModuleList(torch.nn.Conv2d(in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=(1,1))
for _ in range(num_hidden_layers - 1))
self.linear_out = torch.nn.Conv2d(in_channels=hidden_channels,
out_channels=input_channels*hidden_channels,
kernel_size=(1,1))
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, *args):
# z: torch.Size([64, 207, 32])
z = args[0] if len(args) == 1 else args[1]
z = self.start_conv(z.transpose(1,2).unsqueeze(-1))
z = z.relu()
for linear in self.linears:
z = linear(z)
z = z.relu()
z = self.linear_out(z).squeeze().transpose(1,2).view(*z.transpose(1,2).shape[:-2], self.hidden_channels, self.input_channels)
z = z.tanh()
return z
class VectorField_g(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim,
g_type):
super(VectorField_g, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
# project to (hidden_channels, input_channels) for cdeint requirement
self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * input_channels)
self.g_type = g_type
if self.g_type == 'agc':
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels))
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, *args):
z = args[0] if len(args) == 1 else args[1]
z = self.linear_in(z)
z = z.relu()
if self.g_type == 'agc':
z = self.agc(z)
else:
raise ValueError('Check g_type argument')
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
# output shape (..., hidden_channels, input_channels)
z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = z.tanh()
return z #torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1)
# laplacian=False
laplacian=False
if laplacian == True:
# support_set = [torch.eye(node_num).to(supports.device), -supports]
support_set = [supports, -torch.eye(node_num).to(supports.device)]
# support_set = [torch.eye(node_num).to(supports.device), -supports]
# support_set = [-supports]
else:
support_set = [torch.eye(node_num).to(supports.device), supports]
#default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
supports = torch.stack(support_set, dim=0)
weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
return z
class VectorField_only_g(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim,
g_type):
super(VectorField_only_g, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
#FIXME:
self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
# self.linear_out = torch.nn.Linear(hidden_hidden_channels, hidden_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.g_type = g_type
if self.g_type == 'agc':
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels))
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, *args):
z = args[0] if len(args) == 1 else args[1]
z = self.linear_in(z)
z = z.relu()
if self.g_type == 'agc':
z = self.agc(z)
else:
raise ValueError('Check g_type argument')
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
#FIXME:
z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
# z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.hidden_channels)
z = z.tanh()
return z #torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1)
laplacian=False
if laplacian == True:
# support_set = [torch.eye(node_num).to(supports.device), -supports]
support_set = [supports, -torch.eye(node_num).to(supports.device)]
# support_set = [torch.eye(node_num).to(supports.device), -supports]
# support_set = [-supports]
else:
support_set = [torch.eye(node_num).to(supports.device), supports]
#default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
supports = torch.stack(support_set, dim=0)
weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
return z
class VectorField_g_prime(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_hidden_layers, num_nodes, cheb_k, embed_dim,
g_type):
super(VectorField_g_prime, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_hidden_channels = hidden_hidden_channels
self.num_hidden_layers = num_hidden_layers
self.linear_in = torch.nn.Linear(hidden_channels, hidden_hidden_channels)
# self.linears = torch.nn.ModuleList(torch.nn.Linear(hidden_hidden_channels, hidden_hidden_channels)
# for _ in range(num_hidden_layers - 1))
self.linear_out = torch.nn.Linear(hidden_hidden_channels, input_channels * hidden_channels) #32,32*4 -> # 32,32,4
self.g_type = g_type
if self.g_type == 'agc':
self.node_embeddings = nn.Parameter(torch.randn(num_nodes, embed_dim), requires_grad=True)
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, hidden_hidden_channels, hidden_hidden_channels))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, hidden_hidden_channels))
def extra_repr(self):
return "input_channels: {}, hidden_channels: {}, hidden_hidden_channels: {}, num_hidden_layers: {}" \
"".format(self.input_channels, self.hidden_channels, self.hidden_hidden_channels, self.num_hidden_layers)
def forward(self, z):
z = self.linear_in(z)
z = z.relu()
if self.g_type == 'agc':
z = self.agc(z)
else:
raise ValueError('Check g_type argument')
# for linear in self.linears:
# z = linear(x_gconv)
# z = z.relu()
z = self.linear_out(z).view(*z.shape[:-1], self.hidden_channels, self.input_channels)
z = z.tanh()
return z #torch.Size([64, 307, 64, 1])
def agc(self, z):
"""
Adaptive Graph Convolution
- Node Adaptive Parameter Learning
- Data Adaptive Graph Generation
"""
node_num = self.node_embeddings.shape[0]
supports = F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1)
support_set = [torch.eye(node_num).to(supports.device), supports]
#default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
supports = torch.stack(support_set, dim=0)
weights = torch.einsum('nd,dkio->nkio', self.node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
bias = torch.matmul(self.node_embeddings, self.bias_pool) #N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, z) #B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
z = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
return z