TrafficWheel/model/STGNRDE/torchcde/interpolation_cubic.py

347 lines
16 KiB
Python

import torch
from torchcde import interpolation_base
from . import misc
def _natural_cubic_spline_coeffs_without_missing_values(t, x):
# x should be a tensor of shape (..., length)
# Will return the b, two_c, three_d coefficients of the derivative of the cubic spline interpolating the path.
length = x.size(-1)
if length < 2:
# In practice this should always already be caught in __init__.
raise ValueError("Must have a time dimension of size at least 2.")
elif length == 2:
a = x[..., :1]
b = (x[..., 1:] - x[..., :1]) / (t[..., 1:] - t[..., :1])
two_c = torch.zeros(*x.shape[:-1], 1, dtype=x.dtype, device=x.device)
three_d = torch.zeros(*x.shape[:-1], 1, dtype=x.dtype, device=x.device)
else:
# Set up some intermediate values
time_diffs = t[1:] - t[:-1]
time_diffs_reciprocal = time_diffs.reciprocal()
time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2
three_path_diffs = 3 * (x[..., 1:] - x[..., :-1])
six_path_diffs = 2 * three_path_diffs
path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared
# Solve a tridiagonal linear system to find the derivatives at the knots
system_diagonal = torch.empty(length, dtype=x.dtype, device=x.device)
system_diagonal[:-1] = time_diffs_reciprocal
system_diagonal[-1] = 0
system_diagonal[1:] += time_diffs_reciprocal
system_diagonal *= 2
system_rhs = torch.empty_like(x)
system_rhs[..., :-1] = path_diffs_scaled
system_rhs[..., -1] = 0
system_rhs[..., 1:] += path_diffs_scaled
knot_derivatives = misc.tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal,
time_diffs_reciprocal)
# Do some algebra to find the coefficients of the spline
a = x[..., :-1]
b = knot_derivatives[..., :-1]
two_c = (six_path_diffs * time_diffs_reciprocal
- 4 * knot_derivatives[..., :-1]
- 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal
three_d = (-six_path_diffs * time_diffs_reciprocal
+ 3 * (knot_derivatives[..., :-1]
+ knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared
return a, b, two_c, three_d
def _natural_cubic_spline_coeffs_with_missing_values(t, x, _version):
if x.ndimension() == 1:
# We have to break everything down to individual scalar paths because of the possibility of missing values
# being different in different channels
return _natural_cubic_spline_coeffs_with_missing_values_scalar(t, x, _version)
else:
a_pieces = []
b_pieces = []
two_c_pieces = []
three_d_pieces = []
for p in x.unbind(dim=0): # TODO: parallelise over this
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, p, _version)
a_pieces.append(a)
b_pieces.append(b)
two_c_pieces.append(two_c)
three_d_pieces.append(three_d)
return (misc.cheap_stack(a_pieces, dim=0),
misc.cheap_stack(b_pieces, dim=0),
misc.cheap_stack(two_c_pieces, dim=0),
misc.cheap_stack(three_d_pieces, dim=0))
def _natural_cubic_spline_coeffs_with_missing_values_scalar(t, x, _version):
# t and x both have shape (length,)
nan = torch.isnan(x)
not_nan = ~nan
path_no_nan = x.masked_select(not_nan)
if path_no_nan.size(0) == 0:
# Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
# Note that we may assume that X.size(0) >= 2 by the checks in __init__ so "X.size(0) - 1" is a valid
# thing to do.
return (torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device),
torch.zeros(x.size(0) - 1, dtype=x.dtype, device=x.device))
# else we have at least one non-NaN entry, in which case we're going to impute at least one more entry (as
# the path is of length at least 2 so the start and the end aren't the same), so we will then have at least two
# non-Nan entries. In particular we can call _compute_coeffs safely later.
# How to deal with missing values at the start or end of the time series? We're creating some splines, so one
# option is just to extend the first piece backwards, and the final piece forwards. But polynomials tend to
# behave badly when extended beyond the interval they were constructed on, so the results can easily end up
# being awful.
if _version == 0:
# Instead we impute an observation at the very start equal to the first actual observation made, and impute an
# observation at the very end equal to the last actual observation made, and then proceed with splines as
# normal.
need_new_not_nan = False
if torch.isnan(x[0]):
if not need_new_not_nan:
x = x.clone()
need_new_not_nan = True
x[0] = path_no_nan[0]
if torch.isnan(x[-1]):
if not need_new_not_nan:
x = x.clone()
need_new_not_nan = True
x[-1] = path_no_nan[-1]
if need_new_not_nan:
not_nan = ~torch.isnan(x)
path_no_nan = x.masked_select(not_nan)
else:
# Instead we fill forward and backward from the first/last observation made. This is better than the previous
# approach as the splines instead rapidly stabilise to the first/last value.
cumsum_mask = not_nan.cumsum(dim=0)
cumsum_mask[nan] = -1
last_non_nan_index = cumsum_mask.argmax(dim=0)
cumsum_mask[nan] = 1 + last_non_nan_index
first_non_nan_index = cumsum_mask.argmin(dim=0)
x = x.clone()
x[:first_non_nan_index] = x[first_non_nan_index]
x[last_non_nan_index + 1:] = x[last_non_nan_index]
not_nan = ~torch.isnan(x)
path_no_nan = x.masked_select(not_nan)
times_no_nan = t.masked_select(not_nan)
# Find the coefficients on the pieces we do understand
# These all have shape (len - 1,)
(a_pieces_no_nan,
b_pieces_no_nan,
two_c_pieces_no_nan,
three_d_pieces_no_nan) = _natural_cubic_spline_coeffs_without_missing_values(times_no_nan, path_no_nan)
# Now we're going to normalise them to give coefficients on every interval
a_pieces = []
b_pieces = []
two_c_pieces = []
three_d_pieces = []
iter_times_no_nan = iter(times_no_nan)
iter_coeffs_no_nan = iter(zip(a_pieces_no_nan, b_pieces_no_nan, two_c_pieces_no_nan, three_d_pieces_no_nan))
next_time_no_nan = next(iter_times_no_nan)
for time in t[:-1]:
# will always trigger on the first iteration because of how we've imputed missing values at the start and
# end of the time series.
if time >= next_time_no_nan:
prev_time_no_nan = next_time_no_nan
next_time_no_nan = next(iter_times_no_nan)
next_a_no_nan, next_b_no_nan, next_two_c_no_nan, next_three_d_no_nan = next(iter_coeffs_no_nan)
offset = prev_time_no_nan - time
a_inner = (0.5 * next_two_c_no_nan - next_three_d_no_nan * offset / 3) * offset
a_pieces.append(next_a_no_nan + (a_inner - next_b_no_nan) * offset)
b_pieces.append(next_b_no_nan + (next_three_d_no_nan * offset - next_two_c_no_nan) * offset)
two_c_pieces.append(next_two_c_no_nan - 2 * next_three_d_no_nan * offset)
three_d_pieces.append(next_three_d_no_nan)
return (misc.cheap_stack(a_pieces, dim=0),
misc.cheap_stack(b_pieces, dim=0),
misc.cheap_stack(two_c_pieces, dim=0),
misc.cheap_stack(three_d_pieces, dim=0))
# The mathematics of this are adapted from http://mathworld.wolfram.com/CubicSpline.html, although they only treat the
# case of each piece being parameterised by [0, 1]. (We instead take the length of each piece to be the difference in
# time stamps.)
def _natural_cubic_spline_coeffs(x, t, _version):
t = misc.validate_input_path(x, t)
if torch.isnan(x).any():
# Transpose because channels are a batch dimension for the purpose of finding interpolating polynomials.
# b, two_c, three_d have shape (..., channels, length - 1)
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, x.transpose(-1, -2), _version)
else:
# Can do things more quickly in this case.
a, b, two_c, three_d = _natural_cubic_spline_coeffs_without_missing_values(t, x.transpose(-1, -2))
# These all have shape (..., length - 1, channels)
a = a.transpose(-1, -2)
b = b.transpose(-1, -2)
two_c = two_c.transpose(-1, -2)
three_d = three_d.transpose(-1, -2)
coeffs = torch.cat([a, b, two_c, three_d], dim=-1) # for simplicity put them all together
return coeffs
def natural_cubic_spline_coeffs(x, t=None):
"""Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
********************
DEPRECATED: this now exists for backward compatibility. For new projects please use `natural_cubic_coeffs` instead,
which handles missing data at the start/end of a time series better.
********************
Arguments:
x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
length-many observations. Missing values are supported, and should be represented as NaNs.
t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
argument**. See the Further Documentation in README.md.
Warning:
If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
don't reinstantiate it on every forward pass, if at all possible.
Returns:
A tensor, which should in turn be passed to `torchcde.CubicSpline`.
Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
when loading a batch from a DataLoader. If we wrapped all of this up into just the
`torchcde.CubicSpline` class then that sort of thing wouldn't be possible.
As such the suggested use is to:
(a) Load your data.
(b) Preprocess it with this function.
(c) Save the result.
(d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
`torch.utils.data.DataLoader` classes are concerned.
(e) Call CubicSpline as the first part of your model.
See also the accompanying example.py.
"""
return _natural_cubic_spline_coeffs(x, t, _version=0)
def natural_cubic_coeffs(x, t=None):
"""Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
Arguments:
x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
length-many observations. Missing values are supported, and should be represented as NaNs.
t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
argument**. See the Further Documentation in README.md.
Warning:
If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
don't reinstantiate it on every forward pass, if at all possible.
Returns:
A tensor, which should in turn be passed to `torchcde.CubicSpline`.
Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
when loading a batch from a DataLoader. If we wrapped all of this up into just the
`torchcde.CubicSpline` class then that sort of thing wouldn't be possible.
As such the suggested use is to:
(a) Load your data.
(b) Preprocess it with this function.
(c) Save the result.
(d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
`torch.utils.data.DataLoader` classes are concerned.
(e) Call CubicSpline as the first part of your model.
See also the accompanying example.py.
"""
return _natural_cubic_spline_coeffs(x, t, _version=1)
class CubicSpline(interpolation_base.InterpolationBase):
"""Calculates the cubic spline approximation to the batch of controls given. Also calculates its derivative.
Example:
# (2, 1) are batch dimensions. 7 is the time dimension (of the same length as t). 3 is the channel dimension.
x = torch.rand(2, 1, 7, 3)
coeffs = natural_cubic_coeffs(x)
# ...at this point you can save coeffs, put it through PyTorch's Datasets and DataLoaders, etc...
spline = CubicSpline(coeffs)
point = torch.tensor(0.4)
# will be a tensor of shape (2, 1, 3), corresponding to batch and channel dimensions
out = spline.derivative(point)
"""
def __init__(self, coeffs, t=None, **kwargs):
"""
Arguments:
coeffs: As returned by `torchcde.natural_cubic_coeffs`.
t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do
not need to use this argument**. See the Further Documentation in README.md.)
"""
super(CubicSpline, self).__init__(**kwargs)
if t is None:
t = torch.linspace(0, coeffs.size(-2), coeffs.size(-2) + 1, dtype=coeffs.dtype, device=coeffs.device)
channels = coeffs.size(-1) // 4
if channels * 4 != coeffs.size(-1): # check that it's a multiple of 4
raise ValueError("Passed invalid coeffs.")
a, b, two_c, three_d = (coeffs[..., :channels], coeffs[..., channels:2 * channels],
coeffs[..., 2 * channels:3 * channels], coeffs[..., 3 * channels:])
self.register_buffer('_t', t)
self.register_buffer('_a', a)
self.register_buffer('_b', b)
# as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
self.register_buffer('_two_c', two_c)
self.register_buffer('_three_d', three_d)
@property
def grid_points(self):
return self._t
@property
def interval(self):
return torch.stack([self._t[0], self._t[-1]])
def _interpret_t(self, t):
t = torch.as_tensor(t, dtype=self._b.dtype, device=self._b.device)
maxlen = self._b.size(-2) - 1
# clamp because t may go outside of [t[0], t[-1]]; this is fine
index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen)
# will never access the last element of self._t; this is correct behaviour
fractional_part = t - self._t[index]
return fractional_part, index
def evaluate(self, t):
fractional_part, index = self._interpret_t(t)
fractional_part = fractional_part.unsqueeze(-1)
inner = 0.5 * self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part / 3
inner = self._b[..., index, :] + inner * fractional_part
return self._a[..., index, :] + inner * fractional_part
def derivative(self, t):
fractional_part, index = self._interpret_t(t)
fractional_part = fractional_part.unsqueeze(-1)
inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
deriv = self._b[..., index, :] + inner * fractional_part
return deriv
class NaturalCubicSpline(CubicSpline):
"""Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
********************
DEPRECATED: this now exists for backward compatibility. For new projects please use `CubicSpline` instead. This
class is general for any cubic coeffs (currently natural cubic or Hermite with backwards differences).
********************
"""