TrafficWheel/model/TWDGCN/ConnectionMatrix2.py

214 lines
7.0 KiB
Python
Executable File

import torch
class ConnectionMatrix2:
"""
Computes connectivity matrices from input tensors.
Input:
- Tensor of shape (seq_len, batch_size, num_node * input_dim)
Output:
- Tuple of tensors each of shape (seq_len, num_node, num_node)
"""
def __init__(self):
self.input_dim = None
self.num_node = None
self.seq_len = None
self.batch_size = None
def blank(self, x):
"""
Placeholder function.
"""
return x
def get(self, inputs: torch.Tensor):
"""
Computes various connectivity matrices from the inputs.
Parameters:
- inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim)
Returns:
- Tuple of tensors (origin_tensor, delta_ij_tensor, dynamic_edge_tensor,
om_bata_tensor, beta_tensor, om_gamma_tensor,
gamma_tensor, st_rate_tensor, dy_rate_tensor, disappear_rate_tensor)
"""
self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size()
inputs = inputs.permute(1, 0, 2, 3).reshape(
self.seq_len, self.batch_size, self.num_node
)
origin_list, delta_ij_list, dynamic_edge_list = [], [], []
om_bata_list, beta_list, om_gamma_list = [], [], []
gamma_list, st_rate_list, dy_rate_list, disappear_rate_list = [], [], [], []
for t in range(self.seq_len):
if t == 0:
delta_ij, origin = self.delta_ij(inputs[t])
beta, om_bata = self.beta(inputs[t])
st_rate, disappear_rate = self.t0_disappear_rate(inputs[t])
dynamic_edge = torch.zeros_like(delta_ij).to("cuda")
gamma = torch.zeros_like(delta_ij).to("cuda")
om_gamma = torch.zeros_like(delta_ij).to("cuda")
else:
delta_ij, origin = self.delta_ij(inputs[t])
dynamic_edge = self.dynamic_edge(inputs[t], inputs[t - 1])
beta, om_bata = self.beta(inputs[t])
gamma, om_gamma = self.gamma(inputs[t], inputs[t - 1])
st_rate, dy_rate, disappear_rate = self.tn_disappear_rate(
inputs[t], inputs[t - 1]
)
# Collect results
origin_list.append(origin)
delta_ij_list.append(delta_ij)
dynamic_edge_list.append(dynamic_edge)
om_bata_list.append(om_bata)
beta_list.append(beta)
om_gamma_list.append(om_gamma)
gamma_list.append(gamma)
st_rate_list.append(st_rate)
dy_rate_list.append(dy_rate)
disappear_rate_list.append(disappear_rate)
# Convert lists to tensors
origin_tensor = torch.stack(origin_list)
delta_ij_tensor = torch.stack(delta_ij_list)
dynamic_edge_tensor = torch.stack(dynamic_edge_list)
om_bata_tensor = torch.stack(om_bata_list)
beta_tensor = torch.stack(beta_list)
om_gamma_tensor = torch.stack(om_gamma_list)
gamma_tensor = torch.stack(gamma_list)
st_rate_tensor = torch.stack(st_rate_list)
dy_rate_tensor = torch.stack(dy_rate_list)
disappear_rate_tensor = torch.stack(disappear_rate_list)
return (
origin_tensor,
delta_ij_tensor,
dynamic_edge_tensor,
om_bata_tensor,
beta_tensor,
om_gamma_tensor,
gamma_tensor,
st_rate_tensor,
dy_rate_tensor,
disappear_rate_tensor,
)
def delta_ij(self, inputs):
"""
Calculates delta_ij and origin.
Parameters:
- inputs: Tensor of shape (batch_size, num_node)
Returns:
- delta_ij: Tensor of shape (batch_size, num_node, num_node)
- origin: Tensor of shape (batch_size, num_node, 1)
"""
origin = inputs.unsqueeze(1)
delta_ij = inputs.unsqueeze(2) - inputs.unsqueeze(1)
return abs(delta_ij), origin
def dynamic_edge(self, inputs_t, inputs_t_minus_1):
"""
Calculates the dynamic edge tensor.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- dynamic_edge: Tensor of shape (batch_size, num_node, num_node)
"""
delta_ij_t, _ = self.delta_ij(inputs_t)
delta_ij_t_minus_1, _ = self.delta_ij(inputs_t_minus_1)
dynamic_edge = abs(delta_ij_t - delta_ij_t_minus_1)
return dynamic_edge
def beta(self, inputs):
"""
Calculates beta and om_bata.
Parameters:
- inputs: Tensor of shape (batch_size, num_node)
Returns:
- beta: Tensor of shape (batch_size, num_node, num_node)
- om_bata: Tensor of shape (batch_size, num_node, num_node)
"""
delta_ij, _ = self.delta_ij(inputs)
beta = self.exponential_calculation(delta_ij)
om_bata = 1 - beta
return beta, om_bata
def gamma(self, inputs_t, inputs_t_minus_1):
"""
Calculates gamma and om_gamma.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- gamma: Tensor of shape (batch_size, num_node, num_node)
- om_gamma: Tensor of shape (batch_size, num_node, num_node)
"""
dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1)
gamma = self.exponential_calculation(dynamic_edge)
om_gamma = 1 - gamma
return gamma, om_gamma
def tn_disappear_rate(self, inputs_t, inputs_t_minus_1):
"""
Calculates st_rate, dy_rate, and disappear_rate for time t.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- st_rate: Tensor of shape (batch_size, num_node, num_node)
- dy_rate: Tensor of shape (batch_size, num_node, num_node)
- disappear_rate: Tensor of shape (batch_size, num_node, num_node)
"""
gamma, _ = self.gamma(inputs_t, inputs_t_minus_1)
beta, _ = self.beta(inputs_t)
dy_rate = gamma
st_rate = beta * (1 - gamma)
disappear_rate = (1 - beta) * (1 - gamma)
return st_rate, dy_rate, disappear_rate
def t0_disappear_rate(self, inputs_t):
"""
Calculates st_rate and disappear_rate for time t = 0.
Parameters:
- inputs_t: Tensor for time t
Returns:
- st_rate: Tensor of shape (batch_size, num_node, num_node)
- disappear_rate: Tensor of shape (batch_size, num_node, num_node)
"""
beta, _ = self.beta(inputs_t)
st_rate = beta
disappear_rate = 1 - st_rate
return st_rate, disappear_rate
def exponential_calculation(self, tensor):
"""
Calculates the exponential decay of the tensor.
Parameters:
- tensor: Tensor of any shape
Returns:
- Tensor of the same shape with exponential decay applied
"""
return torch.exp(-tensor)