diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index bb2ad783000cf4171fa5725cbaabe618a040465a..eb3bb0d01ea56f163172c2e2cf1e7adc91c2ad4a 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -75,6 +75,8 @@ def _convert_activation(insym, keras_layer, _): def _convert_advanced_activation(insym, keras_layer, symtab): act_type = type(keras_layer).__name__ if act_type == 'ReLU': + if keras_layer.max_value: + return _sym.clip(insym, a_min=0, a_max=keras_layer.max_value) return _sym.relu(insym) elif act_type == 'LeakyReLU': return _sym.leaky_relu(insym, alpha=keras_layer.alpha) diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py index c8c9b2c784e85f1342084052fb4640dd43177c6b..a07e69c75f4fc69f2a9a3126df3d56bc5ad699c3 100644 --- a/nnvm/tests/python/frontend/keras/test_forward.py +++ b/nnvm/tests/python/frontend/keras/test_forward.py @@ -141,25 +141,25 @@ def test_forward_crop(): def test_forward_vgg16(): - keras_model = keras.applications.vgg16.VGG16(include_top=True, weights=None, + keras_model = keras.applications.vgg16.VGG16(include_top=True, weights='imagenet', input_shape=(224,224,3), classes=1000) verify_keras_frontend(keras_model) def test_forward_xception(): - keras_model = keras.applications.xception.Xception(include_top=True, weights=None, + keras_model = keras.applications.xception.Xception(include_top=True, weights='imagenet', input_shape=(299,299,3), classes=1000) verify_keras_frontend(keras_model) def test_forward_resnet50(): - keras_model = keras.applications.resnet50.ResNet50(include_top=True, weights=None, + keras_model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet', input_shape=(224,224,3), classes=1000) verify_keras_frontend(keras_model) def test_forward_mobilenet(): - keras_model = keras.applications.mobilenet.MobileNet(include_top=True, weights=None, + keras_model = keras.applications.mobilenet.MobileNet(include_top=True, weights='imagenet', input_shape=(224,224,3), classes=1000) verify_keras_frontend(keras_model) @@ -169,6 +169,7 @@ def test_forward_activations(): act_funcs = [keras.layers.Activation('softmax'), keras.layers.Activation('softplus'), keras.layers.ReLU(), + keras.layers.ReLU(max_value=6.), keras.layers.LeakyReLU(alpha=0.3), keras.layers.PReLU(weights=weights, alpha_initializer="zero"), keras.layers.ELU(alpha=0.5),