282 lines
13 KiB
Python
282 lines
13 KiB
Python
import math
|
|
import torch
|
|
|
|
from . import misc
|
|
|
|
|
|
def _natural_cubic_spline_coeffs_without_missing_values(times, path):
|
|
# path should be a tensor of shape (..., length)
|
|
# Will return the b, two_c, three_d coefficients of the derivative of the cubic spline interpolating the path.
|
|
|
|
length = path.size(-1)
|
|
|
|
if length < 2:
|
|
# In practice this should always already be caught in __init__.
|
|
raise ValueError("Must have a time dimension of size at least 2.")
|
|
elif length == 2:
|
|
a = path[..., :1]
|
|
b = (path[..., 1:] - path[..., :1]) / (times[..., 1:] - times[..., :1])
|
|
two_c = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
|
|
three_d = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
|
|
else:
|
|
# Set up some intermediate values
|
|
time_diffs = times[1:] - times[:-1]
|
|
time_diffs_reciprocal = time_diffs.reciprocal()
|
|
time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2
|
|
three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
|
|
six_path_diffs = 2 * three_path_diffs
|
|
path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared
|
|
|
|
# Solve a tridiagonal linear system to find the derivatives at the knots
|
|
system_diagonal = torch.empty(length, dtype=path.dtype, device=path.device)
|
|
system_diagonal[:-1] = time_diffs_reciprocal
|
|
system_diagonal[-1] = 0
|
|
system_diagonal[1:] += time_diffs_reciprocal
|
|
system_diagonal *= 2
|
|
system_rhs = torch.empty_like(path)
|
|
system_rhs[..., :-1] = path_diffs_scaled
|
|
system_rhs[..., -1] = 0
|
|
system_rhs[..., 1:] += path_diffs_scaled
|
|
knot_derivatives = misc.tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal,
|
|
time_diffs_reciprocal)
|
|
|
|
# Do some algebra to find the coefficients of the spline
|
|
a = path[..., :-1]
|
|
b = knot_derivatives[..., :-1]
|
|
two_c = (six_path_diffs * time_diffs_reciprocal
|
|
- 4 * knot_derivatives[..., :-1]
|
|
- 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal
|
|
three_d = (-six_path_diffs * time_diffs_reciprocal
|
|
+ 3 * (knot_derivatives[..., :-1]
|
|
+ knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared
|
|
|
|
return a, b, two_c, three_d
|
|
|
|
|
|
def _natural_cubic_spline_coeffs_with_missing_values(t, path):
|
|
if len(path.shape) == 1:
|
|
# We have to break everything down to individual scalar paths because of the possibility of missing values
|
|
# being different in different channels
|
|
return _natural_cubic_spline_coeffs_with_missing_values_scalar(t, path)
|
|
else:
|
|
a_pieces = []
|
|
b_pieces = []
|
|
two_c_pieces = []
|
|
three_d_pieces = []
|
|
for p in path.unbind(dim=0): # TODO: parallelise over this
|
|
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, p)
|
|
a_pieces.append(a)
|
|
b_pieces.append(b)
|
|
two_c_pieces.append(two_c)
|
|
three_d_pieces.append(three_d)
|
|
return (misc.cheap_stack(a_pieces, dim=0),
|
|
misc.cheap_stack(b_pieces, dim=0),
|
|
misc.cheap_stack(two_c_pieces, dim=0),
|
|
misc.cheap_stack(three_d_pieces, dim=0))
|
|
|
|
|
|
def _natural_cubic_spline_coeffs_with_missing_values_scalar(times, path):
|
|
# times and path both have shape (length,)
|
|
|
|
# How to deal with missing values at the start or end of the time series? We're creating some splines, so one
|
|
# option is just to extend the first piece backwards, and the final piece forwards. But polynomials tend to
|
|
# behave badly when extended beyond the interval they were constructed on, so the results can easily end up
|
|
# being awful.
|
|
# Instead we impute an observation at the very start equal to the first actual observation made, and impute an
|
|
# observation at the very end equal to the last actual observation made, and then procede with splines as
|
|
# normal.
|
|
|
|
not_nan = ~torch.isnan(path)
|
|
path_no_nan = path.masked_select(not_nan)
|
|
|
|
if path_no_nan.size(0) == 0:
|
|
# Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
|
|
# Note that we may assume that path.size(0) >= 2 by the checks in __init__ so "path.size(0) - 1" is a valid
|
|
# thing to do.
|
|
return (torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
|
|
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
|
|
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
|
|
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device))
|
|
# else we have at least one non-NaN entry, in which case we're going to impute at least one more entry (as
|
|
# the path is of length at least 2 so the start and the end aren't the same), so we will then have at least two
|
|
# non-Nan entries. In particular we can call _compute_coeffs safely later.
|
|
|
|
need_new_not_nan = False
|
|
if torch.isnan(path[0]):
|
|
if not need_new_not_nan:
|
|
path = path.clone()
|
|
need_new_not_nan = True
|
|
path[0] = path_no_nan[0]
|
|
if torch.isnan(path[-1]):
|
|
if not need_new_not_nan:
|
|
path = path.clone()
|
|
need_new_not_nan = True
|
|
path[-1] = path_no_nan[-1]
|
|
if need_new_not_nan:
|
|
not_nan = ~torch.isnan(path)
|
|
path_no_nan = path.masked_select(not_nan)
|
|
times_no_nan = times.masked_select(not_nan)
|
|
|
|
# Find the coefficients on the pieces we do understand
|
|
# These all have shape (len - 1,)
|
|
(a_pieces_no_nan,
|
|
b_pieces_no_nan,
|
|
two_c_pieces_no_nan,
|
|
three_d_pieces_no_nan) = _natural_cubic_spline_coeffs_without_missing_values(times_no_nan, path_no_nan)
|
|
|
|
# Now we're going to normalise them to give coefficients on every interval
|
|
a_pieces = []
|
|
b_pieces = []
|
|
two_c_pieces = []
|
|
three_d_pieces = []
|
|
|
|
iter_times_no_nan = iter(times_no_nan)
|
|
iter_coeffs_no_nan = iter(zip(a_pieces_no_nan, b_pieces_no_nan, two_c_pieces_no_nan, three_d_pieces_no_nan))
|
|
next_time_no_nan = next(iter_times_no_nan)
|
|
for time in times[:-1]:
|
|
# will always trigger on the first iteration because of how we've imputed missing values at the start and
|
|
# end of the time series.
|
|
if time >= next_time_no_nan:
|
|
prev_time_no_nan = next_time_no_nan
|
|
next_time_no_nan = next(iter_times_no_nan)
|
|
next_a_no_nan, next_b_no_nan, next_two_c_no_nan, next_three_d_no_nan = next(iter_coeffs_no_nan)
|
|
offset = prev_time_no_nan - time
|
|
a_inner = (0.5 * next_two_c_no_nan - next_three_d_no_nan * offset / 3) * offset
|
|
a_pieces.append(next_a_no_nan + (a_inner - next_b_no_nan) * offset)
|
|
b_pieces.append(next_b_no_nan + (next_three_d_no_nan * offset - next_two_c_no_nan) * offset)
|
|
two_c_pieces.append(next_two_c_no_nan - 2 * next_three_d_no_nan * offset)
|
|
three_d_pieces.append(next_three_d_no_nan)
|
|
|
|
return (misc.cheap_stack(a_pieces, dim=0),
|
|
misc.cheap_stack(b_pieces, dim=0),
|
|
misc.cheap_stack(two_c_pieces, dim=0),
|
|
misc.cheap_stack(three_d_pieces, dim=0))
|
|
|
|
|
|
# The mathematics of this are adapted from http://mathworld.wolfram.com/CubicSpline.html, although they only treat the
|
|
# case of each piece being parameterised by [0, 1]. (We instead take the length of each piece to be the difference in
|
|
# time stamps.)
|
|
def natural_cubic_spline_coeffs(t, X):
|
|
"""Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
|
|
|
|
Arguments:
|
|
t: One dimensional tensor of times. Must be monotonically increasing.
|
|
X: tensor of values, of shape (..., L, C), where ... is some number of batch dimensions, L is some length
|
|
that must be the same as the length of t, and C is some number of channels. This is interpreted as a
|
|
(batch of) paths taking values in a C-dimensional real vector space, with L observations. Missing values
|
|
are supported, and should be represented as NaNs.
|
|
|
|
In particular, the support for missing values allows for batching together elements that are observed at
|
|
different times; just set them to have missing values at each other's observation times.
|
|
|
|
Warning:
|
|
Calling this function can be pretty slow. Make sure to cache the result, and don't reinstantiate it on every
|
|
forward pass, if at all possible.
|
|
|
|
Returns:
|
|
Four tensors, which should in turn be passed to `controldiffeq.NaturalCubicSpline`.
|
|
|
|
Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
|
|
when loading a batch from a DataLoader. If we wrapped all of this up into just the
|
|
`controldiffeq.NaturalCubicSpline` class then that sort of thing wouldn't be possible.
|
|
|
|
As such the suggested use is to:
|
|
(a) Load your data.
|
|
(b) Preprocess it with this function.
|
|
(c) Save the result.
|
|
(d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
|
|
`torch.utils.data.DataLoader` classes are concerned.
|
|
(e) Call NaturalCubicSpline as the first part of your model.
|
|
|
|
See also the accompanying example.py.
|
|
"""
|
|
if not t.is_floating_point():
|
|
raise ValueError("t and X must both be floating point/")
|
|
if not X.is_floating_point():
|
|
raise ValueError("t and X must both be floating point/")
|
|
if len(t.shape) != 1:
|
|
raise ValueError("t must be one dimensional.")
|
|
prev_t_i = -math.inf
|
|
for t_i in t:
|
|
if t_i <= prev_t_i:
|
|
raise ValueError("t must be monotonically increasing.")
|
|
|
|
if len(X.shape) < 2:
|
|
raise ValueError("X must have at least two dimensions, corresponding to time and channels.")
|
|
|
|
if X.size(-2) != t.size(0):
|
|
raise ValueError("The time dimension of X must equal the length of t.")
|
|
|
|
if t.size(0) < 2:
|
|
raise ValueError("Must have a time dimension of size at least 2.")
|
|
|
|
if torch.isnan(X).any():
|
|
# Transpose because channels are a batch dimension for the purpose of finding interpolating polynomials.
|
|
# b, two_c, three_d have shape (..., channels, length - 1)
|
|
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, X.transpose(-1, -2))
|
|
else:
|
|
# Can do things more quickly in this case.
|
|
a, b, two_c, three_d = _natural_cubic_spline_coeffs_without_missing_values(t, X.transpose(-1, -2))
|
|
|
|
# These all have shape (..., length - 1, channels)
|
|
a = a.transpose(-1, -2)
|
|
b = b.transpose(-1, -2)
|
|
two_c = two_c.transpose(-1, -2)
|
|
three_d = three_d.transpose(-1, -2)
|
|
return a, b, two_c, three_d
|
|
|
|
|
|
class NaturalCubicSpline:
|
|
"""Calculates the natural cubic spline approximation to the batch of controls given. Also calculates its derivative.
|
|
|
|
Example:
|
|
times = torch.linspace(0, 1, 7)
|
|
# (2, 1) are batch dimensions. 7 is the time dimension (of the same length as t). 3 is the channel dimension.
|
|
X = torch.rand(2, 1, 7, 3)
|
|
coeffs = natural_cubic_spline_coeffs(times, X)
|
|
# ...at this point you can save the coeffs, put them through PyTorch's Datasets and DataLoaders, etc...
|
|
spline = NaturalCubicSpline(times, coeffs)
|
|
t = torch.tensor(0.4)
|
|
# will be a tensor of shape (2, 1, 3), corresponding to batch and channel dimensions
|
|
out = spline.derivative(t)
|
|
"""
|
|
|
|
def __init__(self, times, coeffs, **kwargs):
|
|
"""
|
|
Arguments:
|
|
times: As was passed as an argument to natural_cubic_spline_coeffs.
|
|
coeffs: As returned by natural_cubic_spline_coeffs.
|
|
"""
|
|
super(NaturalCubicSpline, self).__init__(**kwargs)
|
|
|
|
a, b, two_c, three_d = coeffs
|
|
|
|
self._times = times
|
|
self._a = a
|
|
self._b = b
|
|
# as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
|
|
self._two_c = two_c
|
|
self._three_d = three_d
|
|
|
|
def _interpret_t(self, t):
|
|
maxlen = self._b.size(-2) - 1
|
|
index = (t > self._times).sum() - 1
|
|
index = index.clamp(0, maxlen) # clamp because t may go outside of [t[0], t[-1]]; this is fine
|
|
# will never access the last element of self._times; this is correct behaviour
|
|
fractional_part = t - self._times[index]
|
|
return fractional_part, index
|
|
|
|
def evaluate(self, t):
|
|
"""Evaluates the natural cubic spline interpolation at a point t, which should be a scalar tensor."""
|
|
fractional_part, index = self._interpret_t(t)
|
|
inner = 0.5 * self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part / 3
|
|
inner = self._b[..., index, :] + inner * fractional_part
|
|
return self._a[..., index, :] + inner * fractional_part
|
|
|
|
def derivative(self, t):
|
|
"""Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
|
|
fractional_part, index = self._interpret_t(t)
|
|
inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
|
|
deriv = self._b[..., index, :] + inner * fractional_part
|
|
return deriv
|