From 083a41843637fbb7cc7e873b0537507920572ceb Mon Sep 17 00:00:00 2001 From: Yuwei Hu <huyuwei1995@gmail.com> Date: Thu, 21 Dec 2017 22:49:08 +0800 Subject: [PATCH] keras frontend tutorial (#278) * keras frontend tutorial * fix --- nnvm/docs/api/python/frontend.rst | 7 +- nnvm/python/nnvm/frontend/coreml.py | 6 - nnvm/python/nnvm/frontend/onnx.py | 2 +- nnvm/src/top/nn/nn.cc | 2 +- nnvm/tests/ci_build/Dockerfile.gpu | 3 + .../ci_build/install/ubuntu_install_keras.sh | 1 + nnvm/tutorials/from_keras.py | 114 ++++++++++++++++++ 7 files changed, 126 insertions(+), 9 deletions(-) create mode 100644 nnvm/tests/ci_build/install/ubuntu_install_keras.sh create mode 100644 nnvm/tutorials/from_keras.py diff --git a/nnvm/docs/api/python/frontend.rst b/nnvm/docs/api/python/frontend.rst index d1a0ddbd1..f872a6b87 100644 --- a/nnvm/docs/api/python/frontend.rst +++ b/nnvm/docs/api/python/frontend.rst @@ -3,5 +3,10 @@ nnvm.frontend .. automodule:: nnvm.frontend - .. autofunction:: nnvm.frontend.from_mxnet + +.. autofunction:: nnvm.frontend.from_onnx + +.. autofunction:: nnvm.frontend.from_coreml + +.. autofunction:: nnvm.frontend.from_keras diff --git a/nnvm/python/nnvm/frontend/coreml.py b/nnvm/python/nnvm/frontend/coreml.py index 8bf949ea5..fb942fb18 100644 --- a/nnvm/python/nnvm/frontend/coreml.py +++ b/nnvm/python/nnvm/frontend/coreml.py @@ -293,12 +293,6 @@ def from_coreml(model): model: coremltools.models.MLModel of a NeuralNetworkClassifier - arg_params : dict of str to mx.NDArray - The argument parameters in mxnet - - aux_params : dict of str to mx.NDArray - The auxiliary parameters in mxnet - Returns ------- sym : nnvm.Symbol diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index c4db5599f..c3a40ef89 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -393,7 +393,7 @@ class GraphProto(object): def from_onnx(graph): - """Load onnx graph which is a python protobuf object in to nnvm graph. + """Load onnx graph which is a python protobuf object into nnvm graph. The companion parameters will be handled automatically. The inputs from onnx graph is vague, only providing "1", "2"... For convenience, we rename the `real` input names to "input_0", diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index f8c1aa67c..c9dde3ebb 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -266,9 +266,9 @@ inline bool PadInferShape(const nnvm::NodeAttrs& attrs, TShape dshape = (*in_shape)[0]; if (dshape.ndim() == 0) return false; CHECK_EQ(param.pad_width.ndim(), dshape.ndim()); - CHECK_EQ(param.pad_width[0].ndim(), 2U); TShape oshape = dshape; for (uint32_t i = 0; i < dshape.ndim(); i++) { + CHECK_EQ(param.pad_width[i].ndim(), 2U); int pad_before = param.pad_width[i][0]; int pad_after = param.pad_width[i][1]; oshape[i] = dshape[i] + pad_before + pad_after; diff --git a/nnvm/tests/ci_build/Dockerfile.gpu b/nnvm/tests/ci_build/Dockerfile.gpu index 257865f15..bde32322c 100644 --- a/nnvm/tests/ci_build/Dockerfile.gpu +++ b/nnvm/tests/ci_build/Dockerfile.gpu @@ -38,6 +38,9 @@ RUN bash /install/ubuntu_install_onnx.sh COPY install/ubuntu_install_coreml.sh /install/ubuntu_install_coreml.sh RUN bash /install/ubuntu_install_coreml.sh +COPY install/ubuntu_install_keras.sh /install/ubuntu_install_keras.sh +RUN bash /install/ubuntu_install_keras.sh + RUN pip install Pillow # Environment variables diff --git a/nnvm/tests/ci_build/install/ubuntu_install_keras.sh b/nnvm/tests/ci_build/install/ubuntu_install_keras.sh new file mode 100644 index 000000000..9730d83bf --- /dev/null +++ b/nnvm/tests/ci_build/install/ubuntu_install_keras.sh @@ -0,0 +1 @@ +pip2 install keras tensorflow h5py diff --git a/nnvm/tutorials/from_keras.py b/nnvm/tutorials/from_keras.py new file mode 100644 index 000000000..0466d672c --- /dev/null +++ b/nnvm/tutorials/from_keras.py @@ -0,0 +1,114 @@ +""" +Compile Keras Models +===================== +**Author**: `Yuwei Hu <https://Huyuwei.github.io/>`_ + +This article is an introductory tutorial to deploy keras models with NNVM. + +For us to begin with, keras should be installed. +Tensorflow is also required since it's used as the default backend of keras. + +A quick solution is to install via pip +``` +pip install -U keras --user +``` +``` +pip install -U tensorflow --user +``` +or please refer to official site +https://keras.io/#installation +""" +import nnvm +import tvm +import keras +import numpy as np + +def download(url, path, overwrite=False): + import os + if os.path.isfile(path) and not overwrite: + print('File {} exists, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + import urllib.request + urllib.request.urlretrieve(url, path) + except: + import urllib + urllib.urlretrieve(url, path) + +###################################################################### +# Load pretrained keras model +# ---------------------------- +# We load a pretrained resnet-50 classification model provided by keras. +weights_url = ''.join(['https://github.com/fchollet/deep-learning-models/releases/', + 'download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5']) +weights_file = 'resnet50_weights.h5' +download(weights_url, weights_file) +keras_resnet50 = keras.applications.resnet50.ResNet50(include_top=True, weights=None, + input_shape=(224,224,3), classes=1000) +keras_resnet50.load_weights('resnet50_weights.h5') + +###################################################################### +# Load a test image +# ------------------ +# A single cat dominates the examples! +from PIL import Image +from matplotlib import pyplot as plt +from keras.applications.resnet50 import preprocess_input +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +download(img_url, 'cat.jpg') +img = Image.open('cat.jpg').resize((224, 224)) +plt.imshow(img) +plt.show() +# input preprocess +data = np.array(img)[np.newaxis, :].astype('float32') +data = preprocess_input(data).transpose([0, 3, 1, 2]) +print('data', data.shape) + +###################################################################### +# Compile the model on NNVM +# -------------------------- +# We should be familiar with the process now. + +# convert the keras model(NHWC layout) to NNVM format(NCHW layout). +sym, params = nnvm.frontend.from_keras(keras_resnet50) +# compile the model +target = 'cuda' +shape_dict = {'data': data.shape} +with nnvm.compiler.build_config(opt_level=2): + graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params) + +###################################################################### +# Execute on TVM +# --------------- +# The process is no different from other examples. +from tvm.contrib import graph_runtime +ctx = tvm.gpu(0) +m = graph_runtime.create(graph, lib, ctx) +# set inputs +m.set_input('data', tvm.nd.array(data.astype('float32'))) +m.set_input(**params) +# execute +m.run() +# get outputs +out_shape = (1000,) +tvm_out = m.get_output(0, tvm.nd.empty(out_shape, 'float32')).asnumpy() +top1_tvm = np.argmax(tvm_out) + +##################################################################### +# Look up synset name +# ------------------- +# Look up prdiction top 1 index in 1000 class synset. +synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) +synset_name = 'synset.txt' +download(synset_url, synset_name) +with open(synset_name) as f: + synset = eval(f.read()) +print('NNVM top-1 id: {}, class name: {}'.format(top1_tvm, synset[top1_tvm])) +# confirm correctness with keras output +keras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1])) +top1_keras = np.argmax(keras_out) +print('Keras top-1 id: {}, class name: {}'.format(top1_keras, synset[top1_keras])) -- GitLab