From 6292c781c586bf0a6acb2fbadf9503a7bba5343f Mon Sep 17 00:00:00 2001 From: Siju <sijusamuel@gmail.com> Date: Sat, 22 Sep 2018 23:00:41 +0530 Subject: [PATCH] [NNVM]Keras SimpleRnn and GRU support (#1729) --- nnvm/python/nnvm/frontend/keras.py | 98 +++++++++++++++++-- .../python/frontend/keras/test_forward.py | 54 ++++++++++ 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index 07f1ce502..a1e089b21 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -28,6 +28,10 @@ def _get_elu(insym, alpha): """ return -alpha * _sym.relu(1 - _sym.exp(insym)) + _sym.relu(insym) +def _convert_recurrent_activation(insym, keras_layer): + act_type = keras_layer.recurrent_activation.__name__ + return _convert_activation(insym, act_type, None) + def _convert_activation(insym, keras_layer, _): if isinstance(keras_layer, str): act_type = keras_layer @@ -420,16 +424,96 @@ def _convert_lstm(insym, keras_layer, symtab): ixh2 = _sym.dense(in_state_h, recurrent_wt, in_bias, use_bias=True, units=units) gate = ixh1 + ixh2 gates = _sym.split(gate, indices_or_sections=4, axis=1) - in_gate = _sym.sigmoid(gates[0]) - in_transform = _sym.sigmoid(gates[1]) - next_c = in_transform * in_state_c + in_gate * _sym.tanh(gates[2]) - out_gate = _sym.sigmoid(gates[3]) - next_h = out_gate * _sym.tanh(next_c) + in_gate = _convert_recurrent_activation(gates[0], keras_layer) + in_transform = _convert_recurrent_activation(gates[1], keras_layer) + next_c = in_transform * in_state_c + in_gate * _convert_activation(gates[2], keras_layer, None) + out_gate = _convert_recurrent_activation(gates[3], keras_layer) + next_h = out_gate * _convert_activation(next_c, keras_layer, None) out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) out = _sym.reshape(next_h, shape=out_shape) return [out, next_h, next_c] +def _convert_simple_rnn(insym, keras_layer, symtab): + _check_data_format(keras_layer) + if not isinstance(insym, list): + buffer = np.zeros((1, keras_layer.units), 'float32') + prev_sym = symtab.new_const(buffer) + insym = [insym, prev_sym] + in_data = insym[0] + prev_sym = insym[1] + + weightList = keras_layer.get_weights() + kernel_wt = symtab.new_const(weightList[0].transpose([1, 0])) + recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0])) + in_bias = symtab.new_const(weightList[2]) + units = list(weightList[0].shape)[1] + + in_data = _sym.flatten(in_data) + ixh = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units) + prev_sym = _sym.flatten(prev_sym) + ixh2 = _sym.dense(prev_sym, recurrent_wt, use_bias=False, units=units) + output = ixh + ixh2 + output = _convert_activation(output, keras_layer, None) + + out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) + output = _sym.reshape(output, shape=out_shape) + + return [output, output] + +def _convert_gru(insym, keras_layer, symtab): + _check_data_format(keras_layer) + if not isinstance(insym, list): + buffer = np.zeros((1, keras_layer.units), 'float32') + h_tm1 = symtab.new_const(buffer) + insym = [insym, h_tm1] + in_data = insym[0] + h_tm1_sym = insym[1] + + weightList = keras_layer.get_weights() + kernel_wt = symtab.new_const(weightList[0].transpose([1, 0])) + recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0])) + in_bias = symtab.new_const(weightList[2]) + + units = list(weightList[0].shape)[1] + + in_data = _sym.flatten(in_data) + matrix_x = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units) + + # inputs projected by all gate matrices at once + split_indices = [keras_layer.units, 2 * keras_layer.units] + gates = _sym.split(matrix_x, indices_or_sections=split_indices, axis=1) + x_z = gates[0] + x_r = gates[1] + x_h = gates[2] + + # hidden state projected separately for update/reset and new + units = 2 * keras_layer.units + split_indices = [units] + rec_wts = _sym.split(recurrent_wt, indices_or_sections=split_indices, axis=0) + + h_tm1_sym = _sym.flatten(h_tm1_sym) + matrix_inner = _sym.dense(h_tm1_sym, rec_wts[0], use_bias=False, units=units) + + split_indices = [keras_layer.units] + recurrent = _sym.split(matrix_inner, indices_or_sections=split_indices, axis=1) + recurrent_z = recurrent[0] + recurrent_r = recurrent[1] + + rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer) + rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer) + + units = keras_layer.units + recurrent_h = _sym.dense(rec_act_r * h_tm1_sym, rec_wts[1], use_bias=False, units=units) + act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None) + + # previous and candidate state mixed by update gate + output = rec_act_z * h_tm1_sym + (1 - rec_act_z) * act_hh + + out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) + output = _sym.reshape(output, shape=out_shape) + return [output, output] + def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument """Layers that can be skipped because they are train time only.""" return insym @@ -475,9 +559,9 @@ _convert_map = { # 'UpSampling3D' : _convert_upsample, # 'Conv1D' : _convert_convolution1d, - # 'GRU' : _convert_gru, + 'SimpleRNN' : _convert_simple_rnn, 'LSTM' : _convert_lstm, - # 'SimpleRNN' : _convert_simple_rnn, + 'GRU' : _convert_gru, # 'Bidirectional' : _convert_bidirectional, # 'TimeDistributed' : _default_skip, diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py index 459be8737..2e1c378d2 100644 --- a/nnvm/tests/python/frontend/keras/test_forward.py +++ b/nnvm/tests/python/frontend/keras/test_forward.py @@ -254,6 +254,58 @@ def test_forward_LSTM(): _test_LSTM(4, 4, return_state=False) _test_LSTM_MultiLayer(4, 4) +def _test_RNN(inputs, units): + data = keras.layers.Input(shape=(1, inputs)) + rnn_out = keras.layers.SimpleRNN(units, return_state=True, + activation='tanh') + x = rnn_out(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def _test_RNN_MultiLayer(inputs, units): + inputs = keras.layers.Input(shape=(1, inputs)) + layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True, + activation='tanh') + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state) + keras_model = keras.models.Model(inputs, output) + verify_keras_frontend(keras_model, need_transpose=False) + +def test_forward_RNN(): + _test_RNN(2, 4) + _test_RNN(4, 3) + _test_RNN_MultiLayer(4, 12) + +def _test_GRU(inputs, units): + data = keras.layers.Input(shape=(1, inputs)) + gru_out = keras.layers.GRU(units, + return_state=True, + recurrent_activation='sigmoid', + activation='tanh') + x = gru_out(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def _test_GRU_MultiLayer(inputs, units): + inputs = keras.layers.Input(shape=(1, inputs)) + layer = keras.layers.GRU(units, + return_state=True, + return_sequences=True, + recurrent_activation='sigmoid', + activation='tanh') + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + output = keras.layers.GRU(units, recurrent_activation='sigmoid', + activation='tanh')(output, initial_state=state) + keras_model = keras.models.Model(inputs, output) + verify_keras_frontend(keras_model, need_transpose=False) + +def test_forward_GRU(): + _test_GRU(2, 4) + _test_GRU(4, 3) + _test_GRU_MultiLayer(4, 4) + if __name__ == '__main__': test_forward_elemwise_add() test_forward_activations() @@ -272,3 +324,5 @@ if __name__ == '__main__': test_forward_multi_outputs() test_forward_reuse_layers() test_forward_LSTM() + test_forward_RNN() + test_forward_GRU() -- GitLab