Added dcrnn_cell
Rough implementation complete - could forward pass it through the network Ensured sparse mm for readability, logging sparsely as well moving tensors to GPU moving tensors to GPU [v2] moving tensors to GPU [v3] logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor logging and refactor ensured row major ordering fixed log message
This commit is contained in:
parent
e80c47390d
commit
d1964672c2
|
|
@ -1,21 +1,22 @@
|
||||||
from typing import Optional
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lib import utils
|
from lib import utils
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
class LayerParams:
|
class LayerParams:
|
||||||
def __init__(self, rnn_network: torch.nn.RNN, type: str):
|
def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
|
||||||
self._rnn_network = rnn_network
|
self._rnn_network = rnn_network
|
||||||
self._params_dict = {}
|
self._params_dict = {}
|
||||||
self._biases_dict = {}
|
self._biases_dict = {}
|
||||||
self._type = type
|
self._type = layer_type
|
||||||
|
|
||||||
def get_weights(self, shape):
|
def get_weights(self, shape):
|
||||||
if shape not in self._params_dict:
|
if shape not in self._params_dict:
|
||||||
nn_param = torch.nn.init.xavier_normal(torch.empty(*shape))
|
nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
|
||||||
|
torch.nn.init.xavier_normal_(nn_param)
|
||||||
self._params_dict[shape] = nn_param
|
self._params_dict[shape] = nn_param
|
||||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||||
nn_param)
|
nn_param)
|
||||||
|
|
@ -23,7 +24,8 @@ class LayerParams:
|
||||||
|
|
||||||
def get_biases(self, length, bias_start=0.0):
|
def get_biases(self, length, bias_start=0.0):
|
||||||
if length not in self._biases_dict:
|
if length not in self._biases_dict:
|
||||||
biases = torch.nn.init.constant(torch.empty(length), bias_start)
|
biases = torch.nn.Parameter(torch.empty(length, device=device))
|
||||||
|
torch.nn.init.constant_(biases, bias_start)
|
||||||
self._biases_dict[length] = biases
|
self._biases_dict[length] = biases
|
||||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||||
biases)
|
biases)
|
||||||
|
|
@ -31,32 +33,24 @@ class LayerParams:
|
||||||
return self._biases_dict[length]
|
return self._biases_dict[length]
|
||||||
|
|
||||||
|
|
||||||
class DCGRUCell(torch.nn.RNN):
|
class DCGRUCell(torch.nn.Module):
|
||||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_size: int,
|
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
|
||||||
hidden_size: int,
|
filter_type="laplacian", use_gc_for_ru=True):
|
||||||
num_layers: int = 1,
|
|
||||||
num_proj=None,
|
|
||||||
nonlinearity='tanh', filter_type="laplacian", use_gc_for_ru=True):
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param num_units:
|
:param num_units:
|
||||||
:param adj_mx:
|
:param adj_mx:
|
||||||
:param max_diffusion_step:
|
:param max_diffusion_step:
|
||||||
:param num_nodes:
|
:param num_nodes:
|
||||||
:param input_size:
|
|
||||||
:param num_proj:
|
|
||||||
:param nonlinearity:
|
:param nonlinearity:
|
||||||
:param filter_type: "laplacian", "random_walk", "dual_random_walk".
|
:param filter_type: "laplacian", "random_walk", "dual_random_walk".
|
||||||
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||||
"""
|
"""
|
||||||
super(DCGRUCell, self).__init__(input_size, hidden_size, bias=True,
|
|
||||||
# bias param does not exist in tf code?
|
super().__init__()
|
||||||
num_layers=num_layers,
|
|
||||||
nonlinearity=nonlinearity)
|
|
||||||
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||||
# support other nonlinearities up here?
|
# support other nonlinearities up here?
|
||||||
self._num_nodes = num_nodes
|
self._num_nodes = num_nodes
|
||||||
self._num_proj = num_proj
|
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
self._max_diffusion_step = max_diffusion_step
|
self._max_diffusion_step = max_diffusion_step
|
||||||
self._supports = []
|
self._supports = []
|
||||||
|
|
@ -73,23 +67,19 @@ class DCGRUCell(torch.nn.RNN):
|
||||||
supports.append(utils.calculate_scaled_laplacian(adj_mx))
|
supports.append(utils.calculate_scaled_laplacian(adj_mx))
|
||||||
for support in supports:
|
for support in supports:
|
||||||
self._supports.append(self._build_sparse_matrix(support))
|
self._supports.append(self._build_sparse_matrix(support))
|
||||||
|
|
||||||
self._proj_weights = torch.nn.Parameter(torch.randn(self._num_units, self._num_proj))
|
|
||||||
self._fc_params = LayerParams(self, 'fc')
|
self._fc_params = LayerParams(self, 'fc')
|
||||||
self._gconv_params = LayerParams(self, 'gconv')
|
self._gconv_params = LayerParams(self, 'gconv')
|
||||||
|
|
||||||
@property
|
@staticmethod
|
||||||
def state_size(self):
|
def _build_sparse_matrix(L):
|
||||||
return self._num_nodes * self._num_units
|
L = L.tocoo()
|
||||||
|
indices = np.column_stack((L.row, L.col))
|
||||||
|
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||||
|
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||||
|
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||||
|
return L
|
||||||
|
|
||||||
@property
|
def forward(self, inputs, hx):
|
||||||
def output_size(self):
|
|
||||||
output_size = self._num_nodes * self._num_units
|
|
||||||
if self._num_proj is not None:
|
|
||||||
output_size = self._num_nodes * self._num_proj
|
|
||||||
return output_size
|
|
||||||
|
|
||||||
def forward(self, input: Tensor, hx: Optional[Tensor] = ...):
|
|
||||||
"""Gated recurrent unit (GRU) with Graph Convolution.
|
"""Gated recurrent unit (GRU) with Graph Convolution.
|
||||||
:param input: (B, num_nodes * input_dim)
|
:param input: (B, num_nodes * input_dim)
|
||||||
|
|
||||||
|
|
@ -99,28 +89,22 @@ class DCGRUCell(torch.nn.RNN):
|
||||||
the arity and shapes of `state`
|
the arity and shapes of `state`
|
||||||
"""
|
"""
|
||||||
output_size = 2 * self._num_units
|
output_size = 2 * self._num_units
|
||||||
# We start with bias of 1.0 to not reset and not update.
|
|
||||||
if self._use_gc_for_ru:
|
if self._use_gc_for_ru:
|
||||||
fn = self._gconv
|
fn = self._gconv
|
||||||
else:
|
else:
|
||||||
fn = self._fc
|
fn = self._fc
|
||||||
value = torch.sigmoid(fn(input, hx, output_size, bias_start=1.0))
|
value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0))
|
||||||
value = torch.reshape(value, (-1, self._num_nodes, output_size))
|
value = torch.reshape(value, (-1, self._num_nodes, output_size))
|
||||||
r, u = torch.split(tensor=value, split_size_or_sections=2, dim=-1)
|
r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1)
|
||||||
r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
|
r = torch.reshape(r, (-1, self._num_nodes * self._num_units))
|
||||||
u = torch.reshape(u, (-1, self._num_nodes * self._num_units))
|
u = torch.reshape(u, (-1, self._num_nodes * self._num_units))
|
||||||
|
|
||||||
c = self._gconv(input, r * hx, self._num_units)
|
c = self._gconv(inputs, r * hx, self._num_units)
|
||||||
if self._activation is not None:
|
if self._activation is not None:
|
||||||
c = self._activation(c)
|
c = self._activation(c)
|
||||||
|
|
||||||
output = new_state = u * hx + (1 - u) * c
|
new_state = u * hx + (1.0 - u) * c
|
||||||
if self._num_proj is not None:
|
return new_state
|
||||||
batch_size = input.shape[0]
|
|
||||||
output = torch.reshape(new_state, shape=(-1, self._num_units))
|
|
||||||
output = torch.reshape(torch.matmul(output, self._proj_weights),
|
|
||||||
shape=(batch_size, self.output_size))
|
|
||||||
return output, new_state
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _concat(x, x_):
|
def _concat(x, x_):
|
||||||
|
|
@ -153,8 +137,7 @@ class DCGRUCell(torch.nn.RNN):
|
||||||
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||||
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||||
inputs_and_state = torch.cat([inputs, state], dim=2)
|
inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||||
input_size = inputs_and_state.shape[2].value
|
input_size = inputs_and_state.size(2)
|
||||||
dtype = inputs.dtype
|
|
||||||
|
|
||||||
x = inputs_and_state
|
x = inputs_and_state
|
||||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||||
|
|
@ -165,12 +148,11 @@ class DCGRUCell(torch.nn.RNN):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
for support in self._supports:
|
for support in self._supports:
|
||||||
# https://discuss.pytorch.org/t/sparse-x-dense-dense-matrix-multiplication/6116/7
|
x1 = torch.sparse.mm(support, x0) # this is not reordered, does this work - todo
|
||||||
x1 = torch.mm(support, x0)
|
|
||||||
x = self._concat(x, x1)
|
x = self._concat(x, x1)
|
||||||
|
|
||||||
for k in range(2, self._max_diffusion_step + 1):
|
for k in range(2, self._max_diffusion_step + 1):
|
||||||
x2 = 2 * torch.mm(support, x1) - x0
|
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||||
x = self._concat(x, x2)
|
x = self._concat(x, x2)
|
||||||
x1, x0 = x2, x1
|
x1, x0 = x2, x1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,10 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from model.pytorch.dcrnn_cell import DCGRUCell
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqAttrs:
|
class Seq2SeqAttrs:
|
||||||
def __init__(self, adj_mx, **model_kwargs):
|
def __init__(self, adj_mx, **model_kwargs):
|
||||||
|
|
@ -9,7 +13,6 @@ class Seq2SeqAttrs:
|
||||||
self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
|
self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
|
||||||
self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
|
self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
|
||||||
self.filter_type = model_kwargs.get('filter_type', 'laplacian')
|
self.filter_type = model_kwargs.get('filter_type', 'laplacian')
|
||||||
# self.max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0))
|
|
||||||
self.num_nodes = int(model_kwargs.get('num_nodes', 1))
|
self.num_nodes = int(model_kwargs.get('num_nodes', 1))
|
||||||
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
||||||
self.rnn_units = int(model_kwargs.get('rnn_units'))
|
self.rnn_units = int(model_kwargs.get('rnn_units'))
|
||||||
|
|
@ -18,19 +21,13 @@ class Seq2SeqAttrs:
|
||||||
|
|
||||||
class EncoderModel(nn.Module, Seq2SeqAttrs):
|
class EncoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
def __init__(self, adj_mx, **model_kwargs):
|
def __init__(self, adj_mx, **model_kwargs):
|
||||||
# super().__init__(is_training, adj_mx, **model_kwargs)
|
|
||||||
# https://pytorch.org/docs/stable/nn.html#gru
|
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||||
self.input_dim = int(model_kwargs.get('input_dim', 1))
|
self.input_dim = int(model_kwargs.get('input_dim', 1))
|
||||||
self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder
|
self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder
|
||||||
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.input_dim,
|
self.dcgru_layers = nn.ModuleList(
|
||||||
hidden_size=self.hidden_state_size,
|
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||||
bias=True)] + [
|
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||||
nn.GRUCell(input_size=self.hidden_state_size,
|
|
||||||
hidden_size=self.hidden_state_size,
|
|
||||||
bias=True) for _ in
|
|
||||||
range(self.num_rnn_layers - 1)])
|
|
||||||
|
|
||||||
def forward(self, inputs, hidden_state=None):
|
def forward(self, inputs, hidden_state=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -45,7 +42,8 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
"""
|
"""
|
||||||
batch_size, _ = inputs.size()
|
batch_size, _ = inputs.size()
|
||||||
if hidden_state is None:
|
if hidden_state is None:
|
||||||
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
|
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
|
||||||
|
device=device)
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
output = inputs
|
output = inputs
|
||||||
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
|
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
|
||||||
|
|
@ -63,14 +61,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||||
self.output_dim = int(model_kwargs.get('output_dim', 1))
|
self.output_dim = int(model_kwargs.get('output_dim', 1))
|
||||||
self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder
|
self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder
|
||||||
self.projection_layer = nn.Linear(self.hidden_state_size, self.num_nodes * self.output_dim)
|
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
|
||||||
self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.output_dim,
|
self.dcgru_layers = nn.ModuleList(
|
||||||
hidden_size=self.hidden_state_size,
|
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||||
bias=True)] + [
|
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||||
nn.GRUCell(input_size=self.hidden_state_size,
|
|
||||||
hidden_size=self.hidden_state_size,
|
|
||||||
bias=True) for _ in
|
|
||||||
range(self.num_rnn_layers - 1)])
|
|
||||||
|
|
||||||
def forward(self, inputs, hidden_state=None):
|
def forward(self, inputs, hidden_state=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -90,7 +84,10 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
hidden_states.append(next_hidden_state)
|
hidden_states.append(next_hidden_state)
|
||||||
output = next_hidden_state
|
output = next_hidden_state
|
||||||
|
|
||||||
return self.projection_layer(output), torch.stack(hidden_states)
|
projected = self.projection_layer(output.view(-1, self.rnn_units))
|
||||||
|
output = projected.view(-1, self.num_nodes * self.output_dim)
|
||||||
|
|
||||||
|
return output, torch.stack(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
|
|
@ -128,7 +125,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||||
"""
|
"""
|
||||||
batch_size = encoder_hidden_state.size(1)
|
batch_size = encoder_hidden_state.size(1)
|
||||||
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim))
|
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim),
|
||||||
|
device=device)
|
||||||
decoder_hidden_state = encoder_hidden_state
|
decoder_hidden_state = encoder_hidden_state
|
||||||
decoder_input = go_symbol
|
decoder_input = go_symbol
|
||||||
|
|
||||||
|
|
@ -155,7 +153,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||||
"""
|
"""
|
||||||
encoder_hidden_state = self.encoder(inputs)
|
encoder_hidden_state = self.encoder(inputs)
|
||||||
self._logger.info("Encoder complete, starting decoder")
|
self._logger.debug("Encoder complete, starting decoder")
|
||||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||||
self._logger.info("Decoder complete")
|
self._logger.debug("Decoder complete")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ import torch
|
||||||
from lib import utils
|
from lib import utils
|
||||||
from model.pytorch.dcrnn_model import DCRNNModel
|
from model.pytorch.dcrnn_model import DCRNNModel
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
class DCRNNSupervisor:
|
class DCRNNSupervisor:
|
||||||
def __init__(self, adj_mx, **kwargs):
|
def __init__(self, adj_mx, **kwargs):
|
||||||
|
|
@ -75,7 +77,7 @@ class DCRNNSupervisor:
|
||||||
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
||||||
config['epoch'] = epoch
|
config['epoch'] = epoch
|
||||||
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
||||||
self._logger.info("Loaded model at {}".format(epoch))
|
self._logger.info("Saved model at {}".format(epoch))
|
||||||
return self._log_dir + 'models/epo%d.tar' % epoch
|
return self._log_dir + 'models/epo%d.tar' % epoch
|
||||||
|
|
||||||
def load_model(self, epoch):
|
def load_model(self, epoch):
|
||||||
|
|
@ -102,8 +104,7 @@ class DCRNNSupervisor:
|
||||||
criterion = torch.nn.L1Loss()
|
criterion = torch.nn.L1Loss()
|
||||||
|
|
||||||
for _, (x, y) in enumerate(val_iterator):
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
x, y = self._get_x_y(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
|
||||||
|
|
||||||
output = self.dcrnn_model(x)
|
output = self.dcrnn_model(x)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
|
|
@ -128,6 +129,7 @@ class DCRNNSupervisor:
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
||||||
self._logger.info('Start training ...')
|
self._logger.info('Start training ...')
|
||||||
|
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
|
||||||
for epoch_num in range(epochs):
|
for epoch_num in range(epochs):
|
||||||
train_iterator = self._data['train_loader'].get_iterator()
|
train_iterator = self._data['train_loader'].get_iterator()
|
||||||
losses = []
|
losses = []
|
||||||
|
|
@ -137,12 +139,13 @@ class DCRNNSupervisor:
|
||||||
for _, (x, y) in enumerate(train_iterator):
|
for _, (x, y) in enumerate(train_iterator):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
x, y = self._get_x_y(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
|
||||||
|
|
||||||
output = self.dcrnn_model(x, y, batches_seen)
|
output = self.dcrnn_model(x, y, batches_seen)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
self._logger.info(loss.item())
|
|
||||||
|
self._logger.debug(loss.item())
|
||||||
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
batches_seen += 1
|
batches_seen += 1
|
||||||
|
|
@ -152,40 +155,46 @@ class DCRNNSupervisor:
|
||||||
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
self._logger.info("epoch complete")
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
self._logger.info("evaluating now!")
|
||||||
val_loss = self.evaluate(dataset='val')
|
val_loss = self.evaluate(dataset='val')
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
if epoch_num % log_every == 0:
|
if epoch_num % log_every == 0:
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
|
||||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), val_loss, lr_scheduler.get_lr(),
|
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
|
||||||
(end_time - start_time))
|
(end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
if epoch_num % test_every_n_epochs == 0:
|
if epoch_num % test_every_n_epochs == 0:
|
||||||
test_loss = self.evaluate(dataset='test')
|
test_loss = self.evaluate(dataset='test')
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
||||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), test_loss, lr_scheduler.get_lr(),
|
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
||||||
(end_time - start_time))
|
(end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
if val_loss < min_val_loss:
|
if val_loss < min_val_loss:
|
||||||
wait = 0
|
wait = 0
|
||||||
min_val_loss = val_loss
|
|
||||||
if save_model:
|
if save_model:
|
||||||
model_file_name = self.save_model(epoch_num)
|
model_file_name = self.save_model(epoch_num)
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
'Val loss decrease from {:.4f} to {:.4f}, '
|
'Val loss decrease from {:.4f} to {:.4f}, '
|
||||||
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
||||||
|
min_val_loss = val_loss
|
||||||
|
|
||||||
elif val_loss >= min_val_loss:
|
elif val_loss >= min_val_loss:
|
||||||
wait += 1
|
wait += 1
|
||||||
if wait == patience:
|
if wait == patience:
|
||||||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def _prepare_data(self, x, y):
|
||||||
|
x, y = self._get_x_y(x, y)
|
||||||
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||||
|
return x.to(device), y.to(device)
|
||||||
|
|
||||||
def _get_x_y(self, x, y):
|
def _get_x_y(self, x, y):
|
||||||
"""
|
"""
|
||||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue