diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index ac7850f..d8f045c 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -6,24 +6,27 @@ from torch import Tensor from lib import utils -class FCLayerParams: - def __init__(self, rnn_network: torch.nn.RNN): +class LayerParams: + def __init__(self, rnn_network: torch.nn.RNN, type: str): self._rnn_network = rnn_network self._params_dict = {} self._biases_dict = {} + self._type = type def get_weights(self, shape): if shape not in self._params_dict: nn_param = torch.nn.init.xavier_normal(torch.empty(*shape)) self._params_dict[shape] = nn_param - self._rnn_network.register_parameter('fc_weight_{}'.format(str(shape)), nn_param) + self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), + nn_param) return self._params_dict[shape] def get_biases(self, length, bias_start=0.0): if length not in self._biases_dict: biases = torch.nn.init.constant(torch.empty(length), bias_start) self._biases_dict[length] = biases - self._rnn_network.register_parameter('fc_biases_{}'.format(str(length)), biases) + self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), + biases) return self._biases_dict[length] @@ -72,7 +75,8 @@ class DCGRUCell(torch.nn.RNN): self._supports.append(self._build_sparse_matrix(support)) self._proj_weights = torch.nn.Parameter(torch.randn(self._num_units, self._num_proj)) - self._fc_params = FCLayerParams(self) + self._fc_params = LayerParams(self, 'fc') + self._gconv_params = LayerParams(self, 'gconv') @property def state_size(self): @@ -133,4 +137,52 @@ class DCGRUCell(torch.nn.RNN): value = torch.sigmoid(torch.matmul(inputs_and_state, weights)) biases = self._fc_params.get_biases(output_size, bias_start) value += biases - return value \ No newline at end of file + return value + + def _gconv(self, inputs, state, output_size, bias_start=0.0): + """Graph convolution between input and the graph matrix. + + :param args: a 2D Tensor or a list of 2D, batch x n, Tensors. + :param output_size: + :param bias: + :param bias_start: + :return: + """ + # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) + batch_size = inputs.shape[0] + inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) + state = torch.reshape(state, (batch_size, self._num_nodes, -1)) + inputs_and_state = torch.cat([inputs, state], dim=2) + input_size = inputs_and_state.shape[2].value + dtype = inputs.dtype + + x = inputs_and_state + x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) + x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) + x = torch.unsqueeze(x0, 0) + + if self._max_diffusion_step == 0: + pass + else: + for support in self._supports: + # https://discuss.pytorch.org/t/sparse-x-dense-dense-matrix-multiplication/6116/7 + x1 = torch.mm(support, x0) + x = self._concat(x, x1) + + for k in range(2, self._max_diffusion_step + 1): + x2 = 2 * torch.mm(support, x1) - x0 + x = self._concat(x, x2) + x1, x0 = x2, x1 + + num_matrices = len(self._supports) * self._max_diffusion_step + 1 # Adds for x itself. + x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) + x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) + x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) + + weights = self._gconv_params.get_weights((input_size * num_matrices, output_size)) + x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) + + biases = self._gconv_params.get_biases(output_size, bias_start) + x += biases + # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) + return torch.reshape(x, [batch_size, self._num_nodes * output_size])