TrafficWheel/model/STGNRDE/torchcde/interpolation_hermite_cubic...

45 lines
1.6 KiB
Python

import torch
from torchcde.interpolation_linear import linear_interpolation_coeffs
def _setup_hermite_cubic_coeffs_w_backward_differences(times, coeffs, derivs, device=None):
"""Compute backward hermite from linear coeffs."""
x_prev = coeffs[..., :-1, :]
x_next = coeffs[..., 1:, :]
# Let x_0 - x_{-1} = x_1 - x_0
derivs_prev = torch.cat((derivs[..., [0], :], derivs[..., :-1, :]), axis=-2)
derivs_next = derivs
x_diff = x_next - x_prev
t_diff = (times[1:] - times[:-1]).unsqueeze(-1)
# Coeffs
a = x_prev
b = derivs_prev
two_c = 2 * (3 * (x_diff / t_diff - b) - derivs_next + derivs_prev) / t_diff
three_d = (1 / t_diff ** 2) * (derivs_next - b) - (two_c) / t_diff
coeffs = torch.cat([a, b, two_c, three_d], dim=-1).to(device)
return coeffs
def hermite_cubic_coefficients_with_backward_differences(x, t=None):
"""Computes the coefficients for hermite cubic splines with backward differences.
Arguments:
As `torchcde.linear_interpolation_coeffs`.
Returns:
A tensor, which should in turn be passed to `torchcde.CubicSpline`.
"""
# Linear coeffs
coeffs = linear_interpolation_coeffs(x, t=t, rectilinear=None)
if t is None:
t = torch.linspace(0, coeffs.size(-2) - 1, coeffs.size(-2), dtype=coeffs.dtype, device=coeffs.device)
# Linear derivs
derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / (t[1:] - t[:-1]).unsqueeze(-1)
# Use the above to compute hermite coeffs
hermite_coeffs = _setup_hermite_cubic_coeffs_w_backward_differences(t, coeffs, derivs, device=coeffs.device)
return hermite_coeffs