import math import torch import warnings from . import interpolation_base from . import misc _two_pi = 2 * math.pi _inv_two_pi = 1 / _two_pi def _linear_interpolation_coeffs_with_missing_values_scalar(t, x): # t and X both have shape (length,) not_nan = ~torch.isnan(x) path_no_nan = x.masked_select(not_nan) if path_no_nan.size(0) == 0: # Every entry is a NaN, so we take a constant path with derivative zero, so return zero coefficients. return torch.zeros(x.size(0), dtype=x.dtype, device=x.device) if path_no_nan.size(0) == x.size(0): # Every entry is not-NaN, so just return. return x x = x.clone() # How to deal with missing values at the start or end of the time series? We impute an observation at the very start # equal to the first actual observation made, and impute an observation at the very end equal to the last actual # observation made, and then proceed as normal. if torch.isnan(x[0]): x[0] = path_no_nan[0] if torch.isnan(x[-1]): x[-1] = path_no_nan[-1] nan_indices = torch.arange(x.size(0), device=x.device).masked_select(torch.isnan(x)) if nan_indices.size(0) == 0: # We only had missing values at the start or end return x prev_nan_index = nan_indices[0] prev_not_nan_index = prev_nan_index - 1 prev_not_nan_indices = [prev_not_nan_index] for nan_index in nan_indices[1:]: if prev_nan_index != nan_index - 1: prev_not_nan_index = nan_index - 1 prev_nan_index = nan_index prev_not_nan_indices.append(prev_not_nan_index) next_nan_index = nan_indices[-1] next_not_nan_index = next_nan_index + 1 next_not_nan_indices = [next_not_nan_index] for nan_index in reversed(nan_indices[:-1]): if next_nan_index != nan_index + 1: next_not_nan_index = nan_index + 1 next_nan_index = nan_index next_not_nan_indices.append(next_not_nan_index) next_not_nan_indices = reversed(next_not_nan_indices) for prev_not_nan_index, nan_index, next_not_nan_index in zip( prev_not_nan_indices, nan_indices, next_not_nan_indices ): prev_stream = x[prev_not_nan_index] next_stream = x[next_not_nan_index] prev_time = t[prev_not_nan_index] next_time = t[next_not_nan_index] time = t[nan_index] ratio = (time - prev_time) / (next_time - prev_time) x[nan_index] = prev_stream + ratio * (next_stream - prev_stream) return x def _linear_interpolation_coeffs_with_missing_values(t, x): if x.ndimension() == 1: # We have to break everything down to individual scalar paths because of the possibility of missing values # being different in different channels return _linear_interpolation_coeffs_with_missing_values_scalar(t, x) else: out_pieces = [] for p in x.unbind(dim=0): # TODO: parallelise over this out = _linear_interpolation_coeffs_with_missing_values(t, p) out_pieces.append(out) return misc.cheap_stack(out_pieces, dim=0) def _prepare_rectilinear_interpolation(data, time_index): """Prepares data for rectilinear interpolation. This function performs the relevant filling and lagging of the data needed to convert raw data into a format such standard linear interpolation will give the rectilinear interpolation. Arguments: x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is some number of batch dimensions. time_index: integer giving the index of the time channel. Example: Suppose we have data: data = [(t1, x1), (t2, NaN), (t3, x3), ...] that we wish to interpolate using a rectilinear scheme. The key point is that this is equivalent to a linear interpolation on data_rect = [(t1, x1), (t2, x1), (t2, x1), (t3, x1), (t3, x3) ...] This function simply performs the conversion from `data` to `data_rect` so that we can apply the inbuilt torchcde linear interpolation scheme to achieve rectilinear interpolation. Returns: A tensor, now of shape (..., 2 * length - 1, input_channels] that can be fed to linear interpolation coeffs to give rectilinear coeffs. """ # Check time_index is of the correct format n_channels = data.size(-1) assert isinstance(time_index, int), ( "Index of the time channel must be an integer in [0, {}]".format(n_channels - 1) ) assert 0 <= time_index < n_channels, ( "Time index must be in [0, {}], was given {}.".format( n_channels - 1, time_index ) ) times = data[..., time_index] assert not torch.isnan(times).any(), ( "There exist nan values in the time column which is not allowed. If the " "times are padded with nans after final time, a simple solution is to " "forward fill the final time." ) # Forward fill and perform lag interleaving for rectilinear data_filled = misc.forward_fill(data) data_repeat = data_filled.repeat_interleave(2, dim=-2) data_repeat[..., :-1, time_index] = data_repeat[..., 1:, time_index] data_rect = data_repeat[..., :-1, :] return data_rect def linear_interpolation_coeffs(x, t=None, rectilinear=None): """Calculates the knots of the linear interpolation of the batch of controls given. Arguments: x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with length-many observations. Missing values are supported, and should be represented as NaNs. t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to tensor([0., 1., ..., length - 1]). If you are using neural CDEs then you **do not need to use this argument**. See the Further Documentation in README.md. rectilinear: Optional integer. Used for performing rectilinear interpolation. This means that interpolation between each two adjoint points is done by first interpolating in the time direction, and then interpolating in the feature direction. (This is useful for causal missing data, see the Further Documentation in README.md.) Defaults to None, i.e. not performing rectilinear interpolation. For rectilinear interpolation time *must* be a channel in x and the `rectilinear` parameter must be an integer specifying the channel index location of the time index in x. Warning: If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and don't call it on every forward pass, if at all possible. Returns: A tensor, which should in turn be passed to `torchcde.LinearInterpolation`. See the docstring for `torchcde.natural_cubic_coeffs` for more information on why we do it this way. """ if rectilinear is not None: if torch.isnan(x[..., 0, :]).any(): warnings.warn( "The data `x` begins with missing values in some channels. The path will be constructed by " "backward-filling the first observed value, which is not causal. Raising a warning as the " "`rectilinear` argument has also been passed, which is nearly always only used when " "causality is desired. If you need causality then fill in the missing value at the start of " "each channel with whatever you'd like it to be. (The mean over that channel is a common " "choice.)" ) x = _prepare_rectilinear_interpolation(x, rectilinear) t = misc.validate_input_path(x, t) if torch.isnan(x).any(): x = _linear_interpolation_coeffs_with_missing_values( t, x.transpose(-1, -2) ).transpose(-1, -2) return x class LinearInterpolation(interpolation_base.InterpolationBase): """Calculates the linear interpolation to the batch of controls given. Also calculates its derivative.""" def __init__(self, coeffs, t=None, **kwargs): """ Arguments: coeffs: As returned by linear_interpolation_coeffs. t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do not need to use this argument**. See the Further Documentation in README.md.) """ super(LinearInterpolation, self).__init__(**kwargs) if t is None: t = torch.linspace( 0, coeffs.size(-2) - 1, coeffs.size(-2), dtype=coeffs.dtype, device=coeffs.device, ) derivs = (coeffs[..., 1:, :] - coeffs[..., :-1, :]) / ( t[1:] - t[:-1] ).unsqueeze(-1) self.register_buffer("_t", t) self.register_buffer("_coeffs", coeffs) self.register_buffer("_derivs", derivs) @property def grid_points(self): return self._t @property def interval(self): return torch.stack([self._t[0], self._t[-1]]) def _interpret_t(self, t): t = torch.as_tensor(t, dtype=self._derivs.dtype, device=self._derivs.device) maxlen = self._derivs.size(-2) - 1 # clamp because t may go outside of [t[0], t[-1]]; this is fine index = torch.bucketize(t.detach(), self._t.detach()).sub(1).clamp(0, maxlen) # will never access the last element of self._t; this is correct behaviour fractional_part = t - self._t[index] return fractional_part, index def evaluate(self, t): fractional_part, index = self._interpret_t(t) fractional_part = fractional_part.unsqueeze(-1) prev_coeff = self._coeffs[..., index, :] next_coeff = self._coeffs[..., index + 1, :] prev_t = self._t[index] next_t = self._t[index + 1] diff_t = next_t - prev_t return prev_coeff + fractional_part * ( next_coeff - prev_coeff ) / diff_t.unsqueeze(-1) def derivative(self, t): fractional_part, index = self._interpret_t(t) deriv = self._derivs[..., index, :] return deriv