247 lines
10 KiB
Python
247 lines
10 KiB
Python
import math
|
|
import torch
|
|
import warnings
|
|
|
|
from . import interpolation_base
|
|
from . import misc
|
|
|
|
|
|
_two_pi = 2 * math.pi
|
|
_inv_two_pi = 1 / _two_pi
|
|
|
|
|
|
def _linear_interpolation_coeffs_with_missing_values_scalar(t, x):
|
|
# t and X both have shape (length,)
|
|
|
|
not_nan = ~torch.isnan(x)
|
|
path_no_nan = x.masked_select(not_nan)
|
|
|
|
if path_no_nan.size(0) == 0:
|
|
# Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients.
|
|
return torch.zeros(x.size(0), dtype=x.dtype, device=x.device)
|
|
|
|
if path_no_nan.size(0) == x.size(0):
|
|
# Every entry is not-NaN, so just return.
|
|
return x
|
|
|
|
x = x.clone()
|
|
# How to deal with missing values at the start or end of the time series? We impute an observation at the very start
|
|
# equal to the first actual observation made, and impute an observation at the very end equal to the last actual
|
|
# observation made, and then proceed as normal.
|
|
if torch.isnan(x[0]):
|
|
x[0] = path_no_nan[0]
|
|
if torch.isnan(x[-1]):
|
|
x[-1] = path_no_nan[-1]
|
|
|
|
nan_indices = torch.arange(x.size(0), device=x.device).masked_select(torch.isnan(x))
|
|
|
|
if nan_indices.size(0) == 0:
|
|
# We only had missing values at the start or end
|
|
return x
|
|
|
|
prev_nan_index = nan_indices[0]
|
|
prev_not_nan_index = prev_nan_index - 1
|
|
prev_not_nan_indices = [prev_not_nan_index]
|
|
for nan_index in nan_indices[1:]:
|
|
if prev_nan_index != nan_index - 1:
|
|
prev_not_nan_index = nan_index - 1
|
|
prev_nan_index = nan_index
|
|
prev_not_nan_indices.append(prev_not_nan_index)
|
|
|
|
next_nan_index = nan_indices[-1]
|
|
next_not_nan_index = next_nan_index + 1
|
|
next_not_nan_indices = [next_not_nan_index]
|
|
for nan_index in reversed(nan_indices[:-1]):
|
|
if next_nan_index != nan_index + 1:
|
|
next_not_nan_index = nan_index + 1
|
|
next_nan_index = nan_index
|
|
next_not_nan_indices.append(next_not_nan_index)
|
|
next_not_nan_indices = reversed(next_not_nan_indices)
|
|
for prev_not_nan_index, nan_index, next_not_nan_index in zip(
|
|
prev_not_nan_indices, nan_indices, next_not_nan_indices
|
|
):
|
|
prev_stream = x[prev_not_nan_index]
|
|
next_stream = x[next_not_nan_index]
|
|
prev_time = t[prev_not_nan_index]
|
|
next_time = t[next_not_nan_index]
|
|
time = t[nan_index]
|
|
ratio = (time - prev_time) / (next_time - prev_time)
|
|
x[nan_index] = prev_stream + ratio * (next_stream - prev_stream)
|
|
|
|
return x
|
|
|
|
|
|
def _linear_interpolation_coeffs_with_missing_values(t, x):
|
|
if x.ndimension() == 1:
|
|
# We have to break everything down to individual scalar paths because of the possibility of missing values
|
|
# being different in different channels
|
|
return _linear_interpolation_coeffs_with_missing_values_scalar(t, x)
|
|
else:
|
|
out_pieces = []
|
|
for p in x.unbind(dim=0): # TODO: parallelise over this
|
|
out = _linear_interpolation_coeffs_with_missing_values(t, p)
|
|
out_pieces.append(out)
|
|
return misc.cheap_stack(out_pieces, dim=0)
|
|
|
|
|
|
def _prepare_rectilinear_interpolation(data, time_index):
|
|
"""Prepares data for rectilinear interpolation.
|
|
|
|
This function performs the relevant filling and lagging of the data needed to convert raw data into a format such
|
|
standard linear interpolation will give the rectilinear interpolation.
|
|
|
|
Arguments:
|
|
x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is
|
|
some number of batch dimensions.
|
|
time_index: integer giving the index of the time channel.
|
|
|
|
Example:
|
|
Suppose we have data:
|
|
data = [(t1, x1), (t2, NaN), (t3, x3), ...]
|
|
that we wish to interpolate using a rectilinear scheme. The key point is that this is equivalent to a linear
|
|
interpolation on
|
|
data_rect = [(t1, x1), (t2, x1), (t2, x1), (t3, x1), (t3, x3) ...]
|
|
This function simply performs the conversion from `data` to `data_rect` so that we can apply the inbuilt
|
|
torchcde linear interpolation scheme to achieve rectilinear interpolation.
|
|
|
|
Returns:
|
|
A tensor, now of shape (..., 2 * length - 1, input_channels] that can be fed to linear interpolation coeffs to
|
|
give rectilinear coeffs.
|
|
"""
|
|
# Check time_index is of the correct format
|
|
n_channels = data.size(-1)
|
|
assert isinstance(time_index, int), (
|
|
"Index of the time channel must be an integer in [0, {}]".format(n_channels - 1)
|
|
)
|
|
assert 0 <= time_index < n_channels, (
|
|
"Time index must be in [0, {}], was given {}.".format(
|
|
n_channels - 1, time_index
|
|
)
|
|
)
|
|
|
|
times = data[..., time_index]
|
|
assert not torch.isnan(times).any(), (
|
|
"There exist nan values in the time column which is not allowed. If the "
|
|
"times are padded with nans after final time, a simple solution is to "
|
|
"forward fill the final time."
|
|
)
|
|
|
|
# Forward fill and perform lag interleaving for rectilinear
|
|
data_filled = misc.forward_fill(data)
|
|
data_repeat = data_filled.repeat_interleave(2, dim=-2)
|
|
data_repeat[..., :-1, time_index] = data_repeat[..., 1:, time_index]
|
|
data_rect = data_repeat[..., :-1, :]
|
|
|
|
return data_rect
|
|
|
|
|
|
def linear_interpolation_coeffs(x, t=None, rectilinear=None):
|
|
"""Calculates the knots of the linear interpolation of the batch of controls given.
|
|
|
|
Arguments:
|
|
x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
|
|
is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
|
|
length-many observations. Missing values are supported, and should be represented as NaNs.
|
|
t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
|
|
tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this
|
|
argument**. See the Further Documentation in README.md.
|
|
rectilinear: Optional integer. Used for performing rectilinear interpolation. This means that interpolation
|
|
between each two adjoint points is done by first interpolating in the time direction, and then interpolating
|
|
in the feature direction. (This is useful for causal missing data, see the Further Documentation in
|
|
README.md.) Defaults to None, i.e. not performing rectilinear interpolation. For rectilinear interpolation
|
|
time *must* be a channel in x and the `rectilinear` parameter must be an integer specifying the channel
|
|
index location of the time index in x.
|
|
|
|
Warning:
|
|
If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
|
|
don't call it on every forward pass, if at all possible.
|
|
|
|
Returns:
|
|
A tensor, which should in turn be passed to `torchcde.LinearInterpolation`.
|
|
|
|
See the docstring for `torchcde.natural_cubic_coeffs` for more information on why we do it this way.
|
|
"""
|
|
if rectilinear is not None:
|
|
if torch.isnan(x[..., 0, :]).any():
|
|
warnings.warn(
|
|
"The data `x` begins with missing values in some channels. The path will be constructed by "
|
|
"backward-filling the first observed value, which is not causal. Raising a warning as the "
|
|
"`rectilinear` argument has also been passed, which is nearly always only used when "
|
|
"causality is desired. If you need causality then fill in the missing value at the start of "
|
|
"each channel with whatever you'd like it to be. (The mean over that channel is a common "
|
|
"choice.)"
|
|
)
|
|
x = _prepare_rectilinear_interpolation(x, rectilinear)
|
|
|
|
t = misc.validate_input_path(x, t)
|
|
|
|
if torch.isnan(x).any():
|
|
x = _linear_interpolation_coeffs_with_missing_values(
|
|
t, x.transpose(-1, -2)
|
|
).transpose(-1, -2)
|
|
return x
|
|
|
|
|
|
class LinearInterpolation(interpolation_base.InterpolationBase):
|
|
"""Calculates the linear interpolation to the batch of controls given. Also calculates its derivative."""
|
|
|
|
def __init__(self, coeffs, t=None, **kwargs):
|
|
"""
|
|
Arguments:
|
|
coeffs: As returned by linear_interpolation_coeffs.
|
|
t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do
|
|
not need to use this argument**. See the Further Documentation in README.md.)
|
|
"""
|
|
super(LinearInterpolation, self).__init__(**kwargs)
|
|
|
|
if t is None:
|
|
t = torch.linspace(
|
|
0,
|
|
coeffs.size(-2) - 1,
|
|
coeffs.size(-2),
|
|
dtype=coeffs.dtype,
|
|
device=coeffs.device,
|
|
)
|
|
|
|
derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / (
|
|
t[1:] - t[:-1]
|
|
).unsqueeze(-1)
|
|
|
|
self.register_buffer("_t", t)
|
|
self.register_buffer("_coeffs", coeffs)
|
|
self.register_buffer("_derivs", derivs)
|
|
|
|
@property
|
|
def grid_points(self):
|
|
return self._t
|
|
|
|
@property
|
|
def interval(self):
|
|
return torch.stack([self._t[0], self._t[-1]])
|
|
|
|
def _interpret_t(self, t):
|
|
t = torch.as_tensor(t, dtype=self._derivs.dtype, device=self._derivs.device)
|
|
maxlen = self._derivs.size(-2) - 1
|
|
# clamp because t may go outside of [t[0], t[-1]]; this is fine
|
|
index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen)
|
|
# will never access the last element of self._t; this is correct behaviour
|
|
fractional_part = t - self._t[index]
|
|
return fractional_part, index
|
|
|
|
def evaluate(self, t):
|
|
fractional_part, index = self._interpret_t(t)
|
|
fractional_part = fractional_part.unsqueeze(-1)
|
|
prev_coeff = self._coeffs[..., index, :]
|
|
next_coeff = self._coeffs[..., index + 1, :]
|
|
prev_t = self._t[index]
|
|
next_t = self._t[index + 1]
|
|
diff_t = next_t - prev_t
|
|
return prev_coeff + fractional_part * (
|
|
next_coeff - prev_coeff
|
|
) / diff_t.unsqueeze(-1)
|
|
|
|
def derivative(self, t):
|
|
fractional_part, index = self._interpret_t(t)
|
|
deriv = self._derivs[..., index, :]
|
|
return deriv
|