TrafficWheel/model/STGNRDE/torchcde/log_ode.py

147 lines
6.2 KiB
Python

try:
import signatory
except ImportError:
class DummyModule:
def __getattr__(self, item):
raise ImportError(
"signatory has not been installed. Please install it from "
"https://github.com/patrick-kidger/signatory to use the log-ODE method."
)
signatory = DummyModule()
import torch
from . import interpolation_linear
from . import misc
def _logsignature_windows(x, depth, window_length, t, _version):
t = misc.validate_input_path(x, t)
# slightly roundabout way of doing things (rather than using arange) so that it's constructed differentiably
timespan = t[-1] - t[0]
num_pieces = (timespan / window_length).ceil().to(int).item()
end_t = t[0] + num_pieces * window_length
new_t = torch.linspace(t[0], end_t, num_pieces + 1, dtype=t.dtype, device=t.device)
new_t = torch.min(new_t, t.max())
t_index = 0
new_t_unique = []
new_t_indices = []
for new_t_elem in new_t:
while True:
lequal = new_t_elem <= t[t_index]
close = new_t_elem.allclose(t[t_index])
if lequal or close:
break
t_index += 1
new_t_indices.append(t_index + len(new_t_unique))
if close:
continue
new_t_unique.append(new_t_elem.unsqueeze(0))
batch_dimensions = x.shape[:-2]
missing_X = torch.full((1,), float("nan"), dtype=x.dtype, device=x.device).expand(
*batch_dimensions, 1, x.size(-1)
)
if len(new_t_unique) > 0: # no-op if len == 0, so skip for efficiency
t, indices = torch.cat([t, *new_t_unique]).sort()
x = torch.cat([x, missing_X], dim=-2)[..., indices.clamp(0, x.size(-2)), :]
# Fill in any missing data linearly (linearly because that's what signatures do in between observations anyway)
# and conveniently that's what this already does. Here 'missing data' includes the NaNs we've just added.
x = interpolation_linear.linear_interpolation_coeffs(x, t)
# Flatten batch dimensions for compatibility with Signatory
# flatten_X = x.view(-1, x.size(-2), x.size(-1))
flatten_X = x.reshape(-1, x.size(-2), x.size(-1))
first_increment = torch.zeros(
*batch_dimensions,
signatory.logsignature_channels(x.size(-1), depth),
dtype=x.dtype,
device=x.device,
)
first_increment[..., : x.size(-1)] = x[..., 0, :]
logsignatures = [first_increment]
compute_logsignature = signatory.Logsignature(depth=depth)
for index, next_index, time, next_time in zip(
new_t_indices[:-1], new_t_indices[1:], new_t[:-1], new_t[1:]
):
logsignature = compute_logsignature(flatten_X[..., index : next_index + 1, :])
logsignature = logsignature.view(*batch_dimensions, -1)
if _version == 0:
logsignature = logsignature * (next_time - time)
elif _version == 1:
pass
else:
raise RuntimeError
logsignatures.append(logsignature)
logsignatures = torch.stack(logsignatures, dim=-2)
logsignatures = logsignatures.cumsum(dim=-2)
if _version == 0:
return logsignatures, new_t
elif _version == 1:
return logsignatures
else:
raise RuntimeError
def logsignature_windows(x, depth, window_length, t=None):
"""Calculates logsignatures over multiple windows, for the batch of controls given, as in the log-ODE method.
********************
DEPRECATED: this now exists for backward compatibility. For new projects please use `logsig_windows` instead,
which has a corrected rescaling coefficient.
********************
This corresponds to a transform of the time series, and should be used prior to applying one of the interpolation
schemes.
Arguments:
x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
is interpreted as a (batch of) paths taking values in an input_channels-dimensional real vector space, with
length-many observations. Missing values are supported, and should be represented as NaNs.
depth: What depth to compute the logsignatures to.
window_length: How long a time interval to compute logsignatures over.
t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
tensor([0., 1., ..., length - 1]).
Warning:
If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
don't reinstantiate it on every forward pass, if at all possible.
Returns:
A tuple of two tensors, which are the values and times of the transformed path.
"""
return _logsignature_windows(x, depth, window_length, t, _version=0)
def logsig_windows(x, depth, window_length, t=None):
"""Calculates logsignatures over multiple windows, for the batch of controls given, as in the log-ODE method.
This corresponds to a transform of the time series, and should be used prior to applying one of the interpolation
schemes.
Arguments:
x: tensor of values, of shape (..., length, input_channels), where ... is some number of batch dimensions. This
is interpreted as a (batch of) paths taking values in an `input_channels`-dimensional real vector space,
with `length`-many observations. Missing values are supported, and should be represented as NaNs.
depth: What depth to compute the logsignatures to.
window_length: How long a time interval to compute logsignatures over.
t: Optional one dimensional tensor of times. Must be monotonically increasing. If not passed will default to
`tensor([0., 1., ..., length - 1])`.
Warning:
If there are missing values then calling this function can be pretty slow. Make sure to cache the result, and
don't reinstantiate it on every forward pass, if at all possible.
Returns:
A tensor, which are the values of the transformed path. Times are _not_ returned: the return value is
always scaled such that the corresponding times are just `tensor([0., 1., ..., length - 1])`.
"""
return _logsignature_windows(x, depth, window_length, t, _version=1)