Added dcrnn_cell

Rough implementation complete - could forward pass it through the network

Ensured sparse mm for readability, logging sparsely as well

moving tensors to GPU

moving tensors to GPU [v2]

moving tensors to GPU [v3]

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

ensured row major ordering

fixed log message
This commit is contained in:
Chintan Shah 2019-10-06 11:55:02 -04:00
parent e80c47390d
commit d1964672c2
3 changed files with 74 additions and 85 deletions

View File

@ -1,21 +1,22 @@
from typing import Optional
import numpy as np
import torch
from torch import Tensor
from lib import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LayerParams:
def __init__(self, rnn_network: torch.nn.RNN, type: str):
def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
self._rnn_network = rnn_network
self._params_dict = {}
self._biases_dict = {}
self._type = type
self._type = layer_type
def get_weights(self, shape):
if shape not in self._params_dict:
nn_param = torch.nn.init.xavier_normal(torch.empty(*shape))
nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
torch.nn.init.xavier_normal_(nn_param)
self._params_dict[shape] = nn_param
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
nn_param)
@ -23,7 +24,8 @@ class LayerParams:
def get_biases(self, length, bias_start=0.0):
if length not in self._biases_dict:
biases = torch.nn.init.constant(torch.empty(length), bias_start)
biases = torch.nn.Parameter(torch.empty(length, device=device))
torch.nn.init.constant_(biases, bias_start)
self._biases_dict[length] = biases
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
biases)
@ -31,32 +33,24 @@ class LayerParams:
return self._biases_dict[length]
class DCGRUCell(torch.nn.RNN):
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_size: int,
hidden_size: int,
num_layers: int = 1,
num_proj=None,
nonlinearity='tanh', filter_type="laplacian", use_gc_for_ru=True):
class DCGRUCell(torch.nn.Module):
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
filter_type="laplacian", use_gc_for_ru=True):
"""
:param num_units:
:param adj_mx:
:param max_diffusion_step:
:param num_nodes:
:param input_size:
:param num_proj:
:param nonlinearity:
:param filter_type: "laplacian", "random_walk", "dual_random_walk".
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
"""
super(DCGRUCell, self).__init__(input_size, hidden_size, bias=True,
# bias param does not exist in tf code?
num_layers=num_layers,
nonlinearity=nonlinearity)
super().__init__()
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
# support other nonlinearities up here?
self._num_nodes = num_nodes
self._num_proj = num_proj
self._num_units = num_units
self._max_diffusion_step = max_diffusion_step
self._supports = []
@ -73,23 +67,19 @@ class DCGRUCell(torch.nn.RNN):
supports.append(utils.calculate_scaled_laplacian(adj_mx))
for support in supports:
self._supports.append(self._build_sparse_matrix(support))
self._proj_weights = torch.nn.Parameter(torch.randn(self._num_units, self._num_proj))
self._fc_params = LayerParams(self, 'fc')
self._gconv_params = LayerParams(self, 'gconv')
@property
def state_size(self):
return self._num_nodes * self._num_units
@staticmethod
def _build_sparse_matrix(L):
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
return L
@property
def output_size(self):
output_size = self._num_nodes * self._num_units
if self._num_proj is not None:
output_size = self._num_nodes * self._num_proj
return output_size
def forward(self, input: Tensor, hx: Optional[Tensor] = ...):
def forward(self, inputs, hx):
"""Gated recurrent unit (GRU) with Graph Convolution.
:param input: (B, num_nodes * input_dim)
@ -99,28 +89,22 @@ class DCGRUCell(torch.nn.RNN):
the arity and shapes of `state`
"""
output_size = 2 * self._num_units
# We start with bias of 1.0 to not reset and not update.
if self._use_gc_for_ru:
fn = self._gconv
else:
fn = self._fc
value = torch.sigmoid(fn(input, hx, output_size, bias_start=1.0))
value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0))
value = torch.reshape(value, (-1, self._num_nodes, output_size))
r, u = torch.split(tensor=value, split_size_or_sections=2, dim=-1)
r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1)
r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
u = torch.reshape(u, (-1, self._num_nodes * self._num_units))
c = self._gconv(input, r * hx, self._num_units)
c = self._gconv(inputs, r * hx, self._num_units)
if self._activation is not None:
c = self._activation(c)
output = new_state = u * hx + (1 - u) * c
if self._num_proj is not None:
batch_size = input.shape[0]
output = torch.reshape(new_state, shape=(-1, self._num_units))
output = torch.reshape(torch.matmul(output, self._proj_weights),
shape=(batch_size, self.output_size))
return output, new_state
new_state = u * hx + (1.0 - u) * c
return new_state
@staticmethod
def _concat(x, x_):
@ -153,8 +137,7 @@ class DCGRUCell(torch.nn.RNN):
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs_and_state.shape[2].value
dtype = inputs.dtype
input_size = inputs_and_state.size(2)
x = inputs_and_state
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
@ -165,12 +148,11 @@ class DCGRUCell(torch.nn.RNN):
pass
else:
for support in self._supports:
# https://discuss.pytorch.org/t/sparse-x-dense-dense-matrix-multiplication/6116/7
x1 = torch.mm(support, x0)
x1 = torch.sparse.mm(support, x0) # this is not reordered, does this work - todo
x = self._concat(x, x1)
for k in range(2, self._max_diffusion_step + 1):
x2 = 2 * torch.mm(support, x1) - x0
x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1

View File

@ -2,6 +2,10 @@ import numpy as np
import torch
import torch.nn as nn
from model.pytorch.dcrnn_cell import DCGRUCell
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Seq2SeqAttrs:
def __init__(self, adj_mx, **model_kwargs):
@ -9,7 +13,6 @@ class Seq2SeqAttrs:
self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
self.filter_type = model_kwargs.get('filter_type', 'laplacian')
# self.max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0))
self.num_nodes = int(model_kwargs.get('num_nodes', 1))
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
self.rnn_units = int(model_kwargs.get('rnn_units'))
@ -18,19 +21,13 @@ class Seq2SeqAttrs:
class EncoderModel(nn.Module, Seq2SeqAttrs):
def __init__(self, adj_mx, **model_kwargs):
# super().__init__(is_training, adj_mx, **model_kwargs)
# https://pytorch.org/docs/stable/nn.html#gru
nn.Module.__init__(self)
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
self.input_dim = int(model_kwargs.get('input_dim', 1))
self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.input_dim,
hidden_size=self.hidden_state_size,
bias=True)] + [
nn.GRUCell(input_size=self.hidden_state_size,
hidden_size=self.hidden_state_size,
bias=True) for _ in
range(self.num_rnn_layers - 1)])
self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None):
"""
@ -45,7 +42,8 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
"""
batch_size, _ = inputs.size()
if hidden_state is None:
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
device=device)
hidden_states = []
output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
@ -63,14 +61,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
self.output_dim = int(model_kwargs.get('output_dim', 1))
self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder
self.projection_layer = nn.Linear(self.hidden_state_size, self.num_nodes * self.output_dim)
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.output_dim,
hidden_size=self.hidden_state_size,
bias=True)] + [
nn.GRUCell(input_size=self.hidden_state_size,
hidden_size=self.hidden_state_size,
bias=True) for _ in
range(self.num_rnn_layers - 1)])
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None):
"""
@ -90,7 +84,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
hidden_states.append(next_hidden_state)
output = next_hidden_state
return self.projection_layer(output), torch.stack(hidden_states)
projected = self.projection_layer(output.view(-1, self.rnn_units))
output = projected.view(-1, self.num_nodes * self.output_dim)
return output, torch.stack(hidden_states)
class DCRNNModel(nn.Module, Seq2SeqAttrs):
@ -128,7 +125,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
"""
batch_size = encoder_hidden_state.size(1)
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim))
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim),
device=device)
decoder_hidden_state = encoder_hidden_state
decoder_input = go_symbol
@ -155,7 +153,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
"""
encoder_hidden_state = self.encoder(inputs)
self._logger.info("Encoder complete, starting decoder")
self._logger.debug("Encoder complete, starting decoder")
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
self._logger.info("Decoder complete")
self._logger.debug("Decoder complete")
return outputs

View File

@ -7,6 +7,8 @@ import torch
from lib import utils
from model.pytorch.dcrnn_model import DCRNNModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DCRNNSupervisor:
def __init__(self, adj_mx, **kwargs):
@ -75,7 +77,7 @@ class DCRNNSupervisor:
config['model_state_dict'] = self.dcrnn_model.state_dict()
config['epoch'] = epoch
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
self._logger.info("Loaded model at {}".format(epoch))
self._logger.info("Saved model at {}".format(epoch))
return self._log_dir + 'models/epo%d.tar' % epoch
def load_model(self, epoch):
@ -102,8 +104,7 @@ class DCRNNSupervisor:
criterion = torch.nn.L1Loss()
for _, (x, y) in enumerate(val_iterator):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
x, y = self._prepare_data(x, y)
output = self.dcrnn_model(x)
loss = self._compute_loss(y, output, criterion)
@ -128,6 +129,7 @@ class DCRNNSupervisor:
self.dcrnn_model = self.dcrnn_model.train()
self._logger.info('Start training ...')
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
for epoch_num in range(epochs):
train_iterator = self._data['train_loader'].get_iterator()
losses = []
@ -137,12 +139,13 @@ class DCRNNSupervisor:
for _, (x, y) in enumerate(train_iterator):
optimizer.zero_grad()
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
x, y = self._prepare_data(x, y)
output = self.dcrnn_model(x, y, batches_seen)
loss = self._compute_loss(y, output, criterion)
self._logger.info(loss.item())
self._logger.debug(loss.item())
losses.append(loss.item())
batches_seen += 1
@ -152,40 +155,46 @@ class DCRNNSupervisor:
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
optimizer.step()
self._logger.info("epoch complete")
lr_scheduler.step()
self._logger.info("evaluating now!")
val_loss = self.evaluate(dataset='val')
end_time = time.time()
if epoch_num % log_every == 0:
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), val_loss, lr_scheduler.get_lr(),
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if epoch_num % test_every_n_epochs == 0:
test_loss = self.evaluate(dataset='test')
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), test_loss, lr_scheduler.get_lr(),
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if val_loss < min_val_loss:
wait = 0
min_val_loss = val_loss
if save_model:
model_file_name = self.save_model(epoch_num)
self._logger.info(
'Val loss decrease from {:.4f} to {:.4f}, '
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
min_val_loss = val_loss
elif val_loss >= min_val_loss:
wait += 1
if wait == patience:
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(device), y.to(device)
def _get_x_y(self, x, y):
"""
:param x: shape (batch_size, seq_len, num_sensor, input_dim)