From 6292c781c586bf0a6acb2fbadf9503a7bba5343f Mon Sep 17 00:00:00 2001
From: Siju <sijusamuel@gmail.com>
Date: Sat, 22 Sep 2018 23:00:41 +0530
Subject: [PATCH] [NNVM]Keras SimpleRnn and GRU support (#1729)

---
 nnvm/python/nnvm/frontend/keras.py            | 98 +++++++++++++++++--
 .../python/frontend/keras/test_forward.py     | 54 ++++++++++
 2 files changed, 145 insertions(+), 7 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py
index 07f1ce502..a1e089b21 100644
--- a/nnvm/python/nnvm/frontend/keras.py
+++ b/nnvm/python/nnvm/frontend/keras.py
@@ -28,6 +28,10 @@ def _get_elu(insym, alpha):
     """
     return -alpha * _sym.relu(1 - _sym.exp(insym)) + _sym.relu(insym)
 
+def _convert_recurrent_activation(insym, keras_layer):
+    act_type = keras_layer.recurrent_activation.__name__
+    return _convert_activation(insym, act_type, None)
+
 def _convert_activation(insym, keras_layer, _):
     if isinstance(keras_layer, str):
         act_type = keras_layer
@@ -420,16 +424,96 @@ def _convert_lstm(insym, keras_layer, symtab):
     ixh2 = _sym.dense(in_state_h, recurrent_wt, in_bias, use_bias=True, units=units)
     gate = ixh1 + ixh2
     gates = _sym.split(gate, indices_or_sections=4, axis=1)
-    in_gate = _sym.sigmoid(gates[0])
-    in_transform = _sym.sigmoid(gates[1])
-    next_c = in_transform * in_state_c + in_gate * _sym.tanh(gates[2])
-    out_gate = _sym.sigmoid(gates[3])
-    next_h = out_gate * _sym.tanh(next_c)
+    in_gate = _convert_recurrent_activation(gates[0], keras_layer)
+    in_transform = _convert_recurrent_activation(gates[1], keras_layer)
+    next_c = in_transform * in_state_c + in_gate * _convert_activation(gates[2], keras_layer, None)
+    out_gate = _convert_recurrent_activation(gates[3], keras_layer)
+    next_h = out_gate * _convert_activation(next_c, keras_layer, None)
 
     out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
     out = _sym.reshape(next_h, shape=out_shape)
     return [out, next_h, next_c]
 
+def _convert_simple_rnn(insym, keras_layer, symtab):
+    _check_data_format(keras_layer)
+    if not isinstance(insym, list):
+        buffer = np.zeros((1, keras_layer.units), 'float32')
+        prev_sym = symtab.new_const(buffer)
+        insym = [insym, prev_sym]
+    in_data = insym[0]
+    prev_sym = insym[1]
+
+    weightList = keras_layer.get_weights()
+    kernel_wt = symtab.new_const(weightList[0].transpose([1, 0]))
+    recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0]))
+    in_bias = symtab.new_const(weightList[2])
+    units = list(weightList[0].shape)[1]
+
+    in_data = _sym.flatten(in_data)
+    ixh = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units)
+    prev_sym = _sym.flatten(prev_sym)
+    ixh2 = _sym.dense(prev_sym, recurrent_wt, use_bias=False, units=units)
+    output = ixh + ixh2
+    output = _convert_activation(output, keras_layer, None)
+
+    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
+    output = _sym.reshape(output, shape=out_shape)
+
+    return [output, output]
+
+def _convert_gru(insym, keras_layer, symtab):
+    _check_data_format(keras_layer)
+    if not isinstance(insym, list):
+        buffer = np.zeros((1, keras_layer.units), 'float32')
+        h_tm1 = symtab.new_const(buffer)
+        insym = [insym, h_tm1]
+    in_data = insym[0]
+    h_tm1_sym = insym[1]
+
+    weightList = keras_layer.get_weights()
+    kernel_wt = symtab.new_const(weightList[0].transpose([1, 0]))
+    recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0]))
+    in_bias = symtab.new_const(weightList[2])
+
+    units = list(weightList[0].shape)[1]
+
+    in_data = _sym.flatten(in_data)
+    matrix_x = _sym.dense(in_data, kernel_wt, in_bias, use_bias=True, units=units)
+
+    # inputs projected by all gate matrices at once
+    split_indices = [keras_layer.units, 2 * keras_layer.units]
+    gates = _sym.split(matrix_x, indices_or_sections=split_indices, axis=1)
+    x_z = gates[0]
+    x_r = gates[1]
+    x_h = gates[2]
+
+    # hidden state projected separately for update/reset and new
+    units = 2 * keras_layer.units
+    split_indices = [units]
+    rec_wts = _sym.split(recurrent_wt, indices_or_sections=split_indices, axis=0)
+
+    h_tm1_sym = _sym.flatten(h_tm1_sym)
+    matrix_inner = _sym.dense(h_tm1_sym, rec_wts[0], use_bias=False, units=units)
+
+    split_indices = [keras_layer.units]
+    recurrent = _sym.split(matrix_inner, indices_or_sections=split_indices, axis=1)
+    recurrent_z = recurrent[0]
+    recurrent_r = recurrent[1]
+
+    rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer)
+    rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer)
+
+    units = keras_layer.units
+    recurrent_h = _sym.dense(rec_act_r * h_tm1_sym, rec_wts[1], use_bias=False, units=units)
+    act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
+
+    # previous and candidate state mixed by update gate
+    output = rec_act_z * h_tm1_sym + (1 - rec_act_z) * act_hh
+
+    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
+    output = _sym.reshape(output, shape=out_shape)
+    return [output, output]
+
 def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
     """Layers that can be skipped because they are train time only."""
     return insym
@@ -475,9 +559,9 @@ _convert_map = {
     # 'UpSampling3D'           : _convert_upsample,
     # 'Conv1D'                 : _convert_convolution1d,
 
-    # 'GRU'                    : _convert_gru,
+    'SimpleRNN'                : _convert_simple_rnn,
     'LSTM'                     : _convert_lstm,
-    # 'SimpleRNN'              : _convert_simple_rnn,
+    'GRU'                      : _convert_gru,
     # 'Bidirectional'          : _convert_bidirectional,
     # 'TimeDistributed'        : _default_skip,
 
diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py
index 459be8737..2e1c378d2 100644
--- a/nnvm/tests/python/frontend/keras/test_forward.py
+++ b/nnvm/tests/python/frontend/keras/test_forward.py
@@ -254,6 +254,58 @@ def test_forward_LSTM():
     _test_LSTM(4, 4, return_state=False)
     _test_LSTM_MultiLayer(4, 4)
 
+def _test_RNN(inputs, units):
+    data = keras.layers.Input(shape=(1, inputs))
+    rnn_out = keras.layers.SimpleRNN(units, return_state=True,
+                                 activation='tanh')
+    x = rnn_out(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def _test_RNN_MultiLayer(inputs, units):
+    inputs = keras.layers.Input(shape=(1, inputs))
+    layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True,
+                                   activation='tanh')
+    outputs = layer(inputs)
+    output, state = outputs[0], outputs[1:]
+    output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state)
+    keras_model = keras.models.Model(inputs, output)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def test_forward_RNN():
+    _test_RNN(2, 4)
+    _test_RNN(4, 3)
+    _test_RNN_MultiLayer(4, 12)
+
+def _test_GRU(inputs, units):
+    data = keras.layers.Input(shape=(1, inputs))
+    gru_out = keras.layers.GRU(units,
+                               return_state=True,
+                               recurrent_activation='sigmoid',
+                               activation='tanh')
+    x = gru_out(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def _test_GRU_MultiLayer(inputs, units):
+    inputs = keras.layers.Input(shape=(1, inputs))
+    layer = keras.layers.GRU(units,
+                             return_state=True,
+                             return_sequences=True,
+                             recurrent_activation='sigmoid',
+                             activation='tanh')
+    outputs = layer(inputs)
+    output, state = outputs[0], outputs[1:]
+    output = keras.layers.GRU(units, recurrent_activation='sigmoid',
+                              activation='tanh')(output, initial_state=state)
+    keras_model = keras.models.Model(inputs, output)
+    verify_keras_frontend(keras_model, need_transpose=False)
+
+def test_forward_GRU():
+    _test_GRU(2, 4)
+    _test_GRU(4, 3)
+    _test_GRU_MultiLayer(4, 4)
+
 if __name__ == '__main__':
     test_forward_elemwise_add()
     test_forward_activations()
@@ -272,3 +324,5 @@ if __name__ == '__main__':
     test_forward_multi_outputs()
     test_forward_reuse_layers()
     test_forward_LSTM()
+    test_forward_RNN()
+    test_forward_GRU()
-- 
GitLab