184 lines
5.5 KiB
Python
184 lines
5.5 KiB
Python
import torch
|
|
|
|
|
|
class ConnectionMatrix:
|
|
"""
|
|
Computes connectivity matrices from input tensors.
|
|
|
|
Input:
|
|
- Tensor of shape (seq_len, batch_size, num_node * input_dim)
|
|
|
|
Output:
|
|
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.input_dim = None
|
|
self.num_node = None
|
|
self.seq_len = None
|
|
self.batch_size = None
|
|
|
|
def blank(self, x):
|
|
"""
|
|
Placeholder function that currently does nothing.
|
|
"""
|
|
return x
|
|
|
|
def get(self, inputs: torch.Tensor):
|
|
"""
|
|
Computes the connectivity matrices for each time step.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim)
|
|
|
|
Returns:
|
|
- Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices
|
|
"""
|
|
self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size()
|
|
inputs = inputs.permute(1, 0, 2, 3).reshape(self.seq_len, self.batch_size, self.num_node)
|
|
|
|
connection_matrix = []
|
|
t0_matrix = self.t0(inputs[0])
|
|
connection_matrix.append(t0_matrix)
|
|
|
|
for t in range(1, self.seq_len):
|
|
tn_matrix = self.tn(inputs[t], inputs[t - 1], connection_matrix[-1])
|
|
connection_matrix.append(tn_matrix)
|
|
|
|
return torch.stack(connection_matrix)
|
|
|
|
def t0(self, t0_input):
|
|
"""
|
|
Computes the connectivity matrix for time t=0.
|
|
|
|
Parameters:
|
|
- t0_input: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
beta = self.beta(t0_input)
|
|
result = torch.where(beta >= 0.5, 1.0, 0.0).to(t0_input.device)
|
|
return result
|
|
|
|
def tn(self, tn_input, former_input, former_matrix):
|
|
"""
|
|
Computes the connectivity matrix for time t > 0.
|
|
|
|
Parameters:
|
|
- tn_input: Tensor of shape (batch_size, num_node)
|
|
- former_input: Tensor of shape (batch_size, num_node) from previous time step
|
|
- former_matrix: Tensor of shape (batch_size, num_node, num_node) from previous time step
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
gamma = self.gamma(tn_input, former_input)
|
|
beta = self.beta(tn_input)
|
|
|
|
# Determine where to update the matrix
|
|
condition = ((gamma > 0.5) & (beta < 0.5)) | ((gamma <= 0.5) & (beta >= 0.5))
|
|
update_value = torch.where(gamma > 0.5, -1.0, 1.0).to(tn_input.device)
|
|
|
|
result = torch.where(condition, update_value, former_matrix)
|
|
return result
|
|
|
|
def delta_ij(self, inputs):
|
|
"""
|
|
Computes the delta_ij tensor.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
return torch.abs(inputs.unsqueeze(-1) - inputs.unsqueeze(-2)).to(inputs.device)
|
|
|
|
def dynamic_edge(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes the dynamic edge tensor.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij_t = self.delta_ij(inputs_t)
|
|
delta_ij_t_minus_1 = self.delta_ij(inputs_t_minus_1)
|
|
return torch.abs(delta_ij_t - delta_ij_t_minus_1).to(inputs_t.device)
|
|
|
|
def beta(self, inputs):
|
|
"""
|
|
Computes the beta tensor.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij = self.delta_ij(inputs)
|
|
return self.exponential_calculation(delta_ij)
|
|
|
|
def gamma(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes the gamma tensor.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1)
|
|
return self.exponential_calculation(dynamic_edge)
|
|
|
|
def tn_disappear_rate(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Computes st_rate, dy_rate, and disappear_rate for time t.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- Tuple of (st_rate, dy_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
|
|
"""
|
|
gamma = self.gamma(inputs_t, inputs_t_minus_1)
|
|
beta = self.beta(inputs_t)
|
|
|
|
dy_rate = gamma
|
|
st_rate = beta * (1 - gamma)
|
|
disappear_rate = (1 - beta) * (1 - gamma)
|
|
return st_rate, dy_rate, disappear_rate
|
|
|
|
def t0_disappear_rate(self, inputs_t):
|
|
"""
|
|
Computes st_rate and disappear_rate for time t = 0.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
|
|
Returns:
|
|
- Tuple of (st_rate, disappear_rate) each of shape (batch_size, num_node, num_node)
|
|
"""
|
|
beta = self.beta(inputs_t)
|
|
st_rate = beta
|
|
disappear_rate = 1 - st_rate
|
|
return st_rate, disappear_rate
|
|
|
|
def exponential_calculation(self, tensor):
|
|
"""
|
|
Computes the exponential decay of the tensor.
|
|
|
|
Parameters:
|
|
- tensor: Tensor of any shape
|
|
|
|
Returns:
|
|
- Tensor of the same shape with exponential decay applied
|
|
"""
|
|
return torch.exp(-tensor).to(tensor.device) |