189 lines
5.6 KiB
Python
Executable File
189 lines
5.6 KiB
Python
Executable File
import torch
|
|
|
|
|
|
class ConnectionMatrix:
|
|
"""
|
|
Computes connectivity matrices from input tensors.
|
|
|
|
Input:
|
|
- Tensor of shape (seq_len, batch_size, num_node * input_dim)
|
|
|
|
Output:
|
|
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.input_dim = None
|
|
self.num_node = None
|
|
self.seq_len = None
|
|
self.batch_size = None
|
|
|
|
def blank(self, x):
|
|
"""
|
|
Placeholder function that currently does nothing.
|
|
"""
|
|
return x
|
|
|
|
def get(self, inputs: torch.Tensor):
|
|
"""
|
|
Computes the connectivity matrices for each time step.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim)
|
|
|
|
Returns:
|
|
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
|
|
"""
|
|
self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size()
|
|
inputs = inputs.permute(1, 0, 2, 3).reshape(
|
|
self.seq_len, self.batch_size, self.num_node
|
|
)
|
|
|
|
connection_matrix = [self.t0(inputs[0])] # Compute t=0 connectivity matrix
|
|
for t in range(1, self.seq_len):
|
|
t_matrix = self.tn(inputs[t], inputs[t - 1], connection_matrix[-1])
|
|
connection_matrix.append(t_matrix)
|
|
|
|
return torch.stack(connection_matrix)
|
|
|
|
def t0(self, t0_input):
|
|
"""
|
|
Computes the connectivity matrix for time t=0.
|
|
|
|
Parameters:
|
|
- t0_input: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
result = torch.zeros(self.batch_size, self.num_node, self.num_node).to("cpu")
|
|
beta = self.beta(t0_input)
|
|
beta_rate = beta
|
|
disappear_rate = 1 - beta_rate
|
|
|
|
result[beta_rate >= disappear_rate] = 1
|
|
return result
|
|
|
|
def tn(self, tn_input, former_input, former_matrix):
|
|
"""
|
|
Computes the connectivity matrix for time t > 0.
|
|
|
|
Parameters:
|
|
- tn_input: Tensor of shape (batch_size, num_node)
|
|
- former_input: Tensor of shape (batch_size, num_node) from previous time step
|
|
- former_matrix: Tensor of shape (batch_size, num_node, num_node) from previous time step
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
result = torch.zeros(self.batch_size, self.num_node, self.num_node).to("cpu")
|
|
gamma_rate = self.gamma(tn_input, former_input)
|
|
beta_rate = self.beta(tn_input)
|
|
|
|
result[gamma_rate > 0.5] = -1
|
|
result[(gamma_rate <= 0.5) & (beta_rate >= 0.5)] = 1
|
|
result = torch.where(result == -1, former_matrix, result)
|
|
|
|
return result
|
|
|
|
def delta_ij(self, inputs):
|
|
"""
|
|
Computes the delta_ij tensor.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij = inputs.unsqueeze(2) - inputs.unsqueeze(1)
|
|
return abs(delta_ij)
|
|
|
|
def dynamic_edge(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes the dynamic edge tensor.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij_t = self.delta_ij(inputs_t)
|
|
delta_ij_t_minus_1 = self.delta_ij(inputs_t_minus_1)
|
|
return abs(delta_ij_t - delta_ij_t_minus_1)
|
|
|
|
def beta(self, inputs):
|
|
"""
|
|
Computes the beta tensor.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij = self.delta_ij(inputs)
|
|
return self.exponential_calculation(delta_ij)
|
|
|
|
def gamma(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes the gamma tensor.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1)
|
|
return self.exponential_calculation(dynamic_edge)
|
|
|
|
def tn_disappear_rate(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes st_rate, dy_rate, and disappear_rate for time t.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tuple of (st_rate, dy_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
|
|
"""
|
|
gamma = self.gamma(inputs_t, inputs_t_minus_1)
|
|
beta = self.beta(inputs_t)
|
|
|
|
dy_rate = gamma
|
|
st_rate = beta * (1 - gamma)
|
|
disappear_rate = (1 - beta) * (1 - gamma)
|
|
return st_rate, dy_rate, disappear_rate
|
|
|
|
def t0_disappear_rate(self, inputs_t):
|
|
"""
|
|
Computes st_rate and disappear_rate for time t = 0.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
|
|
Returns:
|
|
- Tuple of (st_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
|
|
"""
|
|
beta = self.beta(inputs_t)
|
|
st_rate = beta
|
|
disappear_rate = 1 - st_rate
|
|
return st_rate, disappear_rate
|
|
|
|
def exponential_calculation(self, tensor):
|
|
"""
|
|
Computes the exponential decay of the tensor.
|
|
|
|
Parameters:
|
|
- tensor: Tensor of any shape
|
|
|
|
Returns:
|
|
- Tensor of the same shape with exponential decay applied
|
|
"""
|
|
return torch.exp(-tensor)
|