mirror of https://github.com/czzhangheng/STDEN.git
first commit
This commit is contained in:
commit
3a5f5e5170
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
logs/
|
||||||
|
runs/
|
||||||
|
ckpt/
|
||||||
|
data/
|
||||||
|
.vscode/
|
||||||
|
figures/
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# dotenv
|
||||||
|
.env
|
||||||
|
|
||||||
|
# virtualenv
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
# pycharm
|
||||||
|
.idea/
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2021 Echo Ji
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
# STDEN
|
||||||
|
|
||||||
|
This is pandaif.com in paper Towards Physics-guided Neural Networks for Traffic Flow Prediction.
|
||||||
|
|
||||||
|
## Requirement
|
||||||
|
|
||||||
|
* scipy>=1.5.2
|
||||||
|
* numpy>=1.19.1
|
||||||
|
* pandas>=1.1.5
|
||||||
|
* pyyaml>=5.3.1
|
||||||
|
* pytorch>=1.7.1
|
||||||
|
* future>=0.18.2
|
||||||
|
* torchdiffeq>=0.2.0
|
||||||
|
|
||||||
|
Dependency can be installed using the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Traning and Evaluation
|
||||||
|
|
||||||
|
One can run the code by
|
||||||
|
```bash
|
||||||
|
# traning for dataset GT-221
|
||||||
|
python stden_train.py --config_filename=configs/stden_gt.yaml
|
||||||
|
|
||||||
|
# testing for dataset GT-221
|
||||||
|
python stden_eval.py --config_filename=configs/stden_gt.yaml
|
||||||
|
```
|
||||||
|
The configuration file of all datasets are as follows:
|
||||||
|
|
||||||
|
|dataset|config file|
|
||||||
|
|:--|:--|
|
||||||
|
|GT-221|stden_gt.yaml|
|
||||||
|
|WRS-393|stden_wrs.yaml|
|
||||||
|
|ZGC-564|stden_zgc.yaml|
|
||||||
|
|
||||||
|
PS: The data is not public and I am not allowed to distribute it.
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_GM
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_GM
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_GM.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_RM
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_RM
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_RM.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64 # for recognition
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0 # 0 for not load
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_XZ
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_XZ
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_XZ.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0 # 0 for not load
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def masked_mae_loss(y_pred, y_true):
|
||||||
|
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||||
|
y_true[y_true < 1e-4] = 0
|
||||||
|
mask = (y_true != 0).float()
|
||||||
|
mask /= mask.mean() # 将0值的权重分配给非零值
|
||||||
|
loss = torch.abs(y_pred - y_true)
|
||||||
|
loss = loss * mask
|
||||||
|
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||||
|
loss[loss != loss] = 0
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def masked_mape_loss(y_pred, y_true):
|
||||||
|
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||||
|
y_true[y_true < 1e-4] = 0
|
||||||
|
mask = (y_true != 0).float()
|
||||||
|
mask /= mask.mean() # 将0值的权重分配给非零值
|
||||||
|
loss = torch.abs((y_pred - y_true) / y_true)
|
||||||
|
loss = loss * mask
|
||||||
|
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||||
|
loss[loss != loss] = 0
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def masked_rmse_loss(y_pred, y_true):
|
||||||
|
y_true[y_true < 1e-4] = 0
|
||||||
|
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||||
|
mask = (y_true != 0).float()
|
||||||
|
mask /= mask.mean()
|
||||||
|
loss = torch.pow(y_pred - y_true, 2)
|
||||||
|
loss = loss * mask
|
||||||
|
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||||
|
loss[loss != loss] = 0
|
||||||
|
return torch.sqrt(loss.mean())
|
||||||
|
|
@ -0,0 +1,233 @@
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pickle
|
||||||
|
import scipy.sparse as sp
|
||||||
|
import sys
|
||||||
|
# import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from scipy.sparse import linalg
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader(object):
|
||||||
|
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param xs:
|
||||||
|
:param ys:
|
||||||
|
:param batch_size:
|
||||||
|
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
|
||||||
|
"""
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.current_ind = 0
|
||||||
|
if pad_with_last_sample:
|
||||||
|
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
|
||||||
|
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
|
||||||
|
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
|
||||||
|
xs = np.concatenate([xs, x_padding], axis=0)
|
||||||
|
ys = np.concatenate([ys, y_padding], axis=0)
|
||||||
|
self.size = len(xs)
|
||||||
|
self.num_batch = int(self.size // self.batch_size)
|
||||||
|
if shuffle:
|
||||||
|
permutation = np.random.permutation(self.size)
|
||||||
|
xs, ys = xs[permutation], ys[permutation]
|
||||||
|
self.xs = xs
|
||||||
|
self.ys = ys
|
||||||
|
|
||||||
|
def get_iterator(self):
|
||||||
|
self.current_ind = 0
|
||||||
|
|
||||||
|
def _wrapper():
|
||||||
|
while self.current_ind < self.num_batch:
|
||||||
|
start_ind = self.batch_size * self.current_ind
|
||||||
|
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
|
||||||
|
x_i = self.xs[start_ind: end_ind, ...]
|
||||||
|
y_i = self.ys[start_ind: end_ind, ...]
|
||||||
|
yield (x_i, y_i)
|
||||||
|
self.current_ind += 1
|
||||||
|
|
||||||
|
return _wrapper()
|
||||||
|
|
||||||
|
|
||||||
|
class StandardScaler:
|
||||||
|
"""
|
||||||
|
Standard the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return (data - self.mean) / self.std
|
||||||
|
|
||||||
|
def inverse_transform(self, data):
|
||||||
|
return (data * self.std) + self.mean
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_random_walk_matrix(adj_mx):
|
||||||
|
adj_mx = sp.coo_matrix(adj_mx)
|
||||||
|
d = np.array(adj_mx.sum(1))
|
||||||
|
d_inv = np.power(d, -1).flatten()
|
||||||
|
d_inv[np.isinf(d_inv)] = 0.
|
||||||
|
d_mat_inv = sp.diags(d_inv)
|
||||||
|
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
|
||||||
|
return random_walk_mx
|
||||||
|
|
||||||
|
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
|
||||||
|
# Add file handler and stdout handler
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
# Create the log directory if necessary.
|
||||||
|
try:
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(level=level)
|
||||||
|
# Add console handler.
|
||||||
|
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
console_handler.setLevel(level=level)
|
||||||
|
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(level)
|
||||||
|
# Add file handler and stdout handler
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
# Add console handler.
|
||||||
|
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
# Add google cloud log handler
|
||||||
|
logger.info('Log directory: %s', log_dir)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_dir(kwargs):
|
||||||
|
log_dir = kwargs['train'].get('log_dir')
|
||||||
|
if log_dir is None:
|
||||||
|
batch_size = kwargs['data'].get('batch_size')
|
||||||
|
|
||||||
|
filter_type = kwargs['model'].get('filter_type')
|
||||||
|
gcn_step = kwargs['model'].get('gcn_step')
|
||||||
|
horizon = kwargs['model'].get('horizon')
|
||||||
|
latent_dim = kwargs['model'].get('latent_dim')
|
||||||
|
n_traj_samples = kwargs['model'].get('n_traj_samples')
|
||||||
|
ode_method = kwargs['model'].get('ode_method')
|
||||||
|
|
||||||
|
seq_len = kwargs['model'].get('seq_len')
|
||||||
|
rnn_units = kwargs['model'].get('rnn_units')
|
||||||
|
recg_type = kwargs['model'].get('recg_type')
|
||||||
|
|
||||||
|
if filter_type == 'unkP':
|
||||||
|
filter_type_abbr = 'UP'
|
||||||
|
elif filter_type == 'IncP':
|
||||||
|
filter_type_abbr = 'NV'
|
||||||
|
else:
|
||||||
|
filter_type_abbr = 'DF'
|
||||||
|
|
||||||
|
|
||||||
|
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
|
||||||
|
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, seq_len, horizon, time.strftime('%m%d%H%M%S'))
|
||||||
|
base_dir = kwargs.get('log_base_dir')
|
||||||
|
log_dir = os.path.join(base_dir, run_id)
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
return log_dir
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
|
||||||
|
if('BJ' in dataset_dir):
|
||||||
|
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
data['x_' + category] = data['x_' + category] #[..., :4] # ignore the time index
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
|
||||||
|
data['x_' + category] = cat_data['x']
|
||||||
|
data['y_' + category] = cat_data['y']
|
||||||
|
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std()) # 第0维是要预测的量,但是第1维是什么呢?
|
||||||
|
# Data format
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
data['x_' + category] = scaler.transform(data['x_' + category])
|
||||||
|
data['y_' + category] = scaler.transform(data['y_' + category])
|
||||||
|
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
||||||
|
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
|
||||||
|
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
|
||||||
|
data['scaler'] = scaler
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_graph_data(pkl_filename):
|
||||||
|
adj_mx = np.load(pkl_filename)
|
||||||
|
return adj_mx
|
||||||
|
|
||||||
|
def graph_grad(adj_mx):
|
||||||
|
"""Fetch the graph gradient operator."""
|
||||||
|
num_nodes = adj_mx.shape[0]
|
||||||
|
|
||||||
|
num_edges = (adj_mx > 0.).sum()
|
||||||
|
grad = torch.zeros(num_nodes, num_edges)
|
||||||
|
e = 0
|
||||||
|
for i in range(num_nodes):
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if adj_mx[i, j] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
grad[i, e] = 1.
|
||||||
|
grad[j, e] = -1.
|
||||||
|
e += 1
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def init_network_weights(net, std = 0.1):
|
||||||
|
"""
|
||||||
|
Just for nn.Linear net.
|
||||||
|
"""
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(m.bias, val=0)
|
||||||
|
|
||||||
|
def split_last_dim(data):
|
||||||
|
last_dim = data.size()[-1]
|
||||||
|
last_dim = last_dim//2
|
||||||
|
|
||||||
|
res = data[..., :last_dim], data[..., last_dim:]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_device(tensor):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if tensor.is_cuda:
|
||||||
|
device = tensor.get_device()
|
||||||
|
return device
|
||||||
|
|
||||||
|
def sample_standard_gaussian(mu, sigma):
|
||||||
|
device = get_device(mu)
|
||||||
|
|
||||||
|
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
|
||||||
|
r = d.sample(mu.size()).squeeze(-1)
|
||||||
|
return r * sigma.float() + mu.float()
|
||||||
|
|
||||||
|
def create_net(n_inputs, n_outputs, n_layers = 0,
|
||||||
|
n_units = 100, nonlinear = nn.Tanh):
|
||||||
|
layers = [nn.Linear(n_inputs, n_units)]
|
||||||
|
for i in range(n_layers):
|
||||||
|
layers.append(nonlinear())
|
||||||
|
layers.append(nn.Linear(n_units, n_units))
|
||||||
|
|
||||||
|
layers.append(nonlinear())
|
||||||
|
layers.append(nn.Linear(n_units, n_outputs))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class DiffeqSolver(nn.Module):
|
||||||
|
def __init__(self, odefunc, method, latent_dim,
|
||||||
|
odeint_rtol = 1e-4, odeint_atol = 1e-5):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.ode_method = method
|
||||||
|
self.odefunc = odefunc
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.rtol = odeint_rtol
|
||||||
|
self.atol = odeint_atol
|
||||||
|
|
||||||
|
def forward(self, first_point, time_steps_to_pred):
|
||||||
|
"""
|
||||||
|
Decoder the trajectory through the ODE Solver.
|
||||||
|
|
||||||
|
:param time_steps_to_pred: horizon
|
||||||
|
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
|
||||||
|
"""
|
||||||
|
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
|
||||||
|
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
|
||||||
|
|
||||||
|
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
|
||||||
|
start_time = time.time()
|
||||||
|
self.odefunc.nfe = 0
|
||||||
|
pred_y = odeint(self.odefunc,
|
||||||
|
first_point,
|
||||||
|
time_steps_to_pred,
|
||||||
|
rtol=self.rtol,
|
||||||
|
atol=self.atol,
|
||||||
|
method=self.ode_method)
|
||||||
|
time_fe = time.time() - start_time
|
||||||
|
|
||||||
|
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
|
||||||
|
# assert(pred_y.size()[1] == n_traj_samples)
|
||||||
|
# assert(pred_y.size()[2] == batch_size)
|
||||||
|
|
||||||
|
return pred_y, (self.odefunc.nfe, time_fe)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from lib import utils
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class LayerParams:
|
||||||
|
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||||
|
self._rnn_network = rnn_network
|
||||||
|
self._params_dict = {}
|
||||||
|
self._biases_dict = {}
|
||||||
|
self._type = layer_type
|
||||||
|
|
||||||
|
def get_weights(self, shape):
|
||||||
|
if shape not in self._params_dict:
|
||||||
|
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
||||||
|
nn.init.xavier_normal_(nn_param)
|
||||||
|
self._params_dict[shape] = nn_param
|
||||||
|
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||||
|
nn_param)
|
||||||
|
return self._params_dict[shape]
|
||||||
|
|
||||||
|
def get_biases(self, length, bias_start=0.0):
|
||||||
|
if length not in self._biases_dict:
|
||||||
|
biases = nn.Parameter(torch.empty(length, device=device))
|
||||||
|
nn.init.constant_(biases, bias_start)
|
||||||
|
self._biases_dict[length] = biases
|
||||||
|
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||||
|
biases)
|
||||||
|
|
||||||
|
return self._biases_dict[length]
|
||||||
|
|
||||||
|
class ODEFunc(nn.Module):
|
||||||
|
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
|
||||||
|
gen_layers=1, nonlinearity='tanh', filter_type="default"):
|
||||||
|
"""
|
||||||
|
:param num_units: dimensionality of the hidden layers
|
||||||
|
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
|
||||||
|
:param adj_mx:
|
||||||
|
:param gcn_step:
|
||||||
|
:param num_nodes:
|
||||||
|
:param gen_layers: hidden layers in each ode func.
|
||||||
|
:param nonlinearity:
|
||||||
|
:param filter_type: default
|
||||||
|
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||||
|
"""
|
||||||
|
super(ODEFunc, self).__init__()
|
||||||
|
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||||
|
|
||||||
|
self._num_nodes = num_nodes
|
||||||
|
self._num_units = num_units # hidden dimension
|
||||||
|
self._latent_dim = latent_dim
|
||||||
|
self._gen_layers = gen_layers
|
||||||
|
self.nfe = 0
|
||||||
|
|
||||||
|
self._filter_type = filter_type
|
||||||
|
if(self._filter_type == "unkP"):
|
||||||
|
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
|
||||||
|
utils.init_network_weights(ode_func_net)
|
||||||
|
self.gradient_net = ode_func_net
|
||||||
|
else:
|
||||||
|
self._gcn_step = gcn_step
|
||||||
|
self._gconv_params = LayerParams(self, 'gconv')
|
||||||
|
self._supports = []
|
||||||
|
supports = []
|
||||||
|
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
|
||||||
|
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
|
||||||
|
|
||||||
|
for support in supports:
|
||||||
|
self._supports.append(self._build_sparse_matrix(support))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_sparse_matrix(L):
|
||||||
|
L = L.tocoo()
|
||||||
|
indices = np.column_stack((L.row, L.col))
|
||||||
|
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||||
|
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||||
|
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||||
|
return L
|
||||||
|
|
||||||
|
def forward(self, t_local, y, backwards = False):
|
||||||
|
"""
|
||||||
|
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
|
||||||
|
|
||||||
|
t_local: current time point
|
||||||
|
y: value at the current time point, shape (B, num_nodes * latent_dim)
|
||||||
|
|
||||||
|
:return
|
||||||
|
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
|
||||||
|
"""
|
||||||
|
self.nfe += 1
|
||||||
|
grad = self.get_ode_gradient_nn(t_local, y)
|
||||||
|
if backwards:
|
||||||
|
grad = -grad
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def get_ode_gradient_nn(self, t_local, inputs):
|
||||||
|
if(self._filter_type == "unkP"):
|
||||||
|
grad = self._fc(inputs)
|
||||||
|
elif (self._filter_type == "IncP"):
|
||||||
|
grad = - self.ode_func_net(inputs)
|
||||||
|
else: # default is diffusion process
|
||||||
|
# theta shape: (B, num_nodes * latent_dim)
|
||||||
|
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
|
||||||
|
grad = - theta * self.ode_func_net(inputs)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def ode_func_net(self, inputs):
|
||||||
|
c = inputs
|
||||||
|
for i in range(self._gen_layers):
|
||||||
|
c = self._gconv(c, self._num_units)
|
||||||
|
c = self._activation(c)
|
||||||
|
c = self._gconv(c, self._latent_dim)
|
||||||
|
c = self._activation(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def _fc(self, inputs):
|
||||||
|
batch_size = inputs.size()[0]
|
||||||
|
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
|
||||||
|
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat(x, x_):
|
||||||
|
x_ = x_.unsqueeze(0)
|
||||||
|
return torch.cat([x, x_], dim=0)
|
||||||
|
|
||||||
|
def _gconv(self, inputs, output_size, bias_start=0.0):
|
||||||
|
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||||
|
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||||
|
# inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||||
|
input_size = inputs.size(2)
|
||||||
|
|
||||||
|
x = inputs
|
||||||
|
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||||
|
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
|
||||||
|
x = torch.unsqueeze(x0, 0)
|
||||||
|
|
||||||
|
if self._gcn_step == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for support in self._supports:
|
||||||
|
x1 = torch.sparse.mm(support, x0)
|
||||||
|
x = self._concat(x, x1)
|
||||||
|
|
||||||
|
for k in range(2, self._gcn_step + 1):
|
||||||
|
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||||
|
x = self._concat(x, x2)
|
||||||
|
x1, x0 = x2, x1
|
||||||
|
|
||||||
|
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
|
||||||
|
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
|
||||||
|
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
|
||||||
|
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
|
||||||
|
|
||||||
|
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
|
||||||
|
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
|
||||||
|
|
||||||
|
biases = self._gconv_params.get_biases(output_size, bias_start)
|
||||||
|
x += biases
|
||||||
|
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
|
||||||
|
return torch.reshape(x, [batch_size, self._num_nodes * output_size])
|
||||||
|
|
@ -0,0 +1,206 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torch.nn.modules.rnn import GRU
|
||||||
|
from model.ode_func import ODEFunc
|
||||||
|
from model.diffeq_solver import DiffeqSolver
|
||||||
|
|
||||||
|
from lib import utils
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
class EncoderAttrs:
|
||||||
|
def __init__(self, adj_mx, **model_kwargs):
|
||||||
|
self.adj_mx = adj_mx
|
||||||
|
self.num_nodes = adj_mx.shape[0]
|
||||||
|
self.num_edges = (adj_mx > 0.).sum()
|
||||||
|
self.gcn_step = int(model_kwargs.get('gcn_step', 2))
|
||||||
|
self.filter_type = model_kwargs.get('filter_type', 'default')
|
||||||
|
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
||||||
|
self.rnn_units = int(model_kwargs.get('rnn_units'))
|
||||||
|
self.latent_dim = int(model_kwargs.get('latent_dim', 4))
|
||||||
|
|
||||||
|
class STDENModel(nn.Module, EncoderAttrs):
|
||||||
|
def __init__(self, adj_mx, logger, **model_kwargs):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||||
|
self._logger = logger
|
||||||
|
####################################################
|
||||||
|
# recognition net
|
||||||
|
####################################################
|
||||||
|
self.encoder_z0 = Encoder_z0_RNN(adj_mx, **model_kwargs)
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# ode solver
|
||||||
|
####################################################
|
||||||
|
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
|
||||||
|
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
|
||||||
|
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
|
||||||
|
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
|
||||||
|
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
|
||||||
|
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
|
||||||
|
ode_set_str = "ODE setting --latent {} --samples {} --method {} \
|
||||||
|
--atol {:6f} --rtol {:6f} --gen_layer {} --gen_dim {}".format(\
|
||||||
|
self.latent_dim, self.n_traj_samples, self.ode_method, \
|
||||||
|
self.atol, self.rtol, self.num_gen_layer, self.ode_gen_dim)
|
||||||
|
odefunc = ODEFunc(self.ode_gen_dim, # hidden dimension
|
||||||
|
self.latent_dim,
|
||||||
|
adj_mx,
|
||||||
|
self.gcn_step,
|
||||||
|
self.num_nodes,
|
||||||
|
filter_type=self.filter_type
|
||||||
|
).to(device)
|
||||||
|
self.diffeq_solver = DiffeqSolver(odefunc,
|
||||||
|
self.ode_method,
|
||||||
|
self.latent_dim,
|
||||||
|
odeint_rtol=self.rtol,
|
||||||
|
odeint_atol=self.atol
|
||||||
|
)
|
||||||
|
self._logger.info(ode_set_str)
|
||||||
|
|
||||||
|
self.save_latent = bool(model_kwargs.get('save_latent', False))
|
||||||
|
self.latent_feat = None # used to extract the latent feature
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# decoder
|
||||||
|
####################################################
|
||||||
|
self.horizon = int(model_kwargs.get('horizon', 1))
|
||||||
|
self.out_feat = int(model_kwargs.get('output_dim', 1))
|
||||||
|
self.decoder = Decoder(
|
||||||
|
self.out_feat,
|
||||||
|
adj_mx,
|
||||||
|
self.num_nodes,
|
||||||
|
self.num_edges,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
##########################################
|
||||||
|
def forward(self, inputs, labels=None, batches_seen=None):
|
||||||
|
"""
|
||||||
|
seq2seq forward pass
|
||||||
|
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
|
||||||
|
:param labels: shape (horizon, batch_size, num_edges * output_dim)
|
||||||
|
:param batches_seen: batches seen till now
|
||||||
|
:return: outputs: (self.horizon, batch_size, self.num_edges * self.output_dim)
|
||||||
|
"""
|
||||||
|
perf_time = time.time()
|
||||||
|
# shape: [1, batch, num_nodes * latent_dim]
|
||||||
|
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
||||||
|
self._logger.debug("Recognition complete with {:.1f}s".format(time.time() - perf_time))
|
||||||
|
|
||||||
|
# sample 'n_traj_samples' trajectory
|
||||||
|
perf_time = time.time()
|
||||||
|
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
||||||
|
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
|
||||||
|
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
|
||||||
|
|
||||||
|
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float().to(device)
|
||||||
|
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
|
||||||
|
|
||||||
|
# Shape of sol_ys (horizon, n_traj_samples, batch_size, self.num_nodes * self.latent_dim)
|
||||||
|
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
|
||||||
|
self._logger.debug("ODE solver complete with {:.1f}s".format(time.time() - perf_time))
|
||||||
|
if(self.save_latent):
|
||||||
|
# Shape of latent_feat (horizon, batch_size, self.num_nodes * self.latent_dim)
|
||||||
|
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
||||||
|
|
||||||
|
perf_time = time.time()
|
||||||
|
outputs = self.decoder(sol_ys)
|
||||||
|
self._logger.debug("Decoder complete with {:.1f}s".format(time.time() - perf_time))
|
||||||
|
|
||||||
|
if batches_seen == 0:
|
||||||
|
self._logger.info(
|
||||||
|
"Total trainable parameters {}".format(count_parameters(self))
|
||||||
|
)
|
||||||
|
return outputs, fe
|
||||||
|
|
||||||
|
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
||||||
|
def __init__(self, adj_mx, **model_kwargs):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||||
|
self.recg_type = model_kwargs.get('recg_type', 'gru') # gru
|
||||||
|
|
||||||
|
if(self.recg_type == 'gru'):
|
||||||
|
# gru settings
|
||||||
|
self.input_dim = int(model_kwargs.get('input_dim', 1))
|
||||||
|
self.gru_rnn = GRU(self.input_dim, self.rnn_units).to(device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("The recognition net only support 'gru'.")
|
||||||
|
|
||||||
|
# hidden to z0 settings
|
||||||
|
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
|
||||||
|
self.inv_grad[self.inv_grad != 0.] = 0.5
|
||||||
|
self.hiddens_to_z0 = nn.Sequential(
|
||||||
|
nn.Linear(self.rnn_units, 50),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(50, self.latent_dim * 2),)
|
||||||
|
|
||||||
|
utils.init_network_weights(self.hiddens_to_z0)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""
|
||||||
|
encoder forward pass on t time steps
|
||||||
|
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
|
||||||
|
:return: mean, std: # shape (n_samples=1, batch_size, self.latent_dim)
|
||||||
|
"""
|
||||||
|
if(self.recg_type == 'gru'):
|
||||||
|
# shape of outputs: (seq_len, batch, num_senor * rnn_units)
|
||||||
|
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
||||||
|
inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim)
|
||||||
|
inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim)
|
||||||
|
|
||||||
|
outputs, _ = self.gru_rnn(inputs)
|
||||||
|
last_output = outputs[-1]
|
||||||
|
# (batch_size, num_edges, rnn_units)
|
||||||
|
last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1))
|
||||||
|
last_output = torch.transpose(last_output, (-2, -1))
|
||||||
|
# (batch_size, num_nodes, rnn_units)
|
||||||
|
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("The recognition net only support 'gru'.")
|
||||||
|
|
||||||
|
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
|
||||||
|
mean = mean.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
|
||||||
|
std = std.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
|
||||||
|
std = std.abs()
|
||||||
|
|
||||||
|
assert(not torch.isnan(mean).any())
|
||||||
|
assert(not torch.isnan(std).any())
|
||||||
|
|
||||||
|
return mean.unsqueeze(0), std.unsqueeze(0) # for n_sample traj
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.num_edges = num_edges
|
||||||
|
self.grap_grad = utils.graph_grad(adj_mx)
|
||||||
|
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""
|
||||||
|
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
:return outputs: (horizon, batch_size, num_edges * output_dim), average result of n_traj_samples.
|
||||||
|
"""
|
||||||
|
assert(len(inputs.size()) == 4)
|
||||||
|
horizon, n_traj_samples, batch_size = inputs.size()[:3]
|
||||||
|
|
||||||
|
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
|
||||||
|
latent_dim = inputs.size(-2)
|
||||||
|
# transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.
|
||||||
|
outputs = torch.matmul(inputs, self.grap_grad)
|
||||||
|
|
||||||
|
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim)
|
||||||
|
outputs = torch.mean(
|
||||||
|
torch.mean(outputs, axis=3),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
outputs = outputs.reshape(horizon, batch_size, -1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
@ -0,0 +1,415 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from random import SystemRandom
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from lib import utils
|
||||||
|
from model.stden_model import STDENModel
|
||||||
|
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class STDENSupervisor:
|
||||||
|
def __init__(self, adj_mx, **kwargs):
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._data_kwargs = kwargs.get('data')
|
||||||
|
self._model_kwargs = kwargs.get('model')
|
||||||
|
self._train_kwargs = kwargs.get('train')
|
||||||
|
|
||||||
|
self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.)
|
||||||
|
|
||||||
|
# logging.
|
||||||
|
self._log_dir = utils.get_log_dir(kwargs)
|
||||||
|
self._writer = SummaryWriter('runs/' + self._log_dir)
|
||||||
|
|
||||||
|
log_level = self._kwargs.get('log_level', 'INFO')
|
||||||
|
self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level)
|
||||||
|
|
||||||
|
# data set
|
||||||
|
self._data = utils.load_dataset(**self._data_kwargs)
|
||||||
|
self.standard_scaler = self._data['scaler']
|
||||||
|
self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std))
|
||||||
|
|
||||||
|
self.num_edges = (adj_mx > 0.).sum()
|
||||||
|
self.input_dim = int(self._model_kwargs.get('input_dim', 1))
|
||||||
|
self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder
|
||||||
|
self.output_dim = int(self._model_kwargs.get('output_dim', 1))
|
||||||
|
self.use_curriculum_learning = bool(
|
||||||
|
self._model_kwargs.get('use_curriculum_learning', False))
|
||||||
|
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
|
||||||
|
|
||||||
|
# setup model
|
||||||
|
stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs)
|
||||||
|
self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model
|
||||||
|
self._logger.info("Model created")
|
||||||
|
|
||||||
|
self.experimentID = self._train_kwargs.get('load', 0)
|
||||||
|
if self.experimentID == 0:
|
||||||
|
# Make a new experiment ID
|
||||||
|
self.experimentID = int(SystemRandom().random()*100000)
|
||||||
|
self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID))
|
||||||
|
|
||||||
|
self._epoch_num = self._train_kwargs.get('epoch', 0)
|
||||||
|
if self._epoch_num > 0:
|
||||||
|
self._logger.info('Loading model...')
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def save_model(self, epoch):
|
||||||
|
model_dir = self.ckpt_path
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
|
||||||
|
config = dict(self._kwargs)
|
||||||
|
config['model_state_dict'] = self.stden_model.state_dict()
|
||||||
|
config['epoch'] = epoch
|
||||||
|
model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch))
|
||||||
|
torch.save(config, model_path)
|
||||||
|
self._logger.info("Saved model at {}".format(epoch))
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self._setup_graph()
|
||||||
|
model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num))
|
||||||
|
assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num
|
||||||
|
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu')
|
||||||
|
self.stden_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
||||||
|
|
||||||
|
def _setup_graph(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.stden_model.eval()
|
||||||
|
|
||||||
|
val_iterator = self._data['val_loader'].get_iterator()
|
||||||
|
|
||||||
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
|
x, y = self._prepare_data(x, y)
|
||||||
|
output = self.stden_model(x)
|
||||||
|
break
|
||||||
|
|
||||||
|
def train(self, **kwargs):
|
||||||
|
self._logger.info('Model mode: train')
|
||||||
|
kwargs.update(self._train_kwargs)
|
||||||
|
return self._train(**kwargs)
|
||||||
|
|
||||||
|
def _train(self, base_lr,
|
||||||
|
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
|
||||||
|
test_every_n_epochs=10, epsilon=1e-8, **kwargs):
|
||||||
|
# steps is used in learning rate - will see if need to use it?
|
||||||
|
min_val_loss = float('inf')
|
||||||
|
wait = 0
|
||||||
|
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
||||||
|
gamma=lr_decay_ratio)
|
||||||
|
|
||||||
|
self._logger.info('Start training ...')
|
||||||
|
|
||||||
|
# this will fail if model is loaded with a changed batch_size
|
||||||
|
num_batches = self._data['train_loader'].num_batch
|
||||||
|
self._logger.info("num_batches: {}".format(num_batches))
|
||||||
|
|
||||||
|
batches_seen = num_batches * self._epoch_num
|
||||||
|
|
||||||
|
# used for nfe
|
||||||
|
c = []
|
||||||
|
res, keys = [], []
|
||||||
|
|
||||||
|
for epoch_num in range(self._epoch_num, epochs):
|
||||||
|
|
||||||
|
self.stden_model.train()
|
||||||
|
|
||||||
|
train_iterator = self._data['train_loader'].get_iterator()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
c.clear() #nfe
|
||||||
|
for i, (x, y) in enumerate(train_iterator):
|
||||||
|
if(i >= num_batches):
|
||||||
|
break
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
||||||
|
output, fe = self.stden_model(x, y, batches_seen)
|
||||||
|
|
||||||
|
if batches_seen == 0:
|
||||||
|
# this is a workaround to accommodate dynamically registered parameters
|
||||||
|
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
||||||
|
|
||||||
|
loss = self._compute_loss(y, output)
|
||||||
|
self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||||
|
c.append([*fe, loss.item()])
|
||||||
|
|
||||||
|
self._logger.debug(loss.item())
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
batches_seen += 1 # global step in tensorboard
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# gradient clipping
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
del x, y, output, loss # del make these memory no-labeled trash
|
||||||
|
torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash
|
||||||
|
|
||||||
|
# used for nfe
|
||||||
|
res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err']))
|
||||||
|
keys.append(epoch_num)
|
||||||
|
|
||||||
|
self._logger.info("epoch complete")
|
||||||
|
lr_scheduler.step()
|
||||||
|
self._logger.info("evaluating now!")
|
||||||
|
|
||||||
|
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
self._writer.add_scalar('training loss',
|
||||||
|
np.mean(losses),
|
||||||
|
batches_seen)
|
||||||
|
|
||||||
|
if (epoch_num % log_every) == log_every - 1:
|
||||||
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
|
||||||
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
|
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
|
||||||
|
(end_time - start_time))
|
||||||
|
self._logger.info(message)
|
||||||
|
|
||||||
|
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
|
||||||
|
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
|
||||||
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
||||||
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
|
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
||||||
|
(end_time - start_time))
|
||||||
|
self._logger.info(message)
|
||||||
|
|
||||||
|
if val_loss < min_val_loss:
|
||||||
|
wait = 0
|
||||||
|
if save_model:
|
||||||
|
model_file_name = self.save_model(epoch_num)
|
||||||
|
self._logger.info(
|
||||||
|
'Val loss decrease from {:.4f} to {:.4f}, '
|
||||||
|
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
||||||
|
min_val_loss = val_loss
|
||||||
|
|
||||||
|
elif val_loss >= min_val_loss:
|
||||||
|
wait += 1
|
||||||
|
if wait == patience:
|
||||||
|
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||||
|
break
|
||||||
|
|
||||||
|
if bool(self._model_kwargs.get('nfe', False)):
|
||||||
|
res = pd.concat(res, keys=keys)
|
||||||
|
# self._logger.info("res.shape: ", res.shape)
|
||||||
|
res.index.names = ['epoch', 'iter']
|
||||||
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||||
|
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
||||||
|
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
||||||
|
nfe_file = os.path.join(
|
||||||
|
self._data_kwargs.get('dataset_dir', 'data'),
|
||||||
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||||
|
res.to_pickle(nfe_file)
|
||||||
|
# res.to_csv(nfe_file)
|
||||||
|
|
||||||
|
def _prepare_data(self, x, y):
|
||||||
|
x, y = self._get_x_y(x, y)
|
||||||
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||||
|
return x.to(device), y.to(device)
|
||||||
|
|
||||||
|
def _get_x_y(self, x, y):
|
||||||
|
"""
|
||||||
|
:param x: shape (batch_size, seq_len, num_edges, input_dim)
|
||||||
|
:param y: shape (batch_size, horizon, num_edges, input_dim)
|
||||||
|
:returns x shape (seq_len, batch_size, num_edges, input_dim)
|
||||||
|
y shape (horizon, batch_size, num_edges, input_dim)
|
||||||
|
"""
|
||||||
|
x = torch.from_numpy(x).float()
|
||||||
|
y = torch.from_numpy(y).float()
|
||||||
|
self._logger.debug("X: {}".format(x.size()))
|
||||||
|
self._logger.debug("y: {}".format(y.size()))
|
||||||
|
x = x.permute(1, 0, 2, 3)
|
||||||
|
y = y.permute(1, 0, 2, 3)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def _get_x_y_in_correct_dims(self, x, y):
|
||||||
|
"""
|
||||||
|
:param x: shape (seq_len, batch_size, num_edges, input_dim)
|
||||||
|
:param y: shape (horizon, batch_size, num_edges, input_dim)
|
||||||
|
:return: x: shape (seq_len, batch_size, num_edges * input_dim)
|
||||||
|
y: shape (horizon, batch_size, num_edges * output_dim)
|
||||||
|
"""
|
||||||
|
batch_size = x.size(1)
|
||||||
|
self._logger.debug("size of x {}".format(x.size()))
|
||||||
|
x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim)
|
||||||
|
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
||||||
|
self.num_edges * self.output_dim)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def _compute_loss(self, y_true, y_predicted):
|
||||||
|
y_true = self.standard_scaler.inverse_transform(y_true)
|
||||||
|
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
||||||
|
return masked_mae_loss(y_predicted, y_true)
|
||||||
|
|
||||||
|
def _compute_loss_eval(self, y_true, y_predicted):
|
||||||
|
y_true = self.standard_scaler.inverse_transform(y_true)
|
||||||
|
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
||||||
|
return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item()
|
||||||
|
|
||||||
|
def evaluate(self, dataset='val', batches_seen=0, save=False):
|
||||||
|
"""
|
||||||
|
Computes mae rmse mape loss and the predict if save
|
||||||
|
:return: mean L1Loss
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
self.stden_model.eval()
|
||||||
|
|
||||||
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
|
mae_losses = []
|
||||||
|
mape_losses = []
|
||||||
|
rmse_losses = []
|
||||||
|
y_dict = None
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
y_truths = []
|
||||||
|
y_preds = []
|
||||||
|
|
||||||
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
||||||
|
output, fe = self.stden_model(x)
|
||||||
|
mae, mape, rmse = self._compute_loss_eval(y, output)
|
||||||
|
mae_losses.append(mae)
|
||||||
|
mape_losses.append(mape)
|
||||||
|
rmse_losses.append(rmse)
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
y_truths.append(y.cpu())
|
||||||
|
y_preds.append(output.cpu())
|
||||||
|
|
||||||
|
mean_loss = {
|
||||||
|
'mae': np.mean(mae_losses),
|
||||||
|
'mape': np.mean(mape_losses),
|
||||||
|
'rmse': np.mean(rmse_losses)
|
||||||
|
}
|
||||||
|
|
||||||
|
self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse']))
|
||||||
|
self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen)
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
y_preds = np.concatenate(y_preds, axis=1)
|
||||||
|
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
||||||
|
|
||||||
|
y_truths_scaled = []
|
||||||
|
y_preds_scaled = []
|
||||||
|
# self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape))
|
||||||
|
for t in range(y_preds.shape[0]):
|
||||||
|
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
|
||||||
|
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
|
||||||
|
y_truths_scaled.append(y_truth)
|
||||||
|
y_preds_scaled.append(y_pred)
|
||||||
|
|
||||||
|
y_preds_scaled = np.stack(y_preds_scaled)
|
||||||
|
y_truths_scaled = np.stack(y_truths_scaled)
|
||||||
|
|
||||||
|
y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
|
||||||
|
|
||||||
|
# save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||||
|
# save_path = os.path.join(save_dir, 'pred.npz')
|
||||||
|
# np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
||||||
|
|
||||||
|
return mean_loss['mae'], y_dict
|
||||||
|
|
||||||
|
def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False):
|
||||||
|
"""
|
||||||
|
Computes mae rmse mape loss and the prediction if `save` is set True.
|
||||||
|
"""
|
||||||
|
self._logger.info('Model mode: Evaluation')
|
||||||
|
with torch.no_grad():
|
||||||
|
self.stden_model.eval()
|
||||||
|
|
||||||
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
|
mae_losses = []
|
||||||
|
mape_losses = []
|
||||||
|
rmse_losses = []
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
y_truths = []
|
||||||
|
y_preds = []
|
||||||
|
|
||||||
|
if(extract_latent):
|
||||||
|
latents = []
|
||||||
|
|
||||||
|
# used for nfe
|
||||||
|
c = []
|
||||||
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
||||||
|
output, fe = self.stden_model(x)
|
||||||
|
mae, mape, rmse = [], [], []
|
||||||
|
for seq in seq_len:
|
||||||
|
_mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1])
|
||||||
|
mae.append(_mae)
|
||||||
|
mape.append(_mape)
|
||||||
|
rmse.append(_rmse)
|
||||||
|
mae_losses.append(mae)
|
||||||
|
mape_losses.append(mape)
|
||||||
|
rmse_losses.append(rmse)
|
||||||
|
c.append([*fe, np.mean(mae)])
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
y_truths.append(y.cpu())
|
||||||
|
y_preds.append(output.cpu())
|
||||||
|
|
||||||
|
if(extract_latent):
|
||||||
|
latents.append(self.stden_model.latent_feat.cpu())
|
||||||
|
|
||||||
|
mean_loss = {
|
||||||
|
'mae': np.mean(mae_losses, axis=0),
|
||||||
|
'mape': np.mean(mape_losses, axis=0),
|
||||||
|
'rmse': np.mean(rmse_losses, axis=0)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, seq in enumerate(seq_len):
|
||||||
|
self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(
|
||||||
|
seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i]))
|
||||||
|
|
||||||
|
if(save):
|
||||||
|
# shape (horizon, num_sapmles, feat_dim)
|
||||||
|
y_preds = np.concatenate(y_preds, axis=1)
|
||||||
|
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
||||||
|
y_preds_scaled = self.standard_scaler.inverse_transform(y_preds)
|
||||||
|
y_truths_scaled = self.standard_scaler.inverse_transform(y_truths)
|
||||||
|
|
||||||
|
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||||
|
save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num))
|
||||||
|
np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
||||||
|
|
||||||
|
if(extract_latent):
|
||||||
|
# concatenate on batch dimension
|
||||||
|
latents = np.concatenate(latents, axis=1)
|
||||||
|
# Shape of latents (horizon, num_samples, self.num_edges * self.output_dim)
|
||||||
|
|
||||||
|
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||||
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||||
|
save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num))
|
||||||
|
np.savez_compressed(save_path, latent=latents)
|
||||||
|
|
||||||
|
if bool(self._model_kwargs.get('nfe', False)):
|
||||||
|
res = pd.DataFrame(c, columns=['nfe', 'time', 'err'])
|
||||||
|
res.index.name = 'iter'
|
||||||
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||||
|
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
||||||
|
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
||||||
|
nfe_file = os.path.join(
|
||||||
|
self._data_kwargs.get('dataset_dir', 'data'),
|
||||||
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||||
|
res.to_pickle(nfe_file)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
scipy>=1.5.2
|
||||||
|
numpy>=1.19.1
|
||||||
|
pandas>=1.1.5
|
||||||
|
pyyaml>=5.3.1
|
||||||
|
pytorch>=1.7.1
|
||||||
|
future>=0.18.2
|
||||||
|
torchdiffeq>=0.2.0
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from lib.utils import load_graph_data
|
||||||
|
from model.stden_supervisor import STDENSupervisor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
with open(args.config_filename) as f:
|
||||||
|
supervisor_config = yaml.load(f)
|
||||||
|
|
||||||
|
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
||||||
|
adj_mx = load_graph_data(graph_pkl_filename)
|
||||||
|
|
||||||
|
supervisor = STDENSupervisor(adj_mx=adj_mx, **supervisor_config)
|
||||||
|
|
||||||
|
horizon = supervisor_config['model'].get('horizon')
|
||||||
|
extract_latent = supervisor_config['model'].get('save_latent')
|
||||||
|
supervisor.eval_more(dataset='test',
|
||||||
|
save=args.save_pred,
|
||||||
|
seq_len=np.arange(1, horizon+1, 1),
|
||||||
|
extract_latent=extract_latent)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--config_filename', default=None, type=str,
|
||||||
|
help='Configuration filename for restoring the model.')
|
||||||
|
parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.')
|
||||||
|
parser.add_argument('-r', '--random_seed', type=int, default=2021, help="Random seed for reproduction.")
|
||||||
|
parser.add_argument('--save_pred', action='store_true', help='Save the prediction.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
torch.manual_seed(args.random_seed)
|
||||||
|
np.random.seed(args.random_seed)
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from lib.utils import load_graph_data
|
||||||
|
from model.stden_supervisor import STDENSupervisor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
with open(args.config_filename) as f:
|
||||||
|
supervisor_config = yaml.load(f)
|
||||||
|
|
||||||
|
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
||||||
|
adj_mx = load_graph_data(graph_pkl_filename)
|
||||||
|
|
||||||
|
supervisor = STDENSupervisor(adj_mx=adj_mx, **supervisor_config)
|
||||||
|
|
||||||
|
supervisor.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--config_filename', default=None, type=str,
|
||||||
|
help='Configuration filename for restoring the model.')
|
||||||
|
parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.')
|
||||||
|
parser.add_argument('-r', '--random_seed', type=int, default=2021, help="Random seed for reproduction.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
torch.manual_seed(args.random_seed)
|
||||||
|
np.random.seed(args.random_seed)
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -0,0 +1,535 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(array([[ 0, 1, 2, 3, 4],\n",
|
||||||
|
" [ 0, 6, 7, 8, 9],\n",
|
||||||
|
" [ 0, 0, 12, 13, 14],\n",
|
||||||
|
" [ 0, 0, 0, 18, 19],\n",
|
||||||
|
" [ 0, 0, 0, 0, 24]]),\n",
|
||||||
|
" array([[ 0, 0, 0, 0, 0],\n",
|
||||||
|
" [ 5, 6, 0, 0, 0],\n",
|
||||||
|
" [10, 11, 12, 0, 0],\n",
|
||||||
|
" [15, 16, 17, 18, 0],\n",
|
||||||
|
" [20, 21, 22, 23, 24]]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"m = np.arange(0, 25).reshape((5, 5))\n",
|
||||||
|
"\n",
|
||||||
|
"out = np.triu(m)\n",
|
||||||
|
"inp = np.tril(m)\n",
|
||||||
|
"out, inp"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[1, 0, 1, 0, 0],\n",
|
||||||
|
" [1, 0, 0, 1, 1],\n",
|
||||||
|
" [1, 0, 0, 1, 0],\n",
|
||||||
|
" [1, 1, 1, 0, 0],\n",
|
||||||
|
" [1, 0, 1, 1, 1]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"adj = np.random.randint(0, 2, size=(5, 5))\n",
|
||||||
|
"adj"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[ 0, 0, 2, 0, 0],\n",
|
||||||
|
" [ 5, 0, 0, 8, 9],\n",
|
||||||
|
" [10, 0, 0, 13, 0],\n",
|
||||||
|
" [15, 16, 17, 0, 0],\n",
|
||||||
|
" [20, 0, 22, 23, 48]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"((inp + out) * adj)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"tensor([[1],\n",
|
||||||
|
" [2],\n",
|
||||||
|
" [3]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"a = torch.tensor([1, 2, 3])\n",
|
||||||
|
"a.unsqueeze_(-1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([1, 5, 5])\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "IndexError",
|
||||||
|
"evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[1;32m<ipython-input-29-95457010b19d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mr\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstart_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[1;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"r = torch.tensor(((inp + out) * adj))\n",
|
||||||
|
"r.unsqueeze_(0)\n",
|
||||||
|
"print(r.shape)\n",
|
||||||
|
"r[r > 0].flatten(start_dim=0, end_dim=1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"r = r.repeat((2, 1, 1))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 41,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"tensor([[ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n",
|
||||||
|
" 0, 0, 20, 0, 22, 23, 48],\n",
|
||||||
|
" [ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n",
|
||||||
|
" 0, 0, 20, 0, 22, 23, 48]], dtype=torch.int32)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 41,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"r = torch.flatten(r, start_dim=1, end_dim=-1)\n",
|
||||||
|
"r[r>0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 42,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"tensor([[ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48],\n",
|
||||||
|
" [ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48]],\n",
|
||||||
|
" dtype=torch.int32)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 42,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"r[r > 0].reshape(2, -1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 53,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch.nn as nn"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 90,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class GraphGrad(torch.nn.Module):\n",
|
||||||
|
" def __init__(self, adj_mx):\n",
|
||||||
|
" \"\"\"Graph gradient operator that transform functions on nodes to functions on edges.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" super(GraphGrad, self).__init__()\n",
|
||||||
|
" self.adj_mx = adj_mx\n",
|
||||||
|
" self.grad = self._grad(adj_mx)\n",
|
||||||
|
" \n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def _grad(adj_mx):\n",
|
||||||
|
" \"\"\"Fetch the graph gradient operator.\"\"\"\n",
|
||||||
|
" num_nodes = adj_mx.size()[-1]\n",
|
||||||
|
"\n",
|
||||||
|
" num_edges = (adj_mx > 0.).sum()\n",
|
||||||
|
" grad = torch.zeros(num_nodes, num_edges)\n",
|
||||||
|
" e = 0\n",
|
||||||
|
" for i in range(num_nodes):\n",
|
||||||
|
" for j in range(num_nodes):\n",
|
||||||
|
" if adj_mx[i, j] == 0:\n",
|
||||||
|
" continue\n",
|
||||||
|
"\n",
|
||||||
|
" grad[i, e] = 1.\n",
|
||||||
|
" grad[j, e] = -1.\n",
|
||||||
|
" e += 1\n",
|
||||||
|
" return grad\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, z):\n",
|
||||||
|
" \"\"\"Transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" return torch.matmul(z, self.grad)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 68,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([[1, 0, 1, 0, 0],\n",
|
||||||
|
" [1, 0, 0, 1, 1],\n",
|
||||||
|
" [1, 0, 0, 1, 0],\n",
|
||||||
|
" [1, 1, 1, 0, 0],\n",
|
||||||
|
" [1, 0, 1, 1, 1]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 68,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"adj"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 84,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"torch.Size([5, 14])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 84,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"gg = GraphGrad(torch.tensor(adj))\n",
|
||||||
|
"grad = gg.grad\n",
|
||||||
|
"grad.shape"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 94,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"torch.Size([14, 5])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 94,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"grad.transpose(-1, -2).shape"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 97,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"14"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 97,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"grad.size(-1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 73,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(5, 5)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 73,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"inp.shape"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 88,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"tensor([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 0., 0.],\n",
|
||||||
|
" [ -5., 5., 1., 6., 6., -5., 0., -5., -6., 0., -5., 0.,\n",
|
||||||
|
" 0., 0.],\n",
|
||||||
|
" [-10., -2., 1., 11., 11., 2., 12., -10., -11., -12., -10., -12.,\n",
|
||||||
|
" 0., 0.],\n",
|
||||||
|
" [-15., -2., 1., -2., 16., 2., -1., 3., 2., 1., -15., -17.,\n",
|
||||||
|
" -18., 0.],\n",
|
||||||
|
" [-20., -2., 1., -2., -3., 2., -1., 3., 2., 1., 4., 2.,\n",
|
||||||
|
" 1., -24.]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 88,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"gg(torch.tensor(inp, dtype=torch.float32))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 80,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"torch.Size([5, 5])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 80,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"torch.tensor(inp).shape \n",
|
||||||
|
"# (grad_T.T)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 81,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\u001b[1;31mDocstring:\u001b[0m\n",
|
||||||
|
"matmul(input, other, *, out=None) -> Tensor\n",
|
||||||
|
"\n",
|
||||||
|
"Matrix product of two tensors.\n",
|
||||||
|
"\n",
|
||||||
|
"The behavior depends on the dimensionality of the tensors as follows:\n",
|
||||||
|
"\n",
|
||||||
|
"- If both tensors are 1-dimensional, the dot product (scalar) is returned.\n",
|
||||||
|
"- If both arguments are 2-dimensional, the matrix-matrix product is returned.\n",
|
||||||
|
"- If the first argument is 1-dimensional and the second argument is 2-dimensional,\n",
|
||||||
|
" a 1 is prepended to its dimension for the purpose of the matrix multiply.\n",
|
||||||
|
" After the matrix multiply, the prepended dimension is removed.\n",
|
||||||
|
"- If the first argument is 2-dimensional and the second argument is 1-dimensional,\n",
|
||||||
|
" the matrix-vector product is returned.\n",
|
||||||
|
"- If both arguments are at least 1-dimensional and at least one argument is\n",
|
||||||
|
" N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first\n",
|
||||||
|
" argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the\n",
|
||||||
|
" batched matrix multiply and removed after. If the second argument is 1-dimensional, a\n",
|
||||||
|
" 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.\n",
|
||||||
|
" The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus\n",
|
||||||
|
" must be broadcastable). For example, if :attr:`input` is a\n",
|
||||||
|
" :math:`(j \\times 1 \\times n \\times n)` tensor and :attr:`other` is a :math:`(k \\times n \\times n)`\n",
|
||||||
|
" tensor, :attr:`out` will be a :math:`(j \\times k \\times n \\times n)` tensor.\n",
|
||||||
|
"\n",
|
||||||
|
" Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs\n",
|
||||||
|
" are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a\n",
|
||||||
|
" :math:`(j \\times 1 \\times n \\times m)` tensor and :attr:`other` is a :math:`(k \\times m \\times p)`\n",
|
||||||
|
" tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the\n",
|
||||||
|
" matrix dimensions) are different. :attr:`out` will be a :math:`(j \\times k \\times n \\times p)` tensor.\n",
|
||||||
|
"\n",
|
||||||
|
"This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.\n",
|
||||||
|
"\n",
|
||||||
|
".. note::\n",
|
||||||
|
"\n",
|
||||||
|
" The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.\n",
|
||||||
|
"\n",
|
||||||
|
"Arguments:\n",
|
||||||
|
" input (Tensor): the first tensor to be multiplied\n",
|
||||||
|
" other (Tensor): the second tensor to be multiplied\n",
|
||||||
|
"\n",
|
||||||
|
"Keyword args:\n",
|
||||||
|
" out (Tensor, optional): the output tensor.\n",
|
||||||
|
"\n",
|
||||||
|
"Example::\n",
|
||||||
|
"\n",
|
||||||
|
" >>> # vector x vector\n",
|
||||||
|
" >>> tensor1 = torch.randn(3)\n",
|
||||||
|
" >>> tensor2 = torch.randn(3)\n",
|
||||||
|
" >>> torch.matmul(tensor1, tensor2).size()\n",
|
||||||
|
" torch.Size([])\n",
|
||||||
|
" >>> # matrix x vector\n",
|
||||||
|
" >>> tensor1 = torch.randn(3, 4)\n",
|
||||||
|
" >>> tensor2 = torch.randn(4)\n",
|
||||||
|
" >>> torch.matmul(tensor1, tensor2).size()\n",
|
||||||
|
" torch.Size([3])\n",
|
||||||
|
" >>> # batched matrix x broadcasted vector\n",
|
||||||
|
" >>> tensor1 = torch.randn(10, 3, 4)\n",
|
||||||
|
" >>> tensor2 = torch.randn(4)\n",
|
||||||
|
" >>> torch.matmul(tensor1, tensor2).size()\n",
|
||||||
|
" torch.Size([10, 3])\n",
|
||||||
|
" >>> # batched matrix x batched matrix\n",
|
||||||
|
" >>> tensor1 = torch.randn(10, 3, 4)\n",
|
||||||
|
" >>> tensor2 = torch.randn(10, 4, 5)\n",
|
||||||
|
" >>> torch.matmul(tensor1, tensor2).size()\n",
|
||||||
|
" torch.Size([10, 3, 5])\n",
|
||||||
|
" >>> # batched matrix x broadcasted matrix\n",
|
||||||
|
" >>> tensor1 = torch.randn(10, 3, 4)\n",
|
||||||
|
" >>> tensor2 = torch.randn(4, 5)\n",
|
||||||
|
" >>> torch.matmul(tensor1, tensor2).size()\n",
|
||||||
|
" torch.Size([10, 3, 5])\n",
|
||||||
|
"\u001b[1;31mType:\u001b[0m builtin_function_or_method\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"torch.matmul?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "1b89aa55be347d0b8cc51b3a166e8002614a385bd8cff32165269c80e70c12a7"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.8.5 64-bit ('base': conda)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.5"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue