TrafficWheel/model/STGNCDE/controldiffeq/interpolate.py

282 lines
13 KiB
Python
Executable File

import math
import torch
from . import misc
def _natural_cubic_spline_coeffs_without_missing_values(times, path):
# path should be a tensor of shape (..., length)
# Will return the b, two_c, three_d coefficients of the derivative of the cubic spline interpolating the path.
length = path.size(-1)
if length < 2:
# In practice this should always already be caught in __init__.
raise ValueError("Must have a time dimension of size at least 2.")
elif length == 2:
a = path[..., :1]
b = (path[..., 1:] - path[..., :1]) / (times[..., 1:] - times[..., :1])
two_c = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
three_d = torch.zeros(*path.shape[:-1], 1, dtype=path.dtype, device=path.device)
else:
# Set up some intermediate values
time_diffs = times[1:] - times[:-1]
time_diffs_reciprocal = time_diffs.reciprocal()
time_diffs_reciprocal_squared = time_diffs_reciprocal ** 2
three_path_diffs = 3 * (path[..., 1:] - path[..., :-1])
six_path_diffs = 2 * three_path_diffs
path_diffs_scaled = three_path_diffs * time_diffs_reciprocal_squared
# Solve a tridiagonal linear system to find the derivatives at the knots
system_diagonal = torch.empty(length, dtype=path.dtype, device=path.device)
system_diagonal[:-1] = time_diffs_reciprocal
system_diagonal[-1] = 0
system_diagonal[1:] += time_diffs_reciprocal
system_diagonal *= 2
system_rhs = torch.empty_like(path)
system_rhs[..., :-1] = path_diffs_scaled
system_rhs[..., -1] = 0
system_rhs[..., 1:] += path_diffs_scaled
knot_derivatives = misc.tridiagonal_solve(system_rhs, time_diffs_reciprocal, system_diagonal,
time_diffs_reciprocal)
# Do some algebra to find the coefficients of the spline
a = path[..., :-1]
b = knot_derivatives[..., :-1]
two_c = (six_path_diffs * time_diffs_reciprocal
- 4 * knot_derivatives[..., :-1]
- 2 * knot_derivatives[..., 1:]) * time_diffs_reciprocal
three_d = (-six_path_diffs * time_diffs_reciprocal
+ 3 * (knot_derivatives[..., :-1]
+ knot_derivatives[..., 1:])) * time_diffs_reciprocal_squared
return a, b, two_c, three_d
def _natural_cubic_spline_coeffs_with_missing_values(t, path):
if len(path.shape) == 1:
# We have to break everything down to individual scalar paths because of the possibility of missing values
# being different in different channels
return _natural_cubic_spline_coeffs_with_missing_values_scalar(t, path)
else:
a_pieces = []
b_pieces = []
two_c_pieces = []
three_d_pieces = []
for p in path.unbind(dim=0): # TODO: parallelise over this
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, p)
a_pieces.append(a)
b_pieces.append(b)
two_c_pieces.append(two_c)
three_d_pieces.append(three_d)
return (misc.cheap_stack(a_pieces, dim=0),
misc.cheap_stack(b_pieces, dim=0),
misc.cheap_stack(two_c_pieces, dim=0),
misc.cheap_stack(three_d_pieces, dim=0))
def _natural_cubic_spline_coeffs_with_missing_values_scalar(times, path):
# times and path both have shape (length,)
# How to deal with missing values at the start or end of the time series? We're creating some splines, so one
# option is just to extend the first piece backwards, and the final piece forwards. But polynomials tend to
# behave badly when extended beyond the interval they were constructed on, so the results can easily end up
# being awful.
# Instead we impute an observation at the very start equal to the first actual observation made, and impute an
# observation at the very end equal to the last actual observation made, and then procede with splines as
# normal.
not_nan = ~torch.isnan(path)
path_no_nan = path.masked_select(not_nan)
if path_no_nan.size(0) == 0:
# Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
# Note that we may assume that path.size(0) >= 2 by the checks in __init__ so "path.size(0) - 1" is a valid
# thing to do.
return (torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device),
torch.zeros(path.size(0) - 1, dtype=path.dtype, device=path.device))
# else we have at least one non-NaN entry, in which case we're going to impute at least one more entry (as
# the path is of length at least 2 so the start and the end aren't the same), so we will then have at least two
# non-Nan entries. In particular we can call _compute_coeffs safely later.
need_new_not_nan = False
if torch.isnan(path[0]):
if not need_new_not_nan:
path = path.clone()
need_new_not_nan = True
path[0] = path_no_nan[0]
if torch.isnan(path[-1]):
if not need_new_not_nan:
path = path.clone()
need_new_not_nan = True
path[-1] = path_no_nan[-1]
if need_new_not_nan:
not_nan = ~torch.isnan(path)
path_no_nan = path.masked_select(not_nan)
times_no_nan = times.masked_select(not_nan)
# Find the coefficients on the pieces we do understand
# These all have shape (len - 1,)
(a_pieces_no_nan,
b_pieces_no_nan,
two_c_pieces_no_nan,
three_d_pieces_no_nan) = _natural_cubic_spline_coeffs_without_missing_values(times_no_nan, path_no_nan)
# Now we're going to normalise them to give coefficients on every interval
a_pieces = []
b_pieces = []
two_c_pieces = []
three_d_pieces = []
iter_times_no_nan = iter(times_no_nan)
iter_coeffs_no_nan = iter(zip(a_pieces_no_nan, b_pieces_no_nan, two_c_pieces_no_nan, three_d_pieces_no_nan))
next_time_no_nan = next(iter_times_no_nan)
for time in times[:-1]:
# will always trigger on the first iteration because of how we've imputed missing values at the start and
# end of the time series.
if time >= next_time_no_nan:
prev_time_no_nan = next_time_no_nan
next_time_no_nan = next(iter_times_no_nan)
next_a_no_nan, next_b_no_nan, next_two_c_no_nan, next_three_d_no_nan = next(iter_coeffs_no_nan)
offset = prev_time_no_nan - time
a_inner = (0.5 * next_two_c_no_nan - next_three_d_no_nan * offset / 3) * offset
a_pieces.append(next_a_no_nan + (a_inner - next_b_no_nan) * offset)
b_pieces.append(next_b_no_nan + (next_three_d_no_nan * offset - next_two_c_no_nan) * offset)
two_c_pieces.append(next_two_c_no_nan - 2 * next_three_d_no_nan * offset)
three_d_pieces.append(next_three_d_no_nan)
return (misc.cheap_stack(a_pieces, dim=0),
misc.cheap_stack(b_pieces, dim=0),
misc.cheap_stack(two_c_pieces, dim=0),
misc.cheap_stack(three_d_pieces, dim=0))
# The mathematics of this are adapted from http://mathworld.wolfram.com/CubicSpline.html, although they only treat the
# case of each piece being parameterised by [0, 1]. (We instead take the length of each piece to be the difference in
# time stamps.)
def natural_cubic_spline_coeffs(t, X):
"""Calculates the coefficients of the natural cubic spline approximation to the batch of controls given.
Arguments:
t: One dimensional tensor of times. Must be monotonically increasing.
X: tensor of values, of shape (..., L, C), where ... is some number of batch dimensions, L is some length
that must be the same as the length of t, and C is some number of channels. This is interpreted as a
(batch of) paths taking values in a C-dimensional real vector space, with L observations. Missing values
are supported, and should be represented as NaNs.
In particular, the support for missing values allows for batching together elements that are observed at
different times; just set them to have missing values at each other's observation times.
Warning:
Calling this function can be pretty slow. Make sure to cache the result, and don't reinstantiate it on every
forward pass, if at all possible.
Returns:
Four tensors, which should in turn be passed to `controldiffeq.NaturalCubicSpline`.
Why do we do it like this? Because typically you want to use PyTorch tensors at various interfaces, for example
when loading a batch from a DataLoader. If we wrapped all of this up into just the
`controldiffeq.NaturalCubicSpline` class then that sort of thing wouldn't be possible.
As such the suggested use is to:
(a) Load your data.
(b) Preprocess it with this function.
(c) Save the result.
(d) Treat the result as your dataset as far as PyTorch's `torch.utils.data.Dataset` and
`torch.utils.data.DataLoader` classes are concerned.
(e) Call NaturalCubicSpline as the first part of your model.
See also the accompanying example.py.
"""
if not t.is_floating_point():
raise ValueError("t and X must both be floating point/")
if not X.is_floating_point():
raise ValueError("t and X must both be floating point/")
if len(t.shape) != 1:
raise ValueError("t must be one dimensional.")
prev_t_i = -math.inf
for t_i in t:
if t_i <= prev_t_i:
raise ValueError("t must be monotonically increasing.")
if len(X.shape) < 2:
raise ValueError("X must have at least two dimensions, corresponding to time and channels.")
if X.size(-2) != t.size(0):
raise ValueError("The time dimension of X must equal the length of t.")
if t.size(0) < 2:
raise ValueError("Must have a time dimension of size at least 2.")
if torch.isnan(X).any():
# Transpose because channels are a batch dimension for the purpose of finding interpolating polynomials.
# b, two_c, three_d have shape (..., channels, length - 1)
a, b, two_c, three_d = _natural_cubic_spline_coeffs_with_missing_values(t, X.transpose(-1, -2))
else:
# Can do things more quickly in this case.
a, b, two_c, three_d = _natural_cubic_spline_coeffs_without_missing_values(t, X.transpose(-1, -2))
# These all have shape (..., length - 1, channels)
a = a.transpose(-1, -2)
b = b.transpose(-1, -2)
two_c = two_c.transpose(-1, -2)
three_d = three_d.transpose(-1, -2)
return a, b, two_c, three_d
class NaturalCubicSpline:
"""Calculates the natural cubic spline approximation to the batch of controls given. Also calculates its derivative.
Example:
times = torch.linspace(0, 1, 7)
# (2, 1) are batch dimensions. 7 is the time dimension (of the same length as t). 3 is the channel dimension.
X = torch.rand(2, 1, 7, 3)
coeffs = natural_cubic_spline_coeffs(times, X)
# ...at this point you can save the coeffs, put them through PyTorch's Datasets and DataLoaders, etc...
spline = NaturalCubicSpline(times, coeffs)
t = torch.tensor(0.4)
# will be a tensor of shape (2, 1, 3), corresponding to batch and channel dimensions
out = spline.derivative(t)
"""
def __init__(self, times, coeffs, **kwargs):
"""
Arguments:
times: As was passed as an argument to natural_cubic_spline_coeffs.
coeffs: As returned by natural_cubic_spline_coeffs.
"""
super(NaturalCubicSpline, self).__init__(**kwargs)
a, b, two_c, three_d = coeffs
self._times = times
self._a = a
self._b = b
# as we're typically computing derivatives, we store the multiples of these coefficients that are more useful
self._two_c = two_c
self._three_d = three_d
def _interpret_t(self, t):
maxlen = self._b.size(-2) - 1
index = (t > self._times).sum() - 1
index = index.clamp(0, maxlen) # clamp because t may go outside of [t[0], t[-1]]; this is fine
# will never access the last element of self._times; this is correct behaviour
fractional_part = t - self._times[index]
return fractional_part, index
def evaluate(self, t):
"""Evaluates the natural cubic spline interpolation at a point t, which should be a scalar tensor."""
fractional_part, index = self._interpret_t(t)
inner = 0.5 * self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part / 3
inner = self._b[..., index, :] + inner * fractional_part
return self._a[..., index, :] + inner * fractional_part
def derivative(self, t):
"""Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
fractional_part, index = self._interpret_t(t)
inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
deriv = self._b[..., index, :] + inner * fractional_part
return deriv