138 lines
4.3 KiB
Python
Executable File
138 lines
4.3 KiB
Python
Executable File
import torch as pt
|
|
from torch.distributions.transforms import ComposeTransform
|
|
from torch.distributions.constraints import independent, real
|
|
from .torch_utils import jit_script
|
|
|
|
|
|
def calc_arma_transform(x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift):
|
|
assert x_is_in_not_out.shape == x.shape[-1:]
|
|
# Match dimensions
|
|
input = input[[None] * (len(x.shape) - len(input.shape)) + [Ellipsis]].expand(
|
|
*(x.shape[:-1] + (input.shape[-1],))
|
|
)
|
|
output = output[[None] * (len(x.shape) - len(output.shape)) + [Ellipsis]].expand(
|
|
*(x.shape[:-1] + (output.shape[-1],))
|
|
)
|
|
return calc_arma_transform_core(
|
|
x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift
|
|
)
|
|
|
|
|
|
@jit_script
|
|
def calc_arma_transform_core(
|
|
x, x_is_in_not_out, input, output, i_coefs, o_coefs, drift
|
|
):
|
|
# Assume last coefficient is one
|
|
i_coefs = i_coefs[..., :-1]
|
|
o_coefs = o_coefs[..., :-1]
|
|
# Trim input and output
|
|
input = input[..., (-i_coefs.shape[-1]) :]
|
|
output = output[..., (-o_coefs.shape[-1]) :]
|
|
# Loop over input
|
|
ret_val = []
|
|
for n in range(x.shape[-1]):
|
|
next_x = x[..., n][..., None]
|
|
if x_is_in_not_out[n]:
|
|
next_val = (
|
|
next_x
|
|
+ ((input * i_coefs).sum(-1) - (output * o_coefs).sum(-1))[..., None]
|
|
+ drift
|
|
)
|
|
input = pt.cat([input[..., 1:], next_x], dim=-1)
|
|
output = pt.cat([output[..., 1:], next_val], dim=-1)
|
|
else:
|
|
next_val = (
|
|
next_x
|
|
+ ((output * o_coefs).sum(-1) - (input * i_coefs).sum(-1))[..., None]
|
|
- drift
|
|
)
|
|
input = pt.cat([input[..., 1:], next_val], dim=-1)
|
|
output = pt.cat([output[..., 1:], next_x], dim=-1)
|
|
ret_val.append(next_val)
|
|
return pt.cat(ret_val, dim=-1)
|
|
|
|
|
|
class ARMATransform(pt.distributions.transforms.Transform):
|
|
"""
|
|
Invertible ARMA transform with support for transforming only part of the samples.
|
|
The transform has a Jacobian determinant of one even if only part of the samples are used as input.
|
|
See a discussion with ChatGPT on the subject at https://chat.openai.com/share/55d34600-6b9d-49ea-b7de-0b70b0e2382f.
|
|
"""
|
|
|
|
domain = independent(real, 1)
|
|
codomain = independent(real, 1)
|
|
bijective = True
|
|
|
|
def __init__(
|
|
self,
|
|
i_tail,
|
|
o_tail,
|
|
i_coefs,
|
|
o_coefs,
|
|
drift,
|
|
x=None,
|
|
idx=None,
|
|
x_is_in_not_out=None,
|
|
):
|
|
super().__init__()
|
|
self.i_tail, self.o_tail, self.i_coefs, self.o_coefs, self.drift = (
|
|
i_tail,
|
|
o_tail,
|
|
i_coefs,
|
|
o_coefs,
|
|
drift,
|
|
)
|
|
self.x, self.idx, self.x_is_in_not_out = x, idx, x_is_in_not_out
|
|
if x_is_in_not_out is not None:
|
|
if not x_is_in_not_out[idx].all():
|
|
raise UserWarning("Inputs must be innovations.")
|
|
|
|
def log_abs_det_jacobian(self, x, y):
|
|
return x.new_zeros(x.shape[:(-1)])
|
|
|
|
def get_x(self, x):
|
|
x_is_in_not_out = (
|
|
pt.tensor([True] * (x if self.x is None else self.x).shape[-1])
|
|
if self.x_is_in_not_out is None
|
|
else self.x_is_in_not_out.clone()
|
|
)
|
|
if self.x is not None:
|
|
x_clone = self.x.clone()
|
|
x_clone[..., self.idx] = x
|
|
x = x_clone
|
|
return x, x_is_in_not_out
|
|
|
|
def _call(self, x):
|
|
x, x_is_in_not_out = self.get_x(x)
|
|
x = calc_arma_transform(
|
|
x,
|
|
x_is_in_not_out,
|
|
self.i_tail,
|
|
self.o_tail,
|
|
self.i_coefs,
|
|
self.o_coefs,
|
|
self.drift,
|
|
)
|
|
if self.x is not None:
|
|
x = x[..., self.idx]
|
|
return x
|
|
|
|
def _inverse(self, x):
|
|
x, x_is_in_not_out = self.get_x(x)
|
|
x_is_in_not_out = ~x_is_in_not_out
|
|
if self.x is not None:
|
|
x_is_in_not_out[self.idx] = ~x_is_in_not_out[self.idx]
|
|
x_is_in_not_out = ~x_is_in_not_out
|
|
x = calc_arma_transform(
|
|
x,
|
|
x_is_in_not_out,
|
|
self.i_tail,
|
|
self.o_tail,
|
|
self.i_coefs,
|
|
self.o_coefs,
|
|
self.drift,
|
|
)
|
|
if self.x is not None:
|
|
x = x[..., self.idx]
|
|
return x
|