67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
|
|
def cheap_stack(tensors, dim):
|
|
if len(tensors) == 1:
|
|
return tensors[0].unsqueeze(dim)
|
|
else:
|
|
return torch.stack(tensors, dim=dim)
|
|
|
|
|
|
def tridiagonal_solve(b, A_upper, A_diagonal, A_lower):
|
|
"""Solves a tridiagonal system Ax = b.
|
|
|
|
The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
|
|
and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
|
|
of size (k, k), with entries:
|
|
|
|
D[0] U[0]
|
|
L[0] D[1] U[1]
|
|
L[1] D[2] U[2] 0
|
|
L[2] D[3] U[3]
|
|
. . .
|
|
. . .
|
|
. . .
|
|
L[k - 3] D[k - 2] U[k - 2]
|
|
0 L[k - 2] D[k - 1] U[k - 1]
|
|
L[k - 1] D[k]
|
|
|
|
Arguments:
|
|
b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
|
|
A_upper: A tensor of shape (..., k - 1).
|
|
A_diagonal: A tensor of shape (..., k).
|
|
A_lower: A tensor of shape (..., k - 1).
|
|
|
|
Returns:
|
|
A tensor of shape (..., k), corresponding to the x solving Ax = b
|
|
|
|
Warning:
|
|
This implementation isn't super fast. You probably want to cache the result, if possible.
|
|
"""
|
|
|
|
# This implementation is very much written for clarity rather than speed.
|
|
|
|
A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1])
|
|
A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1])
|
|
A_diagonal, b = torch.broadcast_tensors(A_diagonal, b)
|
|
|
|
channels = b.size(-1)
|
|
|
|
new_b = np.empty(channels, dtype=object)
|
|
new_A_diagonal = np.empty(channels, dtype=object)
|
|
outs = np.empty(channels, dtype=object)
|
|
|
|
new_b[0] = b[..., 0]
|
|
new_A_diagonal[0] = A_diagonal[..., 0]
|
|
for i in range(1, channels):
|
|
w = A_lower[..., i - 1] / new_A_diagonal[i - 1]
|
|
new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
|
|
new_b[i] = b[..., i] - w * new_b[i - 1]
|
|
|
|
outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1]
|
|
for i in range(channels - 2, -1, -1):
|
|
outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i]
|
|
|
|
return torch.stack(outs.tolist(), dim=-1)
|