Added dcrnn_cell
Rough implementation complete - could forward pass it through the network Ensured sparse mm for readability, logging sparsely as well moving tensors to GPU moving tensors to GPU [v2] moving tensors to GPU [v3] logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor ensured row major ordering fixed log message
This commit is contained in:
parent
e80c47390d
commit
d1964672c2
|
|
@ -1,21 +1,22 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lib import utils
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class LayerParams:
|
||||
def __init__(self, rnn_network: torch.nn.RNN, type: str):
|
||||
def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
|
||||
self._rnn_network = rnn_network
|
||||
self._params_dict = {}
|
||||
self._biases_dict = {}
|
||||
self._type = type
|
||||
self._type = layer_type
|
||||
|
||||
def get_weights(self, shape):
|
||||
if shape not in self._params_dict:
|
||||
nn_param = torch.nn.init.xavier_normal(torch.empty(*shape))
|
||||
nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
|
||||
torch.nn.init.xavier_normal_(nn_param)
|
||||
self._params_dict[shape] = nn_param
|
||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||
nn_param)
|
||||
|
|
@ -23,7 +24,8 @@ class LayerParams:
|
|||
|
||||
def get_biases(self, length, bias_start=0.0):
|
||||
if length not in self._biases_dict:
|
||||
biases = torch.nn.init.constant(torch.empty(length), bias_start)
|
||||
biases = torch.nn.Parameter(torch.empty(length, device=device))
|
||||
torch.nn.init.constant_(biases, bias_start)
|
||||
self._biases_dict[length] = biases
|
||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||
biases)
|
||||
|
|
@ -31,32 +33,24 @@ class LayerParams:
|
|||
return self._biases_dict[length]
|
||||
|
||||
|
||||
class DCGRUCell(torch.nn.RNN):
|
||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int = 1,
|
||||
num_proj=None,
|
||||
nonlinearity='tanh', filter_type="laplacian", use_gc_for_ru=True):
|
||||
class DCGRUCell(torch.nn.Module):
|
||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
|
||||
filter_type="laplacian", use_gc_for_ru=True):
|
||||
"""
|
||||
|
||||
:param num_units:
|
||||
:param adj_mx:
|
||||
:param max_diffusion_step:
|
||||
:param num_nodes:
|
||||
:param input_size:
|
||||
:param num_proj:
|
||||
:param nonlinearity:
|
||||
:param filter_type: "laplacian", "random_walk", "dual_random_walk".
|
||||
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||
"""
|
||||
super(DCGRUCell, self).__init__(input_size, hidden_size, bias=True,
|
||||
# bias param does not exist in tf code?
|
||||
num_layers=num_layers,
|
||||
nonlinearity=nonlinearity)
|
||||
|
||||
super().__init__()
|
||||
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||
# support other nonlinearities up here?
|
||||
self._num_nodes = num_nodes
|
||||
self._num_proj = num_proj
|
||||
self._num_units = num_units
|
||||
self._max_diffusion_step = max_diffusion_step
|
||||
self._supports = []
|
||||
|
|
@ -73,23 +67,19 @@ class DCGRUCell(torch.nn.RNN):
|
|||
supports.append(utils.calculate_scaled_laplacian(adj_mx))
|
||||
for support in supports:
|
||||
self._supports.append(self._build_sparse_matrix(support))
|
||||
|
||||
self._proj_weights = torch.nn.Parameter(torch.randn(self._num_units, self._num_proj))
|
||||
self._fc_params = LayerParams(self, 'fc')
|
||||
self._gconv_params = LayerParams(self, 'gconv')
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_nodes * self._num_units
|
||||
@staticmethod
|
||||
def _build_sparse_matrix(L):
|
||||
L = L.tocoo()
|
||||
indices = np.column_stack((L.row, L.col))
|
||||
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||
return L
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
output_size = self._num_nodes * self._num_units
|
||||
if self._num_proj is not None:
|
||||
output_size = self._num_nodes * self._num_proj
|
||||
return output_size
|
||||
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = ...):
|
||||
def forward(self, inputs, hx):
|
||||
"""Gated recurrent unit (GRU) with Graph Convolution.
|
||||
:param input: (B, num_nodes * input_dim)
|
||||
|
||||
|
|
@ -99,28 +89,22 @@ class DCGRUCell(torch.nn.RNN):
|
|||
the arity and shapes of `state`
|
||||
"""
|
||||
output_size = 2 * self._num_units
|
||||
# We start with bias of 1.0 to not reset and not update.
|
||||
if self._use_gc_for_ru:
|
||||
fn = self._gconv
|
||||
else:
|
||||
fn = self._fc
|
||||
value = torch.sigmoid(fn(input, hx, output_size, bias_start=1.0))
|
||||
value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0))
|
||||
value = torch.reshape(value, (-1, self._num_nodes, output_size))
|
||||
r, u = torch.split(tensor=value, split_size_or_sections=2, dim=-1)
|
||||
r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1)
|
||||
r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
|
||||
u = torch.reshape(u, (-1, self._num_nodes * self._num_units))
|
||||
|
||||
c = self._gconv(input, r * hx, self._num_units)
|
||||
c = self._gconv(inputs, r * hx, self._num_units)
|
||||
if self._activation is not None:
|
||||
c = self._activation(c)
|
||||
|
||||
output = new_state = u * hx + (1 - u) * c
|
||||
if self._num_proj is not None:
|
||||
batch_size = input.shape[0]
|
||||
output = torch.reshape(new_state, shape=(-1, self._num_units))
|
||||
output = torch.reshape(torch.matmul(output, self._proj_weights),
|
||||
shape=(batch_size, self.output_size))
|
||||
return output, new_state
|
||||
new_state = u * hx + (1.0 - u) * c
|
||||
return new_state
|
||||
|
||||
@staticmethod
|
||||
def _concat(x, x_):
|
||||
|
|
@ -153,8 +137,7 @@ class DCGRUCell(torch.nn.RNN):
|
|||
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||
inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||
input_size = inputs_and_state.shape[2].value
|
||||
dtype = inputs.dtype
|
||||
input_size = inputs_and_state.size(2)
|
||||
|
||||
x = inputs_and_state
|
||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||
|
|
@ -165,12 +148,11 @@ class DCGRUCell(torch.nn.RNN):
|
|||
pass
|
||||
else:
|
||||
for support in self._supports:
|
||||
# https://discuss.pytorch.org/t/sparse-x-dense-dense-matrix-multiplication/6116/7
|
||||
x1 = torch.mm(support, x0)
|
||||
x1 = torch.sparse.mm(support, x0) # this is not reordered, does this work - todo
|
||||
x = self._concat(x, x1)
|
||||
|
||||
for k in range(2, self._max_diffusion_step + 1):
|
||||
x2 = 2 * torch.mm(support, x1) - x0
|
||||
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||
x = self._concat(x, x2)
|
||||
x1, x0 = x2, x1
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model.pytorch.dcrnn_cell import DCGRUCell
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class Seq2SeqAttrs:
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
|
|
@ -9,7 +13,6 @@ class Seq2SeqAttrs:
|
|||
self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
|
||||
self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
|
||||
self.filter_type = model_kwargs.get('filter_type', 'laplacian')
|
||||
# self.max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0))
|
||||
self.num_nodes = int(model_kwargs.get('num_nodes', 1))
|
||||
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
||||
self.rnn_units = int(model_kwargs.get('rnn_units'))
|
||||
|
|
@ -18,19 +21,13 @@ class Seq2SeqAttrs:
|
|||
|
||||
class EncoderModel(nn.Module, Seq2SeqAttrs):
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
# super().__init__(is_training, adj_mx, **model_kwargs)
|
||||
# https://pytorch.org/docs/stable/nn.html#gru
|
||||
nn.Module.__init__(self)
|
||||
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||
self.input_dim = int(model_kwargs.get('input_dim', 1))
|
||||
self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder
|
||||
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.input_dim,
|
||||
hidden_size=self.hidden_state_size,
|
||||
bias=True)] + [
|
||||
nn.GRUCell(input_size=self.hidden_state_size,
|
||||
hidden_size=self.hidden_state_size,
|
||||
bias=True) for _ in
|
||||
range(self.num_rnn_layers - 1)])
|
||||
self.dcgru_layers = nn.ModuleList(
|
||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
|
||||
def forward(self, inputs, hidden_state=None):
|
||||
"""
|
||||
|
|
@ -45,7 +42,8 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
|
|||
"""
|
||||
batch_size, _ = inputs.size()
|
||||
if hidden_state is None:
|
||||
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
|
||||
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
|
||||
device=device)
|
||||
hidden_states = []
|
||||
output = inputs
|
||||
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
|
||||
|
|
@ -63,14 +61,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
|||
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||
self.output_dim = int(model_kwargs.get('output_dim', 1))
|
||||
self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder
|
||||
self.projection_layer = nn.Linear(self.hidden_state_size, self.num_nodes * self.output_dim)
|
||||
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.output_dim,
|
||||
hidden_size=self.hidden_state_size,
|
||||
bias=True)] + [
|
||||
nn.GRUCell(input_size=self.hidden_state_size,
|
||||
hidden_size=self.hidden_state_size,
|
||||
bias=True) for _ in
|
||||
range(self.num_rnn_layers - 1)])
|
||||
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
|
||||
self.dcgru_layers = nn.ModuleList(
|
||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
|
||||
def forward(self, inputs, hidden_state=None):
|
||||
"""
|
||||
|
|
@ -90,7 +84,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
|||
hidden_states.append(next_hidden_state)
|
||||
output = next_hidden_state
|
||||
|
||||
return self.projection_layer(output), torch.stack(hidden_states)
|
||||
projected = self.projection_layer(output.view(-1, self.rnn_units))
|
||||
output = projected.view(-1, self.num_nodes * self.output_dim)
|
||||
|
||||
return output, torch.stack(hidden_states)
|
||||
|
||||
|
||||
class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||
|
|
@ -128,7 +125,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||
"""
|
||||
batch_size = encoder_hidden_state.size(1)
|
||||
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim))
|
||||
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim),
|
||||
device=device)
|
||||
decoder_hidden_state = encoder_hidden_state
|
||||
decoder_input = go_symbol
|
||||
|
||||
|
|
@ -155,7 +153,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||
"""
|
||||
encoder_hidden_state = self.encoder(inputs)
|
||||
self._logger.info("Encoder complete, starting decoder")
|
||||
self._logger.debug("Encoder complete, starting decoder")
|
||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||
self._logger.info("Decoder complete")
|
||||
self._logger.debug("Decoder complete")
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import torch
|
|||
from lib import utils
|
||||
from model.pytorch.dcrnn_model import DCRNNModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class DCRNNSupervisor:
|
||||
def __init__(self, adj_mx, **kwargs):
|
||||
|
|
@ -75,7 +77,7 @@ class DCRNNSupervisor:
|
|||
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
||||
config['epoch'] = epoch
|
||||
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
||||
self._logger.info("Loaded model at {}".format(epoch))
|
||||
self._logger.info("Saved model at {}".format(epoch))
|
||||
return self._log_dir + 'models/epo%d.tar' % epoch
|
||||
|
||||
def load_model(self, epoch):
|
||||
|
|
@ -102,8 +104,7 @@ class DCRNNSupervisor:
|
|||
criterion = torch.nn.L1Loss()
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output = self.dcrnn_model(x)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
|
|
@ -128,6 +129,7 @@ class DCRNNSupervisor:
|
|||
self.dcrnn_model = self.dcrnn_model.train()
|
||||
|
||||
self._logger.info('Start training ...')
|
||||
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
|
||||
for epoch_num in range(epochs):
|
||||
train_iterator = self._data['train_loader'].get_iterator()
|
||||
losses = []
|
||||
|
|
@ -137,12 +139,13 @@ class DCRNNSupervisor:
|
|||
for _, (x, y) in enumerate(train_iterator):
|
||||
optimizer.zero_grad()
|
||||
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output = self.dcrnn_model(x, y, batches_seen)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
self._logger.info(loss.item())
|
||||
|
||||
self._logger.debug(loss.item())
|
||||
|
||||
losses.append(loss.item())
|
||||
|
||||
batches_seen += 1
|
||||
|
|
@ -152,40 +155,46 @@ class DCRNNSupervisor:
|
|||
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._logger.info("epoch complete")
|
||||
lr_scheduler.step()
|
||||
|
||||
self._logger.info("evaluating now!")
|
||||
val_loss = self.evaluate(dataset='val')
|
||||
end_time = time.time()
|
||||
if epoch_num % log_every == 0:
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
|
||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||
np.mean(losses), val_loss, lr_scheduler.get_lr(),
|
||||
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
|
||||
(end_time - start_time))
|
||||
self._logger.info(message)
|
||||
|
||||
if epoch_num % test_every_n_epochs == 0:
|
||||
test_loss = self.evaluate(dataset='test')
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||
np.mean(losses), test_loss, lr_scheduler.get_lr(),
|
||||
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
||||
(end_time - start_time))
|
||||
self._logger.info(message)
|
||||
|
||||
if val_loss < min_val_loss:
|
||||
wait = 0
|
||||
min_val_loss = val_loss
|
||||
if save_model:
|
||||
model_file_name = self.save_model(epoch_num)
|
||||
self._logger.info(
|
||||
'Val loss decrease from {:.4f} to {:.4f}, '
|
||||
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
||||
min_val_loss = val_loss
|
||||
|
||||
elif val_loss >= min_val_loss:
|
||||
wait += 1
|
||||
if wait == patience:
|
||||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||
break
|
||||
|
||||
def _prepare_data(self, x, y):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
return x.to(device), y.to(device)
|
||||
|
||||
def _get_x_y(self, x, y):
|
||||
"""
|
||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue