TrafficWheel/model/STGNRDE/torchcde/misc.py

167 lines
5.8 KiB
Python

import math
import numpy as np
import torch
def cheap_stack(tensors, dim):
if len(tensors) == 1:
return tensors[0].unsqueeze(dim)
else:
return torch.stack(tensors, dim=dim)
def tridiagonal_solve(b, A_upper, A_diagonal, A_lower):
"""Solves a tridiagonal system Ax = b.
The arguments A_upper, A_digonal, A_lower correspond to the three diagonals of A. Letting U = A_upper, D=A_digonal
and L = A_lower, and assuming for simplicity that there are no batch dimensions, then the matrix A is assumed to be
of size (k, k), with entries:
D[0] U[0]
L[0] D[1] U[1]
L[1] D[2] U[2] 0
L[2] D[3] U[3]
. . .
. . .
. . .
L[k - 3] D[k - 2] U[k - 2]
0 L[k - 2] D[k - 1] U[k - 1]
L[k - 1] D[k]
Arguments:
b: A tensor of shape (..., k), where '...' is zero or more batch dimensions
A_upper: A tensor of shape (..., k - 1).
A_diagonal: A tensor of shape (..., k).
A_lower: A tensor of shape (..., k - 1).
Returns:
A tensor of shape (..., k), corresponding to the x solving Ax = b
Warning:
This implementation isn't super fast. You probably want to cache the result, if possible.
"""
# This implementation is very much written for clarity rather than speed.
A_upper, _ = torch.broadcast_tensors(A_upper, b[..., :-1])
A_lower, _ = torch.broadcast_tensors(A_lower, b[..., :-1])
A_diagonal, b = torch.broadcast_tensors(A_diagonal, b)
channels = b.size(-1)
new_b = np.empty(channels, dtype=object)
new_A_diagonal = np.empty(channels, dtype=object)
outs = np.empty(channels, dtype=object)
new_b[0] = b[..., 0]
new_A_diagonal[0] = A_diagonal[..., 0]
for i in range(1, channels):
w = A_lower[..., i - 1] / new_A_diagonal[i - 1]
new_A_diagonal[i] = A_diagonal[..., i] - w * A_upper[..., i - 1]
new_b[i] = b[..., i] - w * new_b[i - 1]
outs[channels - 1] = new_b[channels - 1] / new_A_diagonal[channels - 1]
for i in range(channels - 2, -1, -1):
outs[i] = (new_b[i] - A_upper[..., i] * outs[i + 1]) / new_A_diagonal[i]
return torch.stack(outs.tolist(), dim=-1)
def validate_input_path(x, t):
if not x.is_floating_point():
raise ValueError("X must both be floating point.")
if x.ndimension() < 2:
raise ValueError("X must have at least two dimensions, corresponding to time and channels. It instead has "
"shape {}.".format(tuple(x.shape)))
if t is None:
t = torch.linspace(0, x.size(-2) - 1, x.size(-2), dtype=x.dtype, device=x.device)
if not t.is_floating_point():
raise ValueError("t must both be floating point.")
if len(t.shape) != 1:
raise ValueError("t must be one dimensional. It instead has shape {}.".format(tuple(t.shape)))
prev_t_i = -math.inf
for t_i in t:
if t_i <= prev_t_i:
raise ValueError("t must be monotonically increasing.")
prev_t_i = t_i
if x.size(-2) != t.size(0):
raise ValueError("The time dimension of X must equal the length of t. X has shape {} and t has shape {}, "
"corresponding to time dimensions of {} and {} respectively."
.format(tuple(x.shape), tuple(t.shape), x.size(-2), t.size(0)))
if t.size(0) < 2:
raise ValueError("Must have a time dimension of size at least 2. It instead has shape {}, corresponding to a "
"time dimension of size {}.".format(tuple(t.shape), t.size(0)))
return t
def forward_fill(x, fill_index=-2):
"""Forward fills data in a torch tensor of shape (..., length, input_channels) along the length dim.
Arguments:
x: tensor of values with first channel index being time, of shape (..., length, input_channels), where ... is
some number of batch dimensions.
fill_index: int that denotes the index to fill down. Default is -2 as we tend to use the convention (...,
length, input_channels) filling down the length dimension.
Returns:
A tensor with forward filled data.
"""
# Checks
assert isinstance(x, torch.Tensor)
assert x.dim() >= 2
mask = torch.isnan(x)
if mask.any():
cumsum_mask = (~mask).cumsum(dim=fill_index)
cumsum_mask[mask] = 0
_, index = cumsum_mask.cummax(dim=fill_index)
x = x.gather(dim=fill_index, index=index)
return x
class TupleControl(torch.nn.Module):
def __init__(self, *controls):
super(TupleControl, self).__init__()
if len(controls) == 0:
raise ValueError("Expected one or more controls to batch together.")
self._interval = controls[0].interval
grid_points = controls[0].grid_points
same_grid_points = True
for control in controls[1:]:
if (control.interval != self._interval).any():
raise ValueError("Can only batch togehter controls over the same interval.")
if same_grid_points and (control.grid_points != grid_points).any():
same_grid_points = False
if same_grid_points:
self._grid_points = grid_points
else:
self._grid_points = None
self.controls = torch.nn.ModuleList(controls)
@property
def interval(self):
return self._interval
@property
def grid_points(self):
if self._grid_points is None:
raise RuntimeError("Batch of controls have different grid points.")
return self._grid_points
def evaluate(self, t):
return tuple(control.evaluate(t) for control in self.controls)
def derivative(self, t):
return tuple(control.derivative(t) for control in self.controls)