From 950aa1a0c2b340c3ccc5bbfa326566a5e2991df4 Mon Sep 17 00:00:00 2001
From: "Joshua Z. Zhang" <cheungchih@gmail.com>
Date: Tue, 26 Sep 2017 21:18:35 -0700
Subject: [PATCH] [Tutorial] mxnet (#47)

* [Tutorial] mxnet

update

add from_gluon

add to __init__

fix tutorial and from_gluon

fix doc lint

merge from_mxnet

fix

fix

fix tutorial

fix

fix header

* fix tutorial

* fix data

* fix
---
 nnvm/python/nnvm/frontend/mxnet.py |  46 +++++++-----
 nnvm/python/nnvm/testing/resnet.py |   1 +
 nnvm/tutorials/from_mxnet.py       | 114 +++++++++++++++++++++++++++++
 3 files changed, 143 insertions(+), 18 deletions(-)
 create mode 100644 nnvm/tutorials/from_mxnet.py

diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py
index 16419e737..168d3c67b 100644
--- a/nnvm/python/nnvm/frontend/mxnet.py
+++ b/nnvm/python/nnvm/frontend/mxnet.py
@@ -256,14 +256,6 @@ def _from_mxnet_impl(symbol, graph):
     nnvm.sym.Symbol
         Converted symbol
     """
-    try:
-        from mxnet import sym as mx_sym  # pylint: disable=import-self
-    except ImportError as e:
-        raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
-
-    if not isinstance(symbol, mx_sym.Symbol):
-        raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol))
-
     if _is_mxnet_group_symbol(symbol):
         return [_from_mxnet_impl(s, graph) for s in symbol]
 
@@ -294,7 +286,7 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
 
     Parameters
     ----------
-    symbol : mxnet.Symbol
+    symbol : mxnet.Symbol or mxnet.gluon.HybridBlock
         MXNet symbol
 
     arg_params : dict of str to mx.NDArray
@@ -305,18 +297,36 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
 
     Returns
     -------
-    net: nnvm.Symbol
+    sym : nnvm.Symbol
         Compatible nnvm symbol
 
     params : dict of str to tvm.NDArray
         The parameter dict to be used by nnvm
     """
-    sym = _from_mxnet_impl(symbol, {})
-    params = {}
-    arg_params = arg_params if arg_params else {}
-    aux_params = aux_params if aux_params else {}
-    for k, v in arg_params.items():
-        params[k] = tvm.nd.array(v.asnumpy())
-    for k, v in aux_params.items():
-        params[k] = tvm.nd.array(v.asnumpy())
+    try:
+        import mxnet as mx  # pylint: disable=import-self
+    except ImportError as e:
+        raise ImportError('{}. MXNet is required to parse symbols.'.format(e))
+
+    if isinstance(symbol, mx.sym.Symbol):
+        sym = _from_mxnet_impl(symbol, {})
+        params = {}
+        arg_params = arg_params if arg_params else {}
+        aux_params = aux_params if aux_params else {}
+        for k, v in arg_params.items():
+            params[k] = tvm.nd.array(v.asnumpy())
+        for k, v in aux_params.items():
+            params[k] = tvm.nd.array(v.asnumpy())
+    elif isinstance(symbol, mx.gluon.HybridBlock):
+        data = mx.sym.Variable('data')
+        sym = symbol(data)
+        sym = _from_mxnet_impl(sym, {})
+        params = {}
+        for k, v in symbol.collect_params().items():
+            params[k] = tvm.nd.array(v.data().asnumpy())
+    elif isinstance(symbol, mx.gluon.Block):
+        raise NotImplementedError("The dynamic Block is not supported yet.")
+    else:
+        msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
+        raise ValueError(msg)
     return sym, params
diff --git a/nnvm/python/nnvm/testing/resnet.py b/nnvm/python/nnvm/testing/resnet.py
index 0e9c81232..76b5c1d89 100644
--- a/nnvm/python/nnvm/testing/resnet.py
+++ b/nnvm/python/nnvm/testing/resnet.py
@@ -23,6 +23,7 @@ Implemented the following paper:
 
 Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
 '''
+# pylint: disable=unused-argument
 import numpy as np
 from .. import symbol as sym
 from . utils import create_workload
diff --git a/nnvm/tutorials/from_mxnet.py b/nnvm/tutorials/from_mxnet.py
new file mode 100644
index 000000000..0b6b70ef2
--- /dev/null
+++ b/nnvm/tutorials/from_mxnet.py
@@ -0,0 +1,114 @@
+"""
+Compiling MXNet Models with NNVM
+================================
+**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
+
+This article is an introductory tutorial to deploy mxnet models with NNVM.
+
+For us to begin with, mxnet module is required to be installed.
+
+A quick solution is
+```
+pip install mxnet --user
+```
+or please refer to offical installation guide.
+https://mxnet.incubator.apache.org/versions/master/install/index.html
+"""
+# some standard imports
+import mxnet as mx
+import nnvm
+import tvm
+import numpy as np
+
+######################################################################
+# Download Resnet18 model from Gluon Model Zoo
+# ---------------------------------------------
+# In this section, we download a pretrained imagenet model and classify an image.
+from mxnet.gluon.model_zoo.vision import get_model
+from mxnet.gluon.utils import download
+import Image
+from matplotlib import pyplot as plt
+block = get_model('resnet18_v1', pretrained=True)
+img_name = 'cat.jpg'
+synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
+                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
+                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
+                      'imagenet1000_clsid_to_human.txt'])
+synset_name = 'synset.txt'
+download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
+download(synset_url, synset_name)
+with open(synset_name) as f:
+    synset = eval(f.read())
+image = Image.open(img_name).resize((224, 224))
+plt.imshow(image)
+plt.show()
+
+def transform_image(image):
+    image = np.array(image) - np.array([123., 117., 104.])
+    image /= np.array([58.395, 57.12, 57.375])
+    image = image.transpose((2, 0, 1))
+    image = image[np.newaxis, :]
+    return image
+
+x = transform_image(image)
+print('x', x.shape)
+
+######################################################################
+# Compile the Graph
+# -----------------
+# Now we would like to port the Gluon model to a portable computational graph.
+# It's as easy as several lines.
+# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
+sym, params = nnvm.frontend.from_mxnet(block)
+# we want a probability so add a softmax operator
+sym = nnvm.sym.softmax(sym)
+
+######################################################################
+# now compile the graph
+import nnvm.compiler
+target = 'cuda'
+shape_dict = {'data': x.shape}
+graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)
+
+######################################################################
+# Execute the portable graph on TVM
+# ---------------------------------
+# Now, we would like to reproduce the same forward computation using TVM.
+from tvm.contrib import graph_runtime
+ctx = tvm.gpu(0)
+dtype = 'float32'
+m = graph_runtime.create(graph, lib, ctx)
+# set inputs
+m.set_input('data', tvm.nd.array(x.astype(dtype)))
+m.set_input(**params)
+# execute
+m.run()
+# get outputs
+tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype))
+top1 = np.argmax(tvm_output)
+print('TVM prediction top-1:', top1, synset[top1])
+
+######################################################################
+# Use MXNet symbol with pretrained weights
+# ----------------------------------------
+# MXNet often use `arg_prams` and `aux_params` to store network parameters
+# separately, here we show how to use these weights with existing API
+def block2symbol(block):
+    data = mx.sym.Variable('data')
+    sym = block(data)
+    args = {}
+    auxs = {}
+    for k, v in block.collect_params().items():
+        args[k] = mx.nd.array(v.data().asnumpy())
+    return sym, args, auxs
+mx_sym, args, auxs = block2symbol(block)
+# usually we would save/load it as checkpoint
+mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
+# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk
+
+######################################################################
+# for a normal mxnet model, we start from here
+mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
+# now we use the same API to get NNVM compatible symbol
+nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs)
+# repeat the same steps to run this model using TVM
-- 
GitLab