import torch class ConnectionMatrix2: """ Computes connectivity matrices from input tensors. Input: - Tensor of shape (seq_len, batch_size, num_node * input_dim) Output: - Tuple of tensors each of shape (seq_len, num_node, num_node) """ def __init__(self): self.input_dim = None self.num_node = None self.seq_len = None self.batch_size = None def blank(self, x): """ Placeholder function. """ return x def get(self, inputs: torch.Tensor): """ Computes various connectivity matrices from the inputs. Parameters: - inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim) Returns: - Tuple of tensors (origin_tensor, delta_ij_tensor, dynamic_edge_tensor, om_bata_tensor, beta_tensor, om_gamma_tensor, gamma_tensor, st_rate_tensor, dy_rate_tensor, disappear_rate_tensor) """ self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size() inputs = inputs.permute(1, 0, 2, 3).reshape(self.seq_len, self.batch_size, self.num_node) origin_list, delta_ij_list, dynamic_edge_list = [], [], [] om_bata_list, beta_list, om_gamma_list = [], [], [] gamma_list, st_rate_list, dy_rate_list, disappear_rate_list = [], [], [], [] for t in range(self.seq_len): if t == 0: delta_ij, origin = self.delta_ij(inputs[t]) beta, om_bata = self.beta(inputs[t]) st_rate, disappear_rate = self.t0_disappear_rate(inputs[t]) dynamic_edge = torch.zeros_like(delta_ij).to('cuda') gamma = torch.zeros_like(delta_ij).to('cuda') om_gamma = torch.zeros_like(delta_ij).to('cuda') else: delta_ij, origin = self.delta_ij(inputs[t]) dynamic_edge = self.dynamic_edge(inputs[t], inputs[t - 1]) beta, om_bata = self.beta(inputs[t]) gamma, om_gamma = self.gamma(inputs[t], inputs[t - 1]) st_rate, dy_rate, disappear_rate = self.tn_disappear_rate(inputs[t], inputs[t - 1]) # Collect results origin_list.append(origin) delta_ij_list.append(delta_ij) dynamic_edge_list.append(dynamic_edge) om_bata_list.append(om_bata) beta_list.append(beta) om_gamma_list.append(om_gamma) gamma_list.append(gamma) st_rate_list.append(st_rate) dy_rate_list.append(dy_rate) disappear_rate_list.append(disappear_rate) # Convert lists to tensors origin_tensor = torch.stack(origin_list) delta_ij_tensor = torch.stack(delta_ij_list) dynamic_edge_tensor = torch.stack(dynamic_edge_list) om_bata_tensor = torch.stack(om_bata_list) beta_tensor = torch.stack(beta_list) om_gamma_tensor = torch.stack(om_gamma_list) gamma_tensor = torch.stack(gamma_list) st_rate_tensor = torch.stack(st_rate_list) dy_rate_tensor = torch.stack(dy_rate_list) disappear_rate_tensor = torch.stack(disappear_rate_list) return (origin_tensor, delta_ij_tensor, dynamic_edge_tensor, om_bata_tensor, beta_tensor, om_gamma_tensor, gamma_tensor, st_rate_tensor, dy_rate_tensor, disappear_rate_tensor) def delta_ij(self, inputs): """ Calculates delta_ij and origin. Parameters: - inputs: Tensor of shape (batch_size, num_node) Returns: - delta_ij: Tensor of shape (batch_size, num_node, num_node) - origin: Tensor of shape (batch_size, num_node, 1) """ origin = inputs.unsqueeze(1) delta_ij = inputs.unsqueeze(2) - inputs.unsqueeze(1) return abs(delta_ij), origin def dynamic_edge(self, inputs_t, inputs_t_minus_1): """ Calculates the dynamic edge tensor. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - dynamic_edge: Tensor of shape (batch_size, num_node, num_node) """ delta_ij_t, _ = self.delta_ij(inputs_t) delta_ij_t_minus_1, _ = self.delta_ij(inputs_t_minus_1) dynamic_edge = abs(delta_ij_t - delta_ij_t_minus_1) return dynamic_edge def beta(self, inputs): """ Calculates beta and om_bata. Parameters: - inputs: Tensor of shape (batch_size, num_node) Returns: - beta: Tensor of shape (batch_size, num_node, num_node) - om_bata: Tensor of shape (batch_size, num_node, num_node) """ delta_ij, _ = self.delta_ij(inputs) beta = self.exponential_calculation(delta_ij) om_bata = 1 - beta return beta, om_bata def gamma(self, inputs_t, inputs_t_minus_1): """ Calculates gamma and om_gamma. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - gamma: Tensor of shape (batch_size, num_node, num_node) - om_gamma: Tensor of shape (batch_size, num_node, num_node) """ dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1) gamma = self.exponential_calculation(dynamic_edge) om_gamma = 1 - gamma return gamma, om_gamma def tn_disappear_rate(self, inputs_t, inputs_t_minus_1): """ Calculates st_rate, dy_rate, and disappear_rate for time t. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - st_rate: Tensor of shape (batch_size, num_node, num_node) - dy_rate: Tensor of shape (batch_size, num_node, num_node) - disappear_rate: Tensor of shape (batch_size, num_node, num_node) """ gamma, _ = self.gamma(inputs_t, inputs_t_minus_1) beta, _ = self.beta(inputs_t) dy_rate = gamma st_rate = beta * (1 - gamma) disappear_rate = (1 - beta) * (1 - gamma) return st_rate, dy_rate, disappear_rate def t0_disappear_rate(self, inputs_t): """ Calculates st_rate and disappear_rate for time t = 0. Parameters: - inputs_t: Tensor for time t Returns: - st_rate: Tensor of shape (batch_size, num_node, num_node) - disappear_rate: Tensor of shape (batch_size, num_node, num_node) """ beta, _ = self.beta(inputs_t) st_rate = beta disappear_rate = 1 - st_rate return st_rate, disappear_rate def exponential_calculation(self, tensor): """ Calculates the exponential decay of the tensor. Parameters: - tensor: Tensor of any shape Returns: - Tensor of the same shape with exponential decay applied """ return torch.exp(-tensor)