483 lines
22 KiB
Python
483 lines
22 KiB
Python
import torch
|
|
import torchdiffeq
|
|
import torchsde
|
|
import warnings
|
|
|
|
|
|
def _check_compatability_per_tensor_base(control_gradient, z0):
|
|
if control_gradient.shape[:-1] != z0.shape[:-1]:
|
|
raise ValueError(
|
|
"X.derivative did not return a tensor with the same number of batch dimensions as z0. "
|
|
"X.derivative returned shape {} (meaning {} batch dimensions), whilst z0 has shape {} "
|
|
"(meaning {} batch dimensions)."
|
|
"".format(
|
|
tuple(control_gradient.shape),
|
|
tuple(control_gradient.shape[:-1]),
|
|
tuple(z0.shape),
|
|
tuple(z0.shape[:-1]),
|
|
)
|
|
)
|
|
|
|
|
|
def _check_compatability_per_tensor_forward(control_gradient, system, z0):
|
|
_check_compatability_per_tensor_base(control_gradient, z0)
|
|
if system.shape[:-2] != z0.shape[:-1]:
|
|
raise ValueError(
|
|
"func did not return a tensor with the same number of batch dimensions as z0. func returned "
|
|
"shape {} (meaning {} batch dimensions), whilst z0 has shape {} (meaning {} batch"
|
|
" dimensions)."
|
|
"".format(
|
|
tuple(system.shape),
|
|
tuple(system.shape[:-2]),
|
|
tuple(z0.shape),
|
|
tuple(z0.shape[:-1]),
|
|
)
|
|
)
|
|
if system.size(-2) != z0.size(-1):
|
|
raise ValueError(
|
|
"func did not return a tensor with the same number of hidden channels as z0. func returned "
|
|
"shape {} (meaning {} channels), whilst z0 has shape {} (meaning {} channels)."
|
|
"".format(
|
|
tuple(system.shape), system.size(-2), tuple(z0.shape), z0.size(-1)
|
|
)
|
|
)
|
|
if system.size(-1) != control_gradient.size(-1):
|
|
raise ValueError(
|
|
"func did not return a tensor with the same number of input channels as X.derivative "
|
|
"returned. func returned shape {} (meaning {} channels), whilst X.derivative returned shape "
|
|
"{} (meaning {} channels)."
|
|
"".format(
|
|
tuple(system.shape),
|
|
system.size(-1),
|
|
tuple(control_gradient.shape),
|
|
control_gradient.size(-1),
|
|
)
|
|
)
|
|
|
|
|
|
def _check_compatability_per_tensor_prod(control_gradient, vector_field, z0):
|
|
_check_compatability_per_tensor_base(control_gradient, z0)
|
|
if vector_field.shape != z0.shape:
|
|
raise ValueError(
|
|
"func.prod did not return a tensor with the same shape as z0. func.prod returned shape {} "
|
|
"whilst z0 has shape {}."
|
|
"".format(tuple(vector_field.shape), tuple(z0.shape))
|
|
)
|
|
|
|
|
|
def _check_compatability(X, func, z0, t):
|
|
if not hasattr(X, "derivative"):
|
|
raise ValueError("X must have a 'derivative' method.")
|
|
control_gradient = X.derivative(t[0].detach())
|
|
if hasattr(func, "prod"):
|
|
is_prod = True
|
|
vector_field = func.prod(t[0], z0, control_gradient)
|
|
else:
|
|
is_prod = False
|
|
system = func(t[0], z0)
|
|
|
|
if isinstance(z0, torch.Tensor):
|
|
is_tensor = True
|
|
if not isinstance(control_gradient, torch.Tensor):
|
|
raise ValueError(
|
|
"z0 is a tensor and so X.derivative must return a tensor as well."
|
|
)
|
|
if is_prod:
|
|
if not isinstance(vector_field, torch.Tensor):
|
|
raise ValueError(
|
|
"z0 is a tensor and so func.prod must return a tensor as well."
|
|
)
|
|
_check_compatability_per_tensor_prod(control_gradient, vector_field, z0)
|
|
else:
|
|
if not isinstance(system, torch.Tensor):
|
|
raise ValueError(
|
|
"z0 is a tensor and so func must return a tensor as well."
|
|
)
|
|
_check_compatability_per_tensor_forward(control_gradient, system, z0)
|
|
|
|
elif isinstance(z0, (tuple, list)):
|
|
is_tensor = False
|
|
if not isinstance(control_gradient, (tuple, list)):
|
|
raise ValueError(
|
|
"z0 is a tuple/list and so X.derivative must return a tuple/list as well."
|
|
)
|
|
if len(z0) != len(control_gradient):
|
|
raise ValueError(
|
|
"z0 and X.derivative(t) must be tuples of the same length."
|
|
)
|
|
if is_prod:
|
|
if not isinstance(vector_field, (tuple, list)):
|
|
raise ValueError(
|
|
"z0 is a tuple/list and so func.prod must return a tuple/list as well."
|
|
)
|
|
if len(z0) != len(vector_field):
|
|
raise ValueError(
|
|
"z0 and func.prod(t, z, dXdt) must be tuples of the same length."
|
|
)
|
|
for control_gradient_, vector_Field_, z0_ in zip(
|
|
control_gradient, vector_field, z0
|
|
):
|
|
if not isinstance(control_gradient_, torch.Tensor):
|
|
raise ValueError(
|
|
"X.derivative must return a tensor or tuple of tensors."
|
|
)
|
|
if not isinstance(vector_Field_, torch.Tensor):
|
|
raise ValueError(
|
|
"func.prod must return a tensor or tuple/list of tensors."
|
|
)
|
|
_check_compatability_per_tensor_prod(
|
|
control_gradient_, vector_Field_, z0_
|
|
)
|
|
else:
|
|
if not isinstance(system, (tuple, list)):
|
|
raise ValueError(
|
|
"z0 is a tuple/list and so func must return a tuple/list as well."
|
|
)
|
|
if len(z0) != len(system):
|
|
raise ValueError("z0 and func(t, z) must be tuples of the same length.")
|
|
for control_gradient_, system_, z0_ in zip(control_gradient, system, z0):
|
|
if not isinstance(control_gradient_, torch.Tensor):
|
|
raise ValueError(
|
|
"X.derivative must return a tensor or tuple of tensors."
|
|
)
|
|
if not isinstance(system_, torch.Tensor):
|
|
raise ValueError(
|
|
"func must return a tensor or tuple/list of tensors."
|
|
)
|
|
_check_compatability_per_tensor_forward(control_gradient_, system_, z0_)
|
|
|
|
else:
|
|
raise ValueError("z0 must either a tensor or a tuple/list of tensors.")
|
|
|
|
return is_tensor, is_prod
|
|
|
|
|
|
class _VectorField(torch.nn.Module):
|
|
def __init__(self, X, func, is_tensor, is_prod):
|
|
super(_VectorField, self).__init__()
|
|
|
|
self.X = X
|
|
self.func = func
|
|
self.is_tensor = is_tensor
|
|
self.is_prod = is_prod
|
|
|
|
# torchsde backend
|
|
self.sde_type = getattr(func, "sde_type", "stratonovich")
|
|
self.noise_type = getattr(func, "noise_type", "additive")
|
|
|
|
# torchdiffeq backend
|
|
def forward(self, t, z):
|
|
# control_gradient is of shape (..., input_channels)
|
|
control_gradient = self.X.derivative(t)
|
|
|
|
if self.is_prod:
|
|
# out is of shape (..., hidden_channels)
|
|
out = self.func.prod(t, z, control_gradient)
|
|
else:
|
|
# vector_field is of shape (..., hidden_channels, input_channels)
|
|
vector_field = self.func(t, z)
|
|
if self.is_tensor:
|
|
# out is of shape (..., hidden_channels)
|
|
# (The squeezing is necessary to make the matrix-multiply properly batch in all cases)
|
|
out = (vector_field @ control_gradient.unsqueeze(-1)).squeeze(-1)
|
|
else:
|
|
out = tuple(
|
|
(vector_field_ @ control_gradient_.unsqueeze(-1)).squeeze(-1)
|
|
for vector_field_, control_gradient_ in zip(
|
|
vector_field, control_gradient
|
|
)
|
|
)
|
|
|
|
return out
|
|
|
|
# torchsde backend
|
|
f = forward
|
|
|
|
def g(self, t, z):
|
|
return torch.zeros_like(z).unsqueeze(-1)
|
|
|
|
|
|
def cdeint(X, func, z0, t, adjoint=True, backend="torchdiffeq", **kwargs):
|
|
r"""Solves a system of controlled differential equations.
|
|
|
|
Solves the controlled problem:
|
|
```
|
|
z_t = z_{t_0} + \int_{t_0}^t f(s, z_s) dX_s
|
|
```
|
|
where z is a tensor of any shape, and X is some controlling signal.
|
|
|
|
Arguments:
|
|
X: The control. This should be a instance of `torch.nn.Module`, with a `derivative` method. For example
|
|
`torchcde.CubicSpline`. This represents a continuous path derived from the data. The
|
|
derivative at a point will be computed via `X.derivative(t)`, where t is a scalar tensor. The returned
|
|
tensor should have shape (..., input_channels), where '...' is some number of batch dimensions and
|
|
input_channels is the number of channels in the input path.
|
|
func: Should be a callable describing the vector field f(t, z). If using `adjoint=True` (the default), then
|
|
should be an instance of `torch.nn.Module`, to collect the parameters for the adjoint pass. Will be called
|
|
with a scalar tensor t and a tensor z of shape (..., hidden_channels), and should return a tensor of shape
|
|
(..., hidden_channels, input_channels), where hidden_channels and input_channels are integers defined by the
|
|
`hidden_shape` and `X` arguments as above. The '...' corresponds to some number of batch dimensions. If it
|
|
has a method `prod` then that will be called to calculate the matrix-vector product f(t, z) dX_t/dt, via
|
|
`func.prod(t, z, dXdt)`.
|
|
z0: The initial state of the solution. It should have shape (..., hidden_channels), where '...' is some number
|
|
of batch dimensions.
|
|
t: a one dimensional tensor describing the times to range of times to integrate over and output the results at.
|
|
The initial time will be t[0] and the final time will be t[-1].
|
|
adjoint: A boolean; whether to use the adjoint method to backpropagate. Defaults to True.
|
|
backend: Either "torchdiffeq" or "torchsde". Which library to use for the solvers. Note that if using torchsde
|
|
that the Brownian motion component is completely ignored -- so it's still reducing the CDE to an ODE --
|
|
but it makes it possible to e.g. use an SDE solver there as the ODE/CDE solver here, if for some reason
|
|
that's desired.
|
|
**kwargs: Any additional kwargs to pass to the odeint solver of torchdiffeq (the most common are `rtol`, `atol`,
|
|
`method`, `options`) or the sdeint solver of torchsde.
|
|
|
|
Returns:
|
|
The value of each z_{t_i} of the solution to the CDE z_t = z_{t_0} + \int_0^t f(s, z_s)dX_s, where t_i = t[i].
|
|
This will be a tensor of shape (..., len(t), hidden_channels).
|
|
|
|
Raises:
|
|
ValueError for malformed inputs.
|
|
|
|
Note:
|
|
Supports tupled input, i.e. z0 can be a tuple of tensors, and X.derivative and func can return tuples of tensors
|
|
of the same length.
|
|
|
|
Warnings:
|
|
Note that the returned tensor puts the sequence dimension second-to-last, rather than first like in
|
|
`torchdiffeq.odeint` or `torchsde.sdeint`.
|
|
"""
|
|
|
|
# Reduce the default values for the tolerances because CDEs are difficult to solve with the default high tolerances.
|
|
if "atol" not in kwargs:
|
|
kwargs["atol"] = 1e-6
|
|
if "rtol" not in kwargs:
|
|
kwargs["rtol"] = 1e-4
|
|
if adjoint:
|
|
if "adjoint_atol" not in kwargs:
|
|
kwargs["adjoint_atol"] = kwargs["atol"]
|
|
if "adjoint_rtol" not in kwargs:
|
|
kwargs["adjoint_rtol"] = kwargs["rtol"]
|
|
|
|
is_tensor, is_prod = _check_compatability(X, func, z0, t)
|
|
|
|
if adjoint and "adjoint_params" not in kwargs:
|
|
for buffer in X.buffers():
|
|
# Compare based on id to avoid PyTorch not playing well with using `in` on tensors.
|
|
if buffer.requires_grad:
|
|
warnings.warn(
|
|
"One of the inputs to the control path X requires gradients but "
|
|
"`kwargs['adjoint_params']` has not been passed. This is probably a mistake: these "
|
|
"inputs will not receive a gradient when using the adjoint method. Either have the input "
|
|
"not require gradients (if that was unintended), or include it (and every other "
|
|
"parameter needing gradients) in `adjoint_params`. For example:\n"
|
|
"```\n"
|
|
"coeffs = ...\n"
|
|
"func = ...\n"
|
|
"X = CubicSpline(coeffs)\n"
|
|
"adjoint_params = tuple(func.parameters()) + (coeffs,)\n"
|
|
"cdeint(X=X, func=func, ..., adjoint_params=adjoint_params)\n"
|
|
"```"
|
|
)
|
|
|
|
vector_field = _VectorField(X=X, func=func, is_tensor=is_tensor, is_prod=is_prod)
|
|
if backend == "torchdiffeq":
|
|
odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
|
|
out = odeint(func=vector_field, y0=z0, t=t, **kwargs)
|
|
elif backend == "torchsde":
|
|
sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint
|
|
out = sdeint(sde=vector_field, y0=z0, ts=t, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unrecognised backend={backend}")
|
|
|
|
if is_tensor:
|
|
batch_dims = range(1, len(out.shape) - 1)
|
|
out = out.permute(*batch_dims, 0, -1)
|
|
else:
|
|
out_ = []
|
|
for outi in out:
|
|
batch_dims = range(1, len(outi.shape) - 1)
|
|
outi = outi.permute(*batch_dims, 0, -1)
|
|
out_.append(outi)
|
|
out = tuple(out_)
|
|
|
|
return out
|
|
|
|
|
|
class _VectorFieldCustom(torch.nn.Module):
|
|
def __init__(self, X, func_f, func_g, is_tensor, is_prod):
|
|
super(_VectorFieldCustom, self).__init__()
|
|
|
|
self.X = X
|
|
self.func_f = func_f
|
|
self.func_g = func_g
|
|
self.is_tensor = is_tensor
|
|
self.is_prod = is_prod
|
|
|
|
# torchsde backend
|
|
self.sde_type = getattr(func_f, "sde_type", "stratonovich")
|
|
self.noise_type = getattr(func_f, "noise_type", "additive")
|
|
|
|
self.sde_type = getattr(func_g, "sde_type", "stratonovich")
|
|
self.noise_type = getattr(func_g, "noise_type", "additive")
|
|
|
|
# torchdiffeq backend
|
|
def forward(self, t, z):
|
|
# control_gradient is of shape (..., input_channels)
|
|
control_gradient = self.X.derivative(t)
|
|
h = z[0]
|
|
z = z[1]
|
|
|
|
if self.is_prod:
|
|
# out is of shape (..., hidden_channels)
|
|
# FIXME: i don't know it is correct
|
|
out_f = self.func_f.prod(t, h, control_gradient)
|
|
out_g = self.func_g.prod(t, z, control_gradient)
|
|
else:
|
|
# vector_field is of shape (..., hidden_channels, input_channels)
|
|
vector_field_f = self.func_f(t, h)
|
|
vector_field_g = self.func_g(t, z)
|
|
if self.is_tensor:
|
|
# out is of shape (..., hidden_channels)
|
|
# (The squeezing is necessary to make the matrix-multiply properly batch in all cases)
|
|
dh = (vector_field_f @ control_gradient.unsqueeze(-1)).squeeze(-1)
|
|
vector_field_gf = vector_field_g @ vector_field_f
|
|
out = (vector_field_gf @ control_gradient.unsqueeze(-1)).squeeze(-1)
|
|
else:
|
|
dh = tuple(
|
|
(vector_field_f_ @ control_gradient_.unsqueeze(-1)).squeeze(-1)
|
|
for vector_field_f_, control_gradient_ in zip(
|
|
vector_field_f, control_gradient
|
|
)
|
|
)
|
|
vector_field_gf = vector_field_g @ vector_field_f
|
|
out = tuple(
|
|
(vector_field_gf_ @ control_gradient_.unsqueeze(-1)).squeeze(-1)
|
|
for vector_field_gf_, control_gradient_ in zip(
|
|
vector_field_gf, control_gradient
|
|
)
|
|
)
|
|
# FIXME: return value to tuple
|
|
# import pdb; pdb.set_trace()
|
|
return tuple([dh, out])
|
|
|
|
# torchsde backend
|
|
f = forward
|
|
|
|
def g(self, t, z):
|
|
return torch.zeros_like(z).unsqueeze(-1)
|
|
|
|
|
|
def cdeint_custom(
|
|
X, func_f, func_g, h0, z0, t, adjoint=True, backend="torchdiffeq", **kwargs
|
|
):
|
|
r"""Solves a system of controlled differential equations.
|
|
|
|
Solves the controlled problem:
|
|
```
|
|
z_t = z_{t_0} + \int_{t_0}^t f(s, z_s) dX_s
|
|
```
|
|
where z is a tensor of any shape, and X is some controlling signal.
|
|
|
|
Arguments:
|
|
X: The control. This should be a instance of `torch.nn.Module`, with a `derivative` method. For example
|
|
`torchcde.CubicSpline`. This represents a continuous path derived from the data. The
|
|
derivative at a point will be computed via `X.derivative(t)`, where t is a scalar tensor. The returned
|
|
tensor should have shape (..., input_channels), where '...' is some number of batch dimensions and
|
|
input_channels is the number of channels in the input path.
|
|
func: Should be a callable describing the vector field f(t, z). If using `adjoint=True` (the default), then
|
|
should be an instance of `torch.nn.Module`, to collect the parameters for the adjoint pass. Will be called
|
|
with a scalar tensor t and a tensor z of shape (..., hidden_channels), and should return a tensor of shape
|
|
(..., hidden_channels, input_channels), where hidden_channels and input_channels are integers defined by the
|
|
`hidden_shape` and `X` arguments as above. The '...' corresponds to some number of batch dimensions. If it
|
|
has a method `prod` then that will be called to calculate the matrix-vector product f(t, z) dX_t/dt, via
|
|
`func.prod(t, z, dXdt)`.
|
|
z0: The initial state of the solution. It should have shape (..., hidden_channels), where '...' is some number
|
|
of batch dimensions.
|
|
t: a one dimensional tensor describing the times to range of times to integrate over and output the results at.
|
|
The initial time will be t[0] and the final time will be t[-1].
|
|
adjoint: A boolean; whether to use the adjoint method to backpropagate. Defaults to True.
|
|
backend: Either "torchdiffeq" or "torchsde". Which library to use for the solvers. Note that if using torchsde
|
|
that the Brownian motion component is completely ignored -- so it's still reducing the CDE to an ODE --
|
|
but it makes it possible to e.g. use an SDE solver there as the ODE/CDE solver here, if for some reason
|
|
that's desired.
|
|
**kwargs: Any additional kwargs to pass to the odeint solver of torchdiffeq (the most common are `rtol`, `atol`,
|
|
`method`, `options`) or the sdeint solver of torchsde.
|
|
|
|
Returns:
|
|
The value of each z_{t_i} of the solution to the CDE z_t = z_{t_0} + \int_0^t f(s, z_s)dX_s, where t_i = t[i].
|
|
This will be a tensor of shape (..., len(t), hidden_channels).
|
|
|
|
Raises:
|
|
ValueError for malformed inputs.
|
|
|
|
Note:
|
|
Supports tupled input, i.e. z0 can be a tuple of tensors, and X.derivative and func can return tuples of tensors
|
|
of the same length.
|
|
|
|
Warnings:
|
|
Note that the returned tensor puts the sequence dimension second-to-last, rather than first like in
|
|
`torchdiffeq.odeint` or `torchsde.sdeint`.
|
|
"""
|
|
|
|
# Reduce the default values for the tolerances because CDEs are difficult to solve with the default high tolerances.
|
|
if "atol" not in kwargs:
|
|
kwargs["atol"] = 1e-6
|
|
if "rtol" not in kwargs:
|
|
kwargs["rtol"] = 1e-4
|
|
if adjoint:
|
|
if "adjoint_atol" not in kwargs:
|
|
kwargs["adjoint_atol"] = kwargs["atol"]
|
|
if "adjoint_rtol" not in kwargs:
|
|
kwargs["adjoint_rtol"] = kwargs["rtol"]
|
|
|
|
is_tensor, is_prod = _check_compatability(X, func_f, h0, t)
|
|
# is_tensor, is_prod = _check_compatability(X, func_g, z0, t)
|
|
|
|
if adjoint and "adjoint_params" not in kwargs:
|
|
for buffer in X.buffers():
|
|
# Compare based on id to avoid PyTorch not playing well with using `in` on tensors.
|
|
if buffer.requires_grad:
|
|
warnings.warn(
|
|
"One of the inputs to the control path X requires gradients but "
|
|
"`kwargs['adjoint_params']` has not been passed. This is probably a mistake: these "
|
|
"inputs will not receive a gradient when using the adjoint method. Either have the input "
|
|
"not require gradients (if that was unintended), or include it (and every other "
|
|
"parameter needing gradients) in `adjoint_params`. For example:\n"
|
|
"```\n"
|
|
"coeffs = ...\n"
|
|
"func = ...\n"
|
|
"X = CubicSpline(coeffs)\n"
|
|
"adjoint_params = tuple(func.parameters()) + (coeffs,)\n"
|
|
"cdeint(X=X, func=func, ..., adjoint_params=adjoint_params)\n"
|
|
"```"
|
|
)
|
|
|
|
vector_field = _VectorFieldCustom(
|
|
X=X, func_f=func_f, func_g=func_g, is_tensor=is_tensor, is_prod=is_prod
|
|
)
|
|
if backend == "torchdiffeq":
|
|
# import pdb; pdb.set_trace()
|
|
z0 = (h0, z0)
|
|
odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
|
|
out = odeint(func=vector_field, y0=z0, t=t, **kwargs)
|
|
elif backend == "torchsde":
|
|
sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint
|
|
out = sdeint(sde=vector_field, y0=z0, ts=t, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unrecognised backend={backend}")
|
|
|
|
if is_tensor:
|
|
# import pdb; pdb.set_trace()
|
|
out = out[-1]
|
|
# batch_dims = range(1, len(out[-1].shape) - 1)
|
|
batch_dims = range(1, len(out.shape) - 1)
|
|
out = out.permute(*batch_dims, 0, -1)
|
|
else:
|
|
out_ = []
|
|
for outi in out:
|
|
batch_dims = range(1, len(outi.shape) - 1)
|
|
outi = outi.permute(*batch_dims, 0, -1)
|
|
out_.append(outi)
|
|
out = tuple(out_)
|
|
return out
|