import numpy as np import torch def cheap_stack(tensors, dim): if len(tensors) == 1: return tensors[0].unsqueeze(dim) else: return torch.stack(tensors, dim=dim) def tridiagonal_solve(b, A_upper, A_diagonal, A_lower): """Solves a tridiagonal system Ax = b. The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be of size (k, k), with entries: D[0] U[0] L[0] D[1] U[1] L[1] D[2] U[2] 0 L[2] D[3] U[3] . . . . . . . . . L[k - 3] D[k - 2] U[k - 2] 0 L[k - 2] D[k - 1] U[k - 1] L[k - 1] D[k] Arguments: b: A tensor of shape (..., k), where '...' is zero or more batch dimensions A_upper: A tensor of shape (..., k - 1). A_diagonal: A tensor of shape (..., k). A_lower: A tensor of shape (..., k - 1). Returns: A tensor of shape (..., k), corresponding to the x solving Ax = b Warning: This implementation isn't super fast. You probably want to cache the result, if possible. """ # This implementation is very much written for clarity rather than speed. A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1]) A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1]) A_diagonal, b = torch.broadcast_tensors(A_diagonal, b) channels = b.size(-1) new_b = np.empty(channels, dtype=object) new_A_diagonal = np.empty(channels, dtype=object) outs = np.empty(channels, dtype=object) new_b[0] = b[..., 0] new_A_diagonal[0] = A_diagonal[..., 0] for i in range(1, channels): w = A_lower[..., i - 1] / new_A_diagonal[i - 1] new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1] new_b[i] = b[..., i] - w * new_b[i - 1] outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1] for i in range(channels - 2, -1, -1): outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i] return torch.stack(outs.tolist(), dim=-1)