From 083a41843637fbb7cc7e873b0537507920572ceb Mon Sep 17 00:00:00 2001
From: Yuwei Hu <huyuwei1995@gmail.com>
Date: Thu, 21 Dec 2017 22:49:08 +0800
Subject: [PATCH] keras frontend tutorial (#278)

* keras frontend tutorial

* fix
---
 nnvm/docs/api/python/frontend.rst             |   7 +-
 nnvm/python/nnvm/frontend/coreml.py           |   6 -
 nnvm/python/nnvm/frontend/onnx.py             |   2 +-
 nnvm/src/top/nn/nn.cc                         |   2 +-
 nnvm/tests/ci_build/Dockerfile.gpu            |   3 +
 .../ci_build/install/ubuntu_install_keras.sh  |   1 +
 nnvm/tutorials/from_keras.py                  | 114 ++++++++++++++++++
 7 files changed, 126 insertions(+), 9 deletions(-)
 create mode 100644 nnvm/tests/ci_build/install/ubuntu_install_keras.sh
 create mode 100644 nnvm/tutorials/from_keras.py

diff --git a/nnvm/docs/api/python/frontend.rst b/nnvm/docs/api/python/frontend.rst
index d1a0ddbd1..f872a6b87 100644
--- a/nnvm/docs/api/python/frontend.rst
+++ b/nnvm/docs/api/python/frontend.rst
@@ -3,5 +3,10 @@ nnvm.frontend
 
 .. automodule:: nnvm.frontend
 
-
 .. autofunction:: nnvm.frontend.from_mxnet
+
+.. autofunction:: nnvm.frontend.from_onnx
+
+.. autofunction:: nnvm.frontend.from_coreml
+
+.. autofunction:: nnvm.frontend.from_keras
diff --git a/nnvm/python/nnvm/frontend/coreml.py b/nnvm/python/nnvm/frontend/coreml.py
index 8bf949ea5..fb942fb18 100644
--- a/nnvm/python/nnvm/frontend/coreml.py
+++ b/nnvm/python/nnvm/frontend/coreml.py
@@ -293,12 +293,6 @@ def from_coreml(model):
     model:
         coremltools.models.MLModel of a NeuralNetworkClassifier
 
-    arg_params : dict of str to mx.NDArray
-        The argument parameters in mxnet
-
-    aux_params : dict of str to mx.NDArray
-        The auxiliary parameters in mxnet
-
     Returns
     -------
     sym : nnvm.Symbol
diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py
index c4db5599f..c3a40ef89 100644
--- a/nnvm/python/nnvm/frontend/onnx.py
+++ b/nnvm/python/nnvm/frontend/onnx.py
@@ -393,7 +393,7 @@ class GraphProto(object):
 
 
 def from_onnx(graph):
-    """Load onnx graph which is a python protobuf object in to nnvm graph.
+    """Load onnx graph which is a python protobuf object into nnvm graph.
     The companion parameters will be handled automatically.
     The inputs from onnx graph is vague, only providing "1", "2"...
     For convenience, we rename the `real` input names to "input_0",
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index f8c1aa67c..c9dde3ebb 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -266,9 +266,9 @@ inline bool PadInferShape(const nnvm::NodeAttrs& attrs,
   TShape dshape = (*in_shape)[0];
   if (dshape.ndim() == 0) return false;
   CHECK_EQ(param.pad_width.ndim(), dshape.ndim());
-  CHECK_EQ(param.pad_width[0].ndim(), 2U);
   TShape oshape = dshape;
   for (uint32_t i = 0; i < dshape.ndim(); i++) {
+    CHECK_EQ(param.pad_width[i].ndim(), 2U);
     int pad_before = param.pad_width[i][0];
     int pad_after = param.pad_width[i][1];
     oshape[i] = dshape[i] + pad_before + pad_after;
diff --git a/nnvm/tests/ci_build/Dockerfile.gpu b/nnvm/tests/ci_build/Dockerfile.gpu
index 257865f15..bde32322c 100644
--- a/nnvm/tests/ci_build/Dockerfile.gpu
+++ b/nnvm/tests/ci_build/Dockerfile.gpu
@@ -38,6 +38,9 @@ RUN bash /install/ubuntu_install_onnx.sh
 COPY install/ubuntu_install_coreml.sh /install/ubuntu_install_coreml.sh
 RUN bash /install/ubuntu_install_coreml.sh
 
+COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh
+RUN bash /install/ubuntu_install_keras.sh
+
 RUN pip install Pillow
 
 # Environment variables
diff --git a/nnvm/tests/ci_build/install/ubuntu_install_keras.sh b/nnvm/tests/ci_build/install/ubuntu_install_keras.sh
new file mode 100644
index 000000000..9730d83bf
--- /dev/null
+++ b/nnvm/tests/ci_build/install/ubuntu_install_keras.sh
@@ -0,0 +1 @@
+pip2 install keras tensorflow h5py
diff --git a/nnvm/tutorials/from_keras.py b/nnvm/tutorials/from_keras.py
new file mode 100644
index 000000000..0466d672c
--- /dev/null
+++ b/nnvm/tutorials/from_keras.py
@@ -0,0 +1,114 @@
+"""
+Compile Keras Models
+=====================
+**Author**: `Yuwei Hu <https://Huyuwei.github.io/>`_
+
+This article is an introductory tutorial to deploy keras models with NNVM.
+
+For us to begin with, keras should be installed.
+Tensorflow is also required since it's used as the default backend of keras.
+
+A quick solution is to install via pip
+```
+pip install -U keras --user
+```
+```
+pip install -U tensorflow --user
+```
+or please refer to official site
+https://keras.io/#installation
+"""
+import nnvm
+import tvm
+import keras
+import numpy as np
+
+def download(url, path, overwrite=False):
+    import os
+    if os.path.isfile(path) and not overwrite:
+        print('File {} exists, skip.'.format(path))
+        return
+    print('Downloading from url {} to {}'.format(url, path))
+    try:
+        import urllib.request
+        urllib.request.urlretrieve(url, path)
+    except:
+        import urllib
+        urllib.urlretrieve(url, path)
+
+######################################################################
+# Load pretrained keras model
+# ----------------------------
+# We load a pretrained resnet-50 classification model provided by keras.
+weights_url = ''.join(['https://github.com/fchollet/deep-learning-models/releases/',
+                       'download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'])
+weights_file = 'resnet50_weights.h5'
+download(weights_url, weights_file)
+keras_resnet50 = keras.applications.resnet50.ResNet50(include_top=True, weights=None,
+	input_shape=(224,224,3), classes=1000)
+keras_resnet50.load_weights('resnet50_weights.h5')
+
+######################################################################
+# Load a test image
+# ------------------
+# A single cat dominates the examples!
+from PIL import Image
+from matplotlib import pyplot as plt
+from keras.applications.resnet50 import preprocess_input
+img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
+download(img_url, 'cat.jpg')
+img = Image.open('cat.jpg').resize((224, 224))
+plt.imshow(img)
+plt.show()
+# input preprocess
+data = np.array(img)[np.newaxis, :].astype('float32')
+data = preprocess_input(data).transpose([0, 3, 1, 2])
+print('data', data.shape)
+
+######################################################################
+# Compile the model on NNVM
+# --------------------------
+# We should be familiar with the process now.
+
+# convert the keras model(NHWC layout) to NNVM format(NCHW layout).
+sym, params = nnvm.frontend.from_keras(keras_resnet50)
+# compile the model
+target = 'cuda'
+shape_dict = {'data': data.shape}
+with nnvm.compiler.build_config(opt_level=2):
+	graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)
+
+######################################################################
+# Execute on TVM
+# ---------------
+# The process is no different from other examples.
+from tvm.contrib import graph_runtime
+ctx = tvm.gpu(0)
+m = graph_runtime.create(graph, lib, ctx)
+# set inputs
+m.set_input('data', tvm.nd.array(data.astype('float32')))
+m.set_input(**params)
+# execute
+m.run()
+# get outputs
+out_shape = (1000,)
+tvm_out = m.get_output(0, tvm.nd.empty(out_shape, 'float32')).asnumpy()
+top1_tvm = np.argmax(tvm_out)
+
+#####################################################################
+# Look up synset name
+# -------------------
+# Look up prdiction top 1 index in 1000 class synset.
+synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+                      'imagenet1000_clsid_to_human.txt'])
+synset_name = 'synset.txt'
+download(synset_url, synset_name)
+with open(synset_name) as f:
+    synset = eval(f.read())
+print('NNVM top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm]))
+# confirm correctness with keras output
+keras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1]))
+top1_keras = np.argmax(keras_out)
+print('Keras top-1 id: {}, class name: {}'.format(top1_keras, synset[top1_keras]))
-- 
GitLab