45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
import torch
|
|
from torchcde.interpolation_linear import linear_interpolation_coeffs
|
|
|
|
|
|
def _setup_hermite_cubic_coeffs_w_backward_differences(times, coeffs, derivs, device=None):
|
|
"""Compute backward hermite from linear coeffs."""
|
|
x_prev = coeffs[..., :-1, :]
|
|
x_next = coeffs[..., 1:, :]
|
|
# Let x_0 - x_{-1} = x_1 - x_0
|
|
derivs_prev = torch.cat((derivs[..., [0], :], derivs[..., :-1, :]), axis=-2)
|
|
derivs_next = derivs
|
|
x_diff = x_next - x_prev
|
|
t_diff = (times[1:] - times[:-1]).unsqueeze(-1)
|
|
# Coeffs
|
|
a = x_prev
|
|
b = derivs_prev
|
|
two_c = 2 * (3 * (x_diff / t_diff - b) - derivs_next + derivs_prev) / t_diff
|
|
three_d = (1 / t_diff ** 2) * (derivs_next - b) - (two_c) / t_diff
|
|
coeffs = torch.cat([a, b, two_c, three_d], dim=-1).to(device)
|
|
return coeffs
|
|
|
|
|
|
def hermite_cubic_coefficients_with_backward_differences(x, t=None):
|
|
"""Computes the coefficients for hermite cubic splines with backward differences.
|
|
|
|
Arguments:
|
|
As `torchcde.linear_interpolation_coeffs`.
|
|
|
|
Returns:
|
|
A tensor, which should in turn be passed to `torchcde.CubicSpline`.
|
|
"""
|
|
# Linear coeffs
|
|
coeffs = linear_interpolation_coeffs(x, t=t, rectilinear=None)
|
|
|
|
if t is None:
|
|
t = torch.linspace(0, coeffs.size(-2) - 1, coeffs.size(-2), dtype=coeffs.dtype, device=coeffs.device)
|
|
|
|
# Linear derivs
|
|
derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / (t[1:] - t[:-1]).unsqueeze(-1)
|
|
|
|
# Use the above to compute hermite coeffs
|
|
hermite_coeffs = _setup_hermite_cubic_coeffs_w_backward_differences(t, coeffs, derivs, device=coeffs.device)
|
|
|
|
return hermite_coeffs
|