Added dcrnn_cell

Rough implementation complete - could forward pass it through the network

Ensured sparse mm for readability, logging sparsely as well

moving tensors to GPU

moving tensors to GPU [v2]

moving tensors to GPU [v3]

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

logging and refactor

ensured row major ordering

fixed log message
This commit is contained in:
Chintan Shah 2019-10-06 11:55:02 -04:00
parent e80c47390d
commit d1964672c2
3 changed files with 74 additions and 85 deletions

View File

@ -1,21 +1,22 @@
from typing import Optional import numpy as np
import torch import torch
from torch import Tensor
from lib import utils from lib import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LayerParams: class LayerParams:
def __init__(self, rnn_network: torch.nn.RNN, type: str): def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
self._rnn_network = rnn_network self._rnn_network = rnn_network
self._params_dict = {} self._params_dict = {}
self._biases_dict = {} self._biases_dict = {}
self._type = type self._type = layer_type
def get_weights(self, shape): def get_weights(self, shape):
if shape not in self._params_dict: if shape not in self._params_dict:
nn_param = torch.nn.init.xavier_normal(torch.empty(*shape)) nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
torch.nn.init.xavier_normal_(nn_param)
self._params_dict[shape] = nn_param self._params_dict[shape] = nn_param
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
nn_param) nn_param)
@ -23,7 +24,8 @@ class LayerParams:
def get_biases(self, length, bias_start=0.0): def get_biases(self, length, bias_start=0.0):
if length not in self._biases_dict: if length not in self._biases_dict:
biases = torch.nn.init.constant(torch.empty(length), bias_start) biases = torch.nn.Parameter(torch.empty(length, device=device))
torch.nn.init.constant_(biases, bias_start)
self._biases_dict[length] = biases self._biases_dict[length] = biases
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
biases) biases)
@ -31,32 +33,24 @@ class LayerParams:
return self._biases_dict[length] return self._biases_dict[length]
class DCGRUCell(torch.nn.RNN): class DCGRUCell(torch.nn.Module):
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_size: int, def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
hidden_size: int, filter_type="laplacian", use_gc_for_ru=True):
num_layers: int = 1,
num_proj=None,
nonlinearity='tanh', filter_type="laplacian", use_gc_for_ru=True):
""" """
:param num_units: :param num_units:
:param adj_mx: :param adj_mx:
:param max_diffusion_step: :param max_diffusion_step:
:param num_nodes: :param num_nodes:
:param input_size:
:param num_proj:
:param nonlinearity: :param nonlinearity:
:param filter_type: "laplacian", "random_walk", "dual_random_walk". :param filter_type: "laplacian", "random_walk", "dual_random_walk".
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
""" """
super(DCGRUCell, self).__init__(input_size, hidden_size, bias=True,
# bias param does not exist in tf code? super().__init__()
num_layers=num_layers,
nonlinearity=nonlinearity)
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
# support other nonlinearities up here? # support other nonlinearities up here?
self._num_nodes = num_nodes self._num_nodes = num_nodes
self._num_proj = num_proj
self._num_units = num_units self._num_units = num_units
self._max_diffusion_step = max_diffusion_step self._max_diffusion_step = max_diffusion_step
self._supports = [] self._supports = []
@ -73,23 +67,19 @@ class DCGRUCell(torch.nn.RNN):
supports.append(utils.calculate_scaled_laplacian(adj_mx)) supports.append(utils.calculate_scaled_laplacian(adj_mx))
for support in supports: for support in supports:
self._supports.append(self._build_sparse_matrix(support)) self._supports.append(self._build_sparse_matrix(support))
self._proj_weights = torch.nn.Parameter(torch.randn(self._num_units, self._num_proj))
self._fc_params = LayerParams(self, 'fc') self._fc_params = LayerParams(self, 'fc')
self._gconv_params = LayerParams(self, 'gconv') self._gconv_params = LayerParams(self, 'gconv')
@property @staticmethod
def state_size(self): def _build_sparse_matrix(L):
return self._num_nodes * self._num_units L = L.tocoo()
indices = np.column_stack((L.row, L.col))
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
return L
@property def forward(self, inputs, hx):
def output_size(self):
output_size = self._num_nodes * self._num_units
if self._num_proj is not None:
output_size = self._num_nodes * self._num_proj
return output_size
def forward(self, input: Tensor, hx: Optional[Tensor] = ...):
"""Gated recurrent unit (GRU) with Graph Convolution. """Gated recurrent unit (GRU) with Graph Convolution.
:param input: (B, num_nodes * input_dim) :param input: (B, num_nodes * input_dim)
@ -99,28 +89,22 @@ class DCGRUCell(torch.nn.RNN):
the arity and shapes of `state` the arity and shapes of `state`
""" """
output_size = 2 * self._num_units output_size = 2 * self._num_units
# We start with bias of 1.0 to not reset and not update.
if self._use_gc_for_ru: if self._use_gc_for_ru:
fn = self._gconv fn = self._gconv
else: else:
fn = self._fc fn = self._fc
value = torch.sigmoid(fn(input, hx, output_size, bias_start=1.0)) value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0))
value = torch.reshape(value, (-1, self._num_nodes, output_size)) value = torch.reshape(value, (-1, self._num_nodes, output_size))
r, u = torch.split(tensor=value, split_size_or_sections=2, dim=-1) r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1)
r = torch.reshape(r, (-1, self._num_nodes * self._num_units)) r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
u = torch.reshape(u, (-1, self._num_nodes * self._num_units)) u = torch.reshape(u, (-1, self._num_nodes * self._num_units))
c = self._gconv(input, r * hx, self._num_units) c = self._gconv(inputs, r * hx, self._num_units)
if self._activation is not None: if self._activation is not None:
c = self._activation(c) c = self._activation(c)
output = new_state = u * hx + (1 - u) * c new_state = u * hx + (1.0 - u) * c
if self._num_proj is not None: return new_state
batch_size = input.shape[0]
output = torch.reshape(new_state, shape=(-1, self._num_units))
output = torch.reshape(torch.matmul(output, self._proj_weights),
shape=(batch_size, self.output_size))
return output, new_state
@staticmethod @staticmethod
def _concat(x, x_): def _concat(x, x_):
@ -153,8 +137,7 @@ class DCGRUCell(torch.nn.RNN):
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
state = torch.reshape(state, (batch_size, self._num_nodes, -1)) state = torch.reshape(state, (batch_size, self._num_nodes, -1))
inputs_and_state = torch.cat([inputs, state], dim=2) inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs_and_state.shape[2].value input_size = inputs_and_state.size(2)
dtype = inputs.dtype
x = inputs_and_state x = inputs_and_state
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
@ -165,12 +148,11 @@ class DCGRUCell(torch.nn.RNN):
pass pass
else: else:
for support in self._supports: for support in self._supports:
# https://discuss.pytorch.org/t/sparse-x-dense-dense-matrix-multiplication/6116/7 x1 = torch.sparse.mm(support, x0) # this is not reordered, does this work - todo
x1 = torch.mm(support, x0)
x = self._concat(x, x1) x = self._concat(x, x1)
for k in range(2, self._max_diffusion_step + 1): for k in range(2, self._max_diffusion_step + 1):
x2 = 2 * torch.mm(support, x1) - x0 x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2) x = self._concat(x, x2)
x1, x0 = x2, x1 x1, x0 = x2, x1

View File

@ -2,6 +2,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from model.pytorch.dcrnn_cell import DCGRUCell
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Seq2SeqAttrs: class Seq2SeqAttrs:
def __init__(self, adj_mx, **model_kwargs): def __init__(self, adj_mx, **model_kwargs):
@ -9,7 +13,6 @@ class Seq2SeqAttrs:
self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
self.filter_type = model_kwargs.get('filter_type', 'laplacian') self.filter_type = model_kwargs.get('filter_type', 'laplacian')
# self.max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0))
self.num_nodes = int(model_kwargs.get('num_nodes', 1)) self.num_nodes = int(model_kwargs.get('num_nodes', 1))
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
self.rnn_units = int(model_kwargs.get('rnn_units')) self.rnn_units = int(model_kwargs.get('rnn_units'))
@ -18,19 +21,13 @@ class Seq2SeqAttrs:
class EncoderModel(nn.Module, Seq2SeqAttrs): class EncoderModel(nn.Module, Seq2SeqAttrs):
def __init__(self, adj_mx, **model_kwargs): def __init__(self, adj_mx, **model_kwargs):
# super().__init__(is_training, adj_mx, **model_kwargs)
# https://pytorch.org/docs/stable/nn.html#gru
nn.Module.__init__(self) nn.Module.__init__(self)
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
self.input_dim = int(model_kwargs.get('input_dim', 1)) self.input_dim = int(model_kwargs.get('input_dim', 1))
self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.input_dim, self.dcgru_layers = nn.ModuleList(
hidden_size=self.hidden_state_size, [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
bias=True)] + [ filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
nn.GRUCell(input_size=self.hidden_state_size,
hidden_size=self.hidden_state_size,
bias=True) for _ in
range(self.num_rnn_layers - 1)])
def forward(self, inputs, hidden_state=None): def forward(self, inputs, hidden_state=None):
""" """
@ -45,7 +42,8 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
""" """
batch_size, _ = inputs.size() batch_size, _ = inputs.size()
if hidden_state is None: if hidden_state is None:
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size)) hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
device=device)
hidden_states = [] hidden_states = []
output = inputs output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers): for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
@ -63,14 +61,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
self.output_dim = int(model_kwargs.get('output_dim', 1)) self.output_dim = int(model_kwargs.get('output_dim', 1))
self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder
self.projection_layer = nn.Linear(self.hidden_state_size, self.num_nodes * self.output_dim) self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.output_dim, self.dcgru_layers = nn.ModuleList(
hidden_size=self.hidden_state_size, [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
bias=True)] + [ filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
nn.GRUCell(input_size=self.hidden_state_size,
hidden_size=self.hidden_state_size,
bias=True) for _ in
range(self.num_rnn_layers - 1)])
def forward(self, inputs, hidden_state=None): def forward(self, inputs, hidden_state=None):
""" """
@ -90,7 +84,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
hidden_states.append(next_hidden_state) hidden_states.append(next_hidden_state)
output = next_hidden_state output = next_hidden_state
return self.projection_layer(output), torch.stack(hidden_states) projected = self.projection_layer(output.view(-1, self.rnn_units))
output = projected.view(-1, self.num_nodes * self.output_dim)
return output, torch.stack(hidden_states)
class DCRNNModel(nn.Module, Seq2SeqAttrs): class DCRNNModel(nn.Module, Seq2SeqAttrs):
@ -128,7 +125,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
""" """
batch_size = encoder_hidden_state.size(1) batch_size = encoder_hidden_state.size(1)
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim)) go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim),
device=device)
decoder_hidden_state = encoder_hidden_state decoder_hidden_state = encoder_hidden_state
decoder_input = go_symbol decoder_input = go_symbol
@ -155,7 +153,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
""" """
encoder_hidden_state = self.encoder(inputs) encoder_hidden_state = self.encoder(inputs)
self._logger.info("Encoder complete, starting decoder") self._logger.debug("Encoder complete, starting decoder")
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
self._logger.info("Decoder complete") self._logger.debug("Decoder complete")
return outputs return outputs

View File

@ -7,6 +7,8 @@ import torch
from lib import utils from lib import utils
from model.pytorch.dcrnn_model import DCRNNModel from model.pytorch.dcrnn_model import DCRNNModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DCRNNSupervisor: class DCRNNSupervisor:
def __init__(self, adj_mx, **kwargs): def __init__(self, adj_mx, **kwargs):
@ -75,7 +77,7 @@ class DCRNNSupervisor:
config['model_state_dict'] = self.dcrnn_model.state_dict() config['model_state_dict'] = self.dcrnn_model.state_dict()
config['epoch'] = epoch config['epoch'] = epoch
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch) torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
self._logger.info("Loaded model at {}".format(epoch)) self._logger.info("Saved model at {}".format(epoch))
return self._log_dir + 'models/epo%d.tar' % epoch return self._log_dir + 'models/epo%d.tar' % epoch
def load_model(self, epoch): def load_model(self, epoch):
@ -102,8 +104,7 @@ class DCRNNSupervisor:
criterion = torch.nn.L1Loss() criterion = torch.nn.L1Loss()
for _, (x, y) in enumerate(val_iterator): for _, (x, y) in enumerate(val_iterator):
x, y = self._get_x_y(x, y) x, y = self._prepare_data(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
output = self.dcrnn_model(x) output = self.dcrnn_model(x)
loss = self._compute_loss(y, output, criterion) loss = self._compute_loss(y, output, criterion)
@ -128,6 +129,7 @@ class DCRNNSupervisor:
self.dcrnn_model = self.dcrnn_model.train() self.dcrnn_model = self.dcrnn_model.train()
self._logger.info('Start training ...') self._logger.info('Start training ...')
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
for epoch_num in range(epochs): for epoch_num in range(epochs):
train_iterator = self._data['train_loader'].get_iterator() train_iterator = self._data['train_loader'].get_iterator()
losses = [] losses = []
@ -137,12 +139,13 @@ class DCRNNSupervisor:
for _, (x, y) in enumerate(train_iterator): for _, (x, y) in enumerate(train_iterator):
optimizer.zero_grad() optimizer.zero_grad()
x, y = self._get_x_y(x, y) x, y = self._prepare_data(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
output = self.dcrnn_model(x, y, batches_seen) output = self.dcrnn_model(x, y, batches_seen)
loss = self._compute_loss(y, output, criterion) loss = self._compute_loss(y, output, criterion)
self._logger.info(loss.item())
self._logger.debug(loss.item())
losses.append(loss.item()) losses.append(loss.item())
batches_seen += 1 batches_seen += 1
@ -152,40 +155,46 @@ class DCRNNSupervisor:
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
optimizer.step() optimizer.step()
self._logger.info("epoch complete")
lr_scheduler.step() lr_scheduler.step()
self._logger.info("evaluating now!")
val_loss = self.evaluate(dataset='val') val_loss = self.evaluate(dataset='val')
end_time = time.time() end_time = time.time()
if epoch_num % log_every == 0: if epoch_num % log_every == 0:
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \ message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen, '{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), val_loss, lr_scheduler.get_lr(), np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
(end_time - start_time)) (end_time - start_time))
self._logger.info(message) self._logger.info(message)
if epoch_num % test_every_n_epochs == 0: if epoch_num % test_every_n_epochs == 0:
test_loss = self.evaluate(dataset='test') test_loss = self.evaluate(dataset='test')
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \ message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen, '{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), test_loss, lr_scheduler.get_lr(), np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
(end_time - start_time)) (end_time - start_time))
self._logger.info(message) self._logger.info(message)
if val_loss < min_val_loss: if val_loss < min_val_loss:
wait = 0 wait = 0
min_val_loss = val_loss
if save_model: if save_model:
model_file_name = self.save_model(epoch_num) model_file_name = self.save_model(epoch_num)
self._logger.info( self._logger.info(
'Val loss decrease from {:.4f} to {:.4f}, ' 'Val loss decrease from {:.4f} to {:.4f}, '
'saving to {}'.format(min_val_loss, val_loss, model_file_name)) 'saving to {}'.format(min_val_loss, val_loss, model_file_name))
min_val_loss = val_loss
elif val_loss >= min_val_loss: elif val_loss >= min_val_loss:
wait += 1 wait += 1
if wait == patience: if wait == patience:
self._logger.warning('Early stopping at epoch: %d' % epoch_num) self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break break
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(device), y.to(device)
def _get_x_y(self, x, y): def _get_x_y(self, x, y):
""" """
:param x: shape (batch_size, seq_len, num_sensor, input_dim) :param x: shape (batch_size, seq_len, num_sensor, input_dim)