200 lines
6.9 KiB
Python
200 lines
6.9 KiB
Python
import torch
|
|
|
|
class ConnectionMatrix2:
|
|
"""
|
|
Computes connectivity matrices from input tensors.
|
|
|
|
Input:
|
|
- Tensor of shape (seq_len, batch_size, num_node * input_dim)
|
|
|
|
Output:
|
|
- Tuple of tensors each of shape (seq_len, num_node, num_node)
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.input_dim = None
|
|
self.num_node = None
|
|
self.seq_len = None
|
|
self.batch_size = None
|
|
|
|
def blank(self, x):
|
|
"""
|
|
Placeholder function.
|
|
"""
|
|
return x
|
|
|
|
def get(self, inputs: torch.Tensor):
|
|
"""
|
|
Computes various connectivity matrices from the inputs.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim)
|
|
|
|
Returns:
|
|
- Tuple of tensors (origin_tensor, delta_ij_tensor, dynamic_edge_tensor,
|
|
om_bata_tensor, beta_tensor, om_gamma_tensor,
|
|
gamma_tensor, st_rate_tensor, dy_rate_tensor, disappear_rate_tensor)
|
|
"""
|
|
self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size()
|
|
inputs = inputs.permute(1, 0, 2, 3).reshape(self.seq_len, self.batch_size, self.num_node)
|
|
|
|
origin_list, delta_ij_list, dynamic_edge_list = [], [], []
|
|
om_bata_list, beta_list, om_gamma_list = [], [], []
|
|
gamma_list, st_rate_list, dy_rate_list, disappear_rate_list = [], [], [], []
|
|
|
|
for t in range(self.seq_len):
|
|
if t == 0:
|
|
delta_ij, origin = self.delta_ij(inputs[t])
|
|
beta, om_bata = self.beta(inputs[t])
|
|
st_rate, disappear_rate = self.t0_disappear_rate(inputs[t])
|
|
dynamic_edge = torch.zeros_like(delta_ij).to('cuda')
|
|
gamma = torch.zeros_like(delta_ij).to('cuda')
|
|
om_gamma = torch.zeros_like(delta_ij).to('cuda')
|
|
else:
|
|
delta_ij, origin = self.delta_ij(inputs[t])
|
|
dynamic_edge = self.dynamic_edge(inputs[t], inputs[t - 1])
|
|
beta, om_bata = self.beta(inputs[t])
|
|
gamma, om_gamma = self.gamma(inputs[t], inputs[t - 1])
|
|
st_rate, dy_rate, disappear_rate = self.tn_disappear_rate(inputs[t], inputs[t - 1])
|
|
|
|
# Collect results
|
|
origin_list.append(origin)
|
|
delta_ij_list.append(delta_ij)
|
|
dynamic_edge_list.append(dynamic_edge)
|
|
om_bata_list.append(om_bata)
|
|
beta_list.append(beta)
|
|
om_gamma_list.append(om_gamma)
|
|
gamma_list.append(gamma)
|
|
st_rate_list.append(st_rate)
|
|
dy_rate_list.append(dy_rate)
|
|
disappear_rate_list.append(disappear_rate)
|
|
|
|
# Convert lists to tensors
|
|
origin_tensor = torch.stack(origin_list)
|
|
delta_ij_tensor = torch.stack(delta_ij_list)
|
|
dynamic_edge_tensor = torch.stack(dynamic_edge_list)
|
|
om_bata_tensor = torch.stack(om_bata_list)
|
|
beta_tensor = torch.stack(beta_list)
|
|
om_gamma_tensor = torch.stack(om_gamma_list)
|
|
gamma_tensor = torch.stack(gamma_list)
|
|
st_rate_tensor = torch.stack(st_rate_list)
|
|
dy_rate_tensor = torch.stack(dy_rate_list)
|
|
disappear_rate_tensor = torch.stack(disappear_rate_list)
|
|
|
|
return (origin_tensor, delta_ij_tensor, dynamic_edge_tensor, om_bata_tensor,
|
|
beta_tensor, om_gamma_tensor, gamma_tensor, st_rate_tensor,
|
|
dy_rate_tensor, disappear_rate_tensor)
|
|
|
|
def delta_ij(self, inputs):
|
|
"""
|
|
Calculates delta_ij and origin.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- delta_ij: Tensor of shape (batch_size, num_node, num_node)
|
|
- origin: Tensor of shape (batch_size, num_node, 1)
|
|
"""
|
|
origin = inputs.unsqueeze(1)
|
|
delta_ij = inputs.unsqueeze(2) - inputs.unsqueeze(1)
|
|
return abs(delta_ij), origin
|
|
|
|
def dynamic_edge(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Calculates the dynamic edge tensor.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- dynamic_edge: Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij_t, _ = self.delta_ij(inputs_t)
|
|
delta_ij_t_minus_1, _ = self.delta_ij(inputs_t_minus_1)
|
|
dynamic_edge = abs(delta_ij_t - delta_ij_t_minus_1)
|
|
return dynamic_edge
|
|
|
|
def beta(self, inputs):
|
|
"""
|
|
Calculates beta and om_bata.
|
|
|
|
Parameters:
|
|
- inputs: Tensor of shape (batch_size, num_node)
|
|
|
|
Returns:
|
|
- beta: Tensor of shape (batch_size, num_node, num_node)
|
|
- om_bata: Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
delta_ij, _ = self.delta_ij(inputs)
|
|
beta = self.exponential_calculation(delta_ij)
|
|
om_bata = 1 - beta
|
|
return beta, om_bata
|
|
|
|
def gamma(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Calculates gamma and om_gamma.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- gamma: Tensor of shape (batch_size, num_node, num_node)
|
|
- om_gamma: Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1)
|
|
gamma = self.exponential_calculation(dynamic_edge)
|
|
om_gamma = 1 - gamma
|
|
return gamma, om_gamma
|
|
|
|
def tn_disappear_rate(self, inputs_t, inputs_t_minus_1):
|
|
"""
|
|
Calculates st_rate, dy_rate, and disappear_rate for time t.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
- inputs_t_minus_1: Tensor for time t-1
|
|
|
|
Returns:
|
|
- st_rate: Tensor of shape (batch_size, num_node, num_node)
|
|
- dy_rate: Tensor of shape (batch_size, num_node, num_node)
|
|
- disappear_rate: Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
gamma, _ = self.gamma(inputs_t, inputs_t_minus_1)
|
|
beta, _ = self.beta(inputs_t)
|
|
|
|
dy_rate = gamma
|
|
st_rate = beta * (1 - gamma)
|
|
disappear_rate = (1 - beta) * (1 - gamma)
|
|
return st_rate, dy_rate, disappear_rate
|
|
|
|
def t0_disappear_rate(self, inputs_t):
|
|
"""
|
|
Calculates st_rate and disappear_rate for time t = 0.
|
|
|
|
Parameters:
|
|
- inputs_t: Tensor for time t
|
|
|
|
Returns:
|
|
- st_rate: Tensor of shape (batch_size, num_node, num_node)
|
|
- disappear_rate: Tensor of shape (batch_size, num_node, num_node)
|
|
"""
|
|
beta, _ = self.beta(inputs_t)
|
|
st_rate = beta
|
|
disappear_rate = 1 - st_rate
|
|
return st_rate, disappear_rate
|
|
|
|
def exponential_calculation(self, tensor):
|
|
"""
|
|
Calculates the exponential decay of the tensor.
|
|
|
|
Parameters:
|
|
- tensor: Tensor of any shape
|
|
|
|
Returns:
|
|
- Tensor of the same shape with exponential decay applied
|
|
"""
|
|
return torch.exp(-tensor)
|