TrafficWheel/model/TWDGCN/ConnectionMatrix.py

187 lines
5.6 KiB
Python
Executable File

import torch
class ConnectionMatrix:
"""
Computes connectivity matrices from input tensors.
Input:
- Tensor of shape (seq_len, batch_size, num_node * input_dim)
Output:
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
"""
def __init__(self):
self.input_dim = None
self.num_node = None
self.seq_len = None
self.batch_size = None
def blank(self, x):
"""
Placeholder function that currently does nothing.
"""
return x
def get(self, inputs: torch.Tensor):
"""
Computes the connectivity matrices for each time step.
Parameters:
- inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim)
Returns:
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
"""
self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size()
inputs = inputs.permute(1, 0, 2, 3).reshape(
self.seq_len, self.batch_size, self.num_node
)
connection_matrix = []
t0_matrix = self.t0(inputs[0])
connection_matrix.append(t0_matrix)
for t in range(1, self.seq_len):
tn_matrix = self.tn(inputs[t], inputs[t - 1], connection_matrix[-1])
connection_matrix.append(tn_matrix)
return torch.stack(connection_matrix)
def t0(self, t0_input):
"""
Computes the connectivity matrix for time t=0.
Parameters:
- t0_input: Tensor of shape (batch_size, num_node)
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
beta = self.beta(t0_input)
result = torch.where(beta >= 0.5, 1.0, 0.0).to(t0_input.device)
return result
def tn(self, tn_input, former_input, former_matrix):
"""
Computes the connectivity matrix for time t > 0.
Parameters:
- tn_input: Tensor of shape (batch_size, num_node)
- former_input: Tensor of shape (batch_size, num_node) from previous time step
- former_matrix: Tensor of shape (batch_size, num_node, num_node) from previous time step
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
gamma = self.gamma(tn_input, former_input)
beta = self.beta(tn_input)
# Determine where to update the matrix
condition = ((gamma > 0.5) & (beta < 0.5)) | ((gamma <= 0.5) & (beta >= 0.5))
update_value = torch.where(gamma > 0.5, -1.0, 1.0).to(tn_input.device)
result = torch.where(condition, update_value, former_matrix)
return result
def delta_ij(self, inputs):
"""
Computes the delta_ij tensor.
Parameters:
- inputs: Tensor of shape (batch_size, num_node)
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
return torch.abs(inputs.unsqueeze(-1) - inputs.unsqueeze(-2)).to(inputs.device)
def dynamic_edge(self, inputs_t, inputs_t_minus_1):
"""
Computes the dynamic edge tensor.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
delta_ij_t = self.delta_ij(inputs_t)
delta_ij_t_minus_1 = self.delta_ij(inputs_t_minus_1)
return torch.abs(delta_ij_t - delta_ij_t_minus_1).to(inputs_t.device)
def beta(self, inputs):
"""
Computes the beta tensor.
Parameters:
- inputs: Tensor of shape (batch_size, num_node)
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
delta_ij = self.delta_ij(inputs)
return self.exponential_calculation(delta_ij)
def gamma(self, inputs_t, inputs_t_minus_1):
"""
Computes the gamma tensor.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- Tensor of shape (batch_size, num_node, num_node)
"""
dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1)
return self.exponential_calculation(dynamic_edge)
def tn_disappear_rate(self, inputs_t, inputs_t_minus_1):
"""
Computes st_rate, dy_rate, and disappear_rate for time t.
Parameters:
- inputs_t: Tensor for time t
- inputs_t_minus_1: Tensor for time t-1
Returns:
- Tuple of (st_rate, dy_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
"""
gamma = self.gamma(inputs_t, inputs_t_minus_1)
beta = self.beta(inputs_t)
dy_rate = gamma
st_rate = beta * (1 - gamma)
disappear_rate = (1 - beta) * (1 - gamma)
return st_rate, dy_rate, disappear_rate
def t0_disappear_rate(self, inputs_t):
"""
Computes st_rate and disappear_rate for time t = 0.
Parameters:
- inputs_t: Tensor for time t
Returns:
- Tuple of (st_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
"""
beta = self.beta(inputs_t)
st_rate = beta
disappear_rate = 1 - st_rate
return st_rate, disappear_rate
def exponential_calculation(self, tensor):
"""
Computes the exponential decay of the tensor.
Parameters:
- tensor: Tensor of any shape
Returns:
- Tensor of the same shape with exponential decay applied
"""
return torch.exp(-tensor).to(tensor.device)