import torch class ConnectionMatrix: """ Computes connectivity matrices from input tensors. Input: - Tensor of shape (seq_len, batch_size, num_node * input_dim) Output: - Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices """ def __init__(self): self.input_dim = None self.num_node = None self.seq_len = None self.batch_size = None def blank(self, x): """ Placeholder function that currently does nothing. """ return x def get(self, inputs: torch.Tensor): """ Computes the connectivity matrices for each time step. Parameters: - inputs: Tensor of shape (batch_size, seq_len, num_node, input_dim) Returns: - Tensor of shape (seq_len, batch_size, num_node, num_node) containing the connectivity matrices """ self.batch_size, self.seq_len, self.num_node, self.input_dim = inputs.size() inputs = inputs.permute(1, 0, 2, 3).reshape( self.seq_len, self.batch_size, self.num_node ) connection_matrix = [self.t0(inputs[0])] # Compute t=0 connectivity matrix for t in range(1, self.seq_len): t_matrix = self.tn(inputs[t], inputs[t - 1], connection_matrix[-1]) connection_matrix.append(t_matrix) return torch.stack(connection_matrix) def t0(self, t0_input): """ Computes the connectivity matrix for time t=0. Parameters: - t0_input: Tensor of shape (batch_size, num_node) Returns: - Tensor of shape (batch_size, num_node, num_node) """ result = torch.zeros(self.batch_size, self.num_node, self.num_node).to("cpu") beta = self.beta(t0_input) beta_rate = beta disappear_rate = 1 - beta_rate result[beta_rate >= disappear_rate] = 1 return result def tn(self, tn_input, former_input, former_matrix): """ Computes the connectivity matrix for time t > 0. Parameters: - tn_input: Tensor of shape (batch_size, num_node) - former_input: Tensor of shape (batch_size, num_node) from previous time step - former_matrix: Tensor of shape (batch_size, num_node, num_node) from previous time step Returns: - Tensor of shape (batch_size, num_node, num_node) """ result = torch.zeros(self.batch_size, self.num_node, self.num_node).to("cpu") gamma_rate = self.gamma(tn_input, former_input) beta_rate = self.beta(tn_input) result[gamma_rate > 0.5] = -1 result[(gamma_rate <= 0.5) & (beta_rate >= 0.5)] = 1 result = torch.where(result == -1, former_matrix, result) return result def delta_ij(self, inputs): """ Computes the delta_ij tensor. Parameters: - inputs: Tensor of shape (batch_size, num_node) Returns: - Tensor of shape (batch_size, num_node, num_node) """ delta_ij = inputs.unsqueeze(2) - inputs.unsqueeze(1) return abs(delta_ij) def dynamic_edge(self, inputs_t, inputs_t_minus_1): """ Computes the dynamic edge tensor. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - Tensor of shape (batch_size, num_node, num_node) """ delta_ij_t = self.delta_ij(inputs_t) delta_ij_t_minus_1 = self.delta_ij(inputs_t_minus_1) return abs(delta_ij_t - delta_ij_t_minus_1) def beta(self, inputs): """ Computes the beta tensor. Parameters: - inputs: Tensor of shape (batch_size, num_node) Returns: - Tensor of shape (batch_size, num_node, num_node) """ delta_ij = self.delta_ij(inputs) return self.exponential_calculation(delta_ij) def gamma(self, inputs_t, inputs_t_minus_1): """ Computes the gamma tensor. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - Tensor of shape (batch_size, num_node, num_node) """ dynamic_edge = self.dynamic_edge(inputs_t, inputs_t_minus_1) return self.exponential_calculation(dynamic_edge) def tn_disappear_rate(self, inputs_t, inputs_t_minus_1): """ Computes st_rate, dy_rate, and disappear_rate for time t. Parameters: - inputs_t: Tensor for time t - inputs_t_minus_1: Tensor for time t-1 Returns: - Tuple of (st_rate, dy_rate, disappear_rate) each of shape (batch_size, num_node, num_node) """ gamma = self.gamma(inputs_t, inputs_t_minus_1) beta = self.beta(inputs_t) dy_rate = gamma st_rate = beta * (1 - gamma) disappear_rate = (1 - beta) * (1 - gamma) return st_rate, dy_rate, disappear_rate def t0_disappear_rate(self, inputs_t): """ Computes st_rate and disappear_rate for time t = 0. Parameters: - inputs_t: Tensor for time t Returns: - Tuple of (st_rate, disappear_rate) each of shape (batch_size, num_node, num_node) """ beta = self.beta(inputs_t) st_rate = beta disappear_rate = 1 - st_rate return st_rate, disappear_rate def exponential_calculation(self, tensor): """ Computes the exponential decay of the tensor. Parameters: - tensor: Tensor of any shape Returns: - Tensor of the same shape with exponential decay applied """ return torch.exp(-tensor)