From 34d74282ec0adce60eda4298b82e411e3dd17543 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 24 Sep 2017 14:17:06 -0700
Subject: [PATCH] [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune
 (#35)

* [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune

* Some language improvements
---
 nnvm/docs/.gitignore                      |   2 +
 nnvm/docs/README.txt                      |   4 +
 nnvm/docs/conf.py                         |  30 +++++-
 nnvm/docs/dev/index.rst                   |   4 +-
 nnvm/docs/index.rst                       |   1 +
 nnvm/docs/top.rst                         |   4 +-
 nnvm/example/mobilenet_inference_gpu.py   | 117 --------------------
 nnvm/examples/README.md                   |   5 +
 nnvm/python/nnvm/testing/__init__.py      |   4 +-
 nnvm/python/nnvm/testing/config.py        |   2 +
 nnvm/python/nnvm/testing/mobilenet.py     | 125 ++++++++++++++++++++++
 nnvm/src/compiler/precompute_prune.cc     |  11 +-
 nnvm/tests/python/compiler/test_build.py  |  12 ++-
 nnvm/tutorials/README.txt                 |   3 +
 nnvm/tutorials/mobilenet_inference_gpu.py |  82 ++++++++++++++
 15 files changed, 271 insertions(+), 135 deletions(-)
 create mode 100644 nnvm/docs/README.txt
 delete mode 100644 nnvm/example/mobilenet_inference_gpu.py
 create mode 100644 nnvm/examples/README.md
 create mode 100644 nnvm/python/nnvm/testing/mobilenet.py
 create mode 100644 nnvm/tutorials/README.txt
 create mode 100644 nnvm/tutorials/mobilenet_inference_gpu.py

diff --git a/nnvm/docs/.gitignore b/nnvm/docs/.gitignore
index 024fbfbe7..d5d021127 100644
--- a/nnvm/docs/.gitignore
+++ b/nnvm/docs/.gitignore
@@ -1,2 +1,4 @@
 doxygen
 _build
+gen_modules
+tutorials
diff --git a/nnvm/docs/README.txt b/nnvm/docs/README.txt
new file mode 100644
index 000000000..8b8c75082
--- /dev/null
+++ b/nnvm/docs/README.txt
@@ -0,0 +1,4 @@
+The documentation of nnvm is generated with recommonmark and sphinx.
+
+- pip install sphinx>=1.5.5 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark
+- Build tvm first in the root folder.
diff --git a/nnvm/docs/conf.py b/nnvm/docs/conf.py
index af089466e..517516718 100644
--- a/nnvm/docs/conf.py
+++ b/nnvm/docs/conf.py
@@ -15,6 +15,7 @@ import sys
 import os, subprocess
 import shlex
 import recommonmark
+import sphinx_gallery
 from recommonmark.parser import CommonMarkParser
 from recommonmark.transform import AutoStructify
 
@@ -50,7 +51,8 @@ extensions = [
     'sphinx.ext.autosummary',
     'sphinx.ext.intersphinx',
     'sphinx.ext.napoleon',
-    'sphinx.ext.mathjax'
+    'sphinx.ext.mathjax',
+    'sphinx_gallery.gen_gallery',
 ]
 
 # Add any paths that contain templates here, relative to this directory.
@@ -129,7 +131,7 @@ if not on_rtd and html_theme == 'rtd':
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+# html_static_path = ['_static']
 
 # Output file base name for HTML help builder.
 htmlhelp_basename = project + 'doc'
@@ -164,9 +166,17 @@ intersphinx_mapping = {
     'numpy': ('http://docs.scipy.org/doc/numpy/', None),
     'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
     'matplotlib': ('http://matplotlib.org/', None),
+    'tvm': ('http://docs.tvmlang.org/', None),
 }
 
 
+from sphinx_gallery.sorting import ExplicitOrder
+
+examples_dirs = ['../tutorials/']
+gallery_dirs = ['tutorials']
+
+subsection_order = ExplicitOrder([])
+
 def generate_doxygen_xml(app):
     """Run the doxygen make commands if we're on the ReadTheDocs server"""
     run_doxygen('..')
@@ -180,3 +190,19 @@ def setup(app):
         'auto_doc_ref': True
             }, True)
     app.add_transform(AutoStructify)
+
+
+sphinx_gallery_conf = {
+    'backreferences_dir': 'gen_modules/backreferences',
+    'doc_module': ('tvm', 'nnvm', 'numpy'),
+'reference_url': {
+    'nnvm': None,
+    'tvm': 'http://docs.tvmlang.org',
+    'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1'},
+    'examples_dirs': examples_dirs,
+    'gallery_dirs': gallery_dirs,
+    'subsection_order': subsection_order,
+    'find_mayavi_figures': False,
+    'filename_pattern': '.py',
+    'expected_failing_examples': []
+}
diff --git a/nnvm/docs/dev/index.rst b/nnvm/docs/dev/index.rst
index ecee6889d..0647c9cce 100644
--- a/nnvm/docs/dev/index.rst
+++ b/nnvm/docs/dev/index.rst
@@ -1,5 +1,5 @@
-NNVM Design Note
-================
+Design Note
+===========
 
 In this part of documentation, we share the rationale for the specific choices made when designing NNVM.
 
diff --git a/nnvm/docs/index.rst b/nnvm/docs/index.rst
index 9011bacc9..14db71902 100644
--- a/nnvm/docs/index.rst
+++ b/nnvm/docs/index.rst
@@ -10,4 +10,5 @@ Contents
 
    self
    top
+   tutorials/index
    dev/index
diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst
index adbafdc99..89af46509 100644
--- a/nnvm/docs/top.rst
+++ b/nnvm/docs/top.rst
@@ -1,5 +1,5 @@
-NNVM Core Tensor Operators
-==========================
+Core Tensor Operators
+=====================
 
 This page contains the list of core tensor operator primitives re-defined in NNVM.
 The core tensor operator primitives(``nnvm.top``) covers typical workloads in deep learning.
diff --git a/nnvm/example/mobilenet_inference_gpu.py b/nnvm/example/mobilenet_inference_gpu.py
deleted file mode 100644
index 4331ec4c0..000000000
--- a/nnvm/example/mobilenet_inference_gpu.py
+++ /dev/null
@@ -1,117 +0,0 @@
-"""Forward propagation of MobileNet on GPU."""
-import numpy as np
-import time
-import os
-
-import tvm
-import topi
-import nnvm.symbol as sym
-import nnvm.compiler
-import nnvm.runtime
-from tvm.contrib import nvcc
-
-TASK="mobilenet"
-
-target = 'cuda'
-ctx = tvm.gpu(0)
-
-@tvm.register_func
-def tvm_callback_cuda_compile(code):
-    ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_60"])
-    return ptx
-
-def write_code(code, fname):
-    with open(fname, "w") as f:
-        f.write(code)
-
-@tvm.register_func
-def tvm_callback_cuda_postproc(code):
-    if not os.path.exists("perf"):
-        os.mkdir("perf")
-    write_code(code, "perf/%s_generated.cu" % TASK)
-    return code
-
-dtype = 'float32'
-epsilon = 1e-10 + 1e-5
-
-def conv_block(data, name, channels, kernel_size=(3,3), strides=(1,1), padding=(1,1)):
-    # convolution + bn + relu
-    conv = sym.conv2d(data=data, channels=channels, kernel_size=kernel_size, strides=strides,
-        padding=padding, use_bias=False, layout='NCHW', name=name + '_conv')
-    bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
-    act = sym.relu(data=bn, name=name + '_relu')
-    return act
-
-def separable_conv_block(data, name, depthwise_channels, pointwise_channels, kernel_size=(3,3), downsample=False, padding=(1,1)):
-    if downsample:
-        strides = (2,2)
-    else:
-        strides = (1,1)
-    # depthwise convolution + bn + relu
-    conv1 = sym.conv2d(data=data, channels=depthwise_channels, groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
-        padding=padding, use_bias=False, layout='NCHW', name=name + '_conv1')
-    bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + '_bn1')
-    act1 = sym.relu(data=bn1, name=name + '_relu1')
-    # pointwise convolution + bn + relu
-    conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1,1), strides=(1,1),
-        padding=(0,0), use_bias=False, layout='NCHW', name=name + '_conv2')
-    bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + '_bn2')
-    act2 = sym.relu(data=bn2, name=name + '_relu2')
-    return act2
-
-def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
-    data = sym.Variable("data")
-    body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2,2))
-    body = separable_conv_block(body, 'separable_conv_block_1', int(32*alpha), int(64*alpha))
-    body = separable_conv_block(body, 'separable_conv_block_2', int(64*alpha), int(128*alpha), downsample=True)
-    body = separable_conv_block(body, 'separable_conv_block_3', int(128*alpha), int(128*alpha))
-    body = separable_conv_block(body, 'separable_conv_block_4', int(128*alpha), int(256*alpha), downsample=True)
-    body = separable_conv_block(body, 'separable_conv_block_5', int(256*alpha), int(256*alpha))
-    body = separable_conv_block(body, 'separable_conv_block_6', int(256*alpha), int(512*alpha), downsample=True)
-    if is_shallow:
-        body = separable_conv_block(body, 'separable_conv_block_7', int(512*alpha), int(1024*alpha), downsample=True)
-        body = separable_conv_block(body, 'separable_conv_block_8', int(1024*alpha), int(1024*alpha))
-    else:
-        for i in range(7, 12):
-            body = separable_conv_block(body, 'separable_conv_block_%d' % i, int(512*alpha), int(512*alpha))
-        body = separable_conv_block(body, 'separable_conv_block_12', int(512*alpha), int(1024*alpha), downsample=True)
-        body = separable_conv_block(body, 'separable_conv_block_13', int(1024*alpha), int(1024*alpha))
-    pool = sym.global_avg_pool2d(data=body, name='pool')
-    flatten = sym.flatten(data=pool, name='flatten')
-    fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name='fc')
-    softmax = sym.softmax(data=fc, name='softmax')
-    return softmax
-
-
-batch_size = 1
-num_classes = 1000
-image_shape = (3,224,224)
-data_shape = (batch_size,) + image_shape
-out_shape = (batch_size, num_classes)
-
-net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
-
-# build graph
-with nnvm.compiler.build_config(opt_level=2):
-    graph, lib, _ = nnvm.compiler.build(net, target, {'data': data_shape})
-# prepare params
-params = {}
-names = graph.index.input_names
-shapes = [graph.json_attr("shape")[graph.index.entry_id(x)] for x in names]
-for i in range(len(names)):
-    params[names[i]] = tvm.nd.array(np.random.uniform(-0.1, 0.1, size=shapes[i]).astype(dtype), ctx=ctx)
-# create runtime module
-module = nnvm.runtime.create(graph, lib, ctx)
-# set input
-module.set_input(**params)
-# run
-print("run")
-module.run()
-ctx.sync()
-start = time.time()
-for i in range(1000):
-    module.run()
-    ctx.sync()
-print("average time cost of 1000 runs = %g ms" % ((time.time() - start)))
-# get output
-out = module.get_output(0, tvm.nd.empty(out_shape, dtype))
diff --git a/nnvm/examples/README.md b/nnvm/examples/README.md
new file mode 100644
index 000000000..123007b55
--- /dev/null
+++ b/nnvm/examples/README.md
@@ -0,0 +1,5 @@
+NNVM Examples
+=============
+This folder contains example snippets of running NNVM Compilation.
+
+- See also [Tutorials](tutorials) for tutorials with detailed explainations.
diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py
index 6dd015d87..27aaad8de 100644
--- a/nnvm/python/nnvm/testing/__init__.py
+++ b/nnvm/python/nnvm/testing/__init__.py
@@ -1,3 +1,5 @@
-"""Utilities for testcase"""
+"""Utilities for testing and benchmarks"""
+from __future__ import absolute_import as _abs
 
 from .config import ctx_list
+from . import mobilenet
diff --git a/nnvm/python/nnvm/testing/config.py b/nnvm/python/nnvm/testing/config.py
index a96e4b4ea..0eab3e6b3 100644
--- a/nnvm/python/nnvm/testing/config.py
+++ b/nnvm/python/nnvm/testing/config.py
@@ -1,4 +1,6 @@
 """Configuration about tests"""
+from __future__ import absolute_import as _abs
+
 import os
 import tvm
 
diff --git a/nnvm/python/nnvm/testing/mobilenet.py b/nnvm/python/nnvm/testing/mobilenet.py
new file mode 100644
index 000000000..4a0838031
--- /dev/null
+++ b/nnvm/python/nnvm/testing/mobilenet.py
@@ -0,0 +1,125 @@
+"""Helper utility to get mobilenet workload for testing."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import as _abs
+
+import numpy as np
+import tvm
+from .. compiler import graph_util
+from .. import graph
+from .. import symbol as sym
+
+def conv_block(data, name, channels,
+               kernel_size=(3, 3), strides=(1, 1), padding=(1, 1),
+               epsilon=1e-5):
+    """Helper function to construct conv-bn-relu"""
+    # convolution + bn + relu
+    conv = sym.conv2d(data=data, channels=channels,
+                      kernel_size=kernel_size, strides=strides,
+                      padding=padding, use_bias=False,
+                      layout="NCHW", name=name + "_conv")
+    bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn")
+    act = sym.relu(data=bn, name=name + "_relu")
+    return act
+
+def separable_conv_block(data, name, depthwise_channels,
+                         pointwise_channels, kernel_size=(3, 3),
+                         downsample=False, padding=(1, 1),
+                         epsilon=1e-5):
+    """Helper function to get a separable conv block"""
+    if downsample:
+        strides = (2, 2)
+    else:
+        strides = (1, 1)
+    # depthwise convolution + bn + relu
+    conv1 = sym.conv2d(data=data, channels=depthwise_channels,
+                       groups=depthwise_channels, kernel_size=kernel_size, strides=strides,
+                       padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1")
+    bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1")
+    act1 = sym.relu(data=bn1, name=name + "_relu1")
+    # pointwise convolution + bn + relu
+    conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1),
+                       padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2")
+    bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2")
+    act2 = sym.relu(data=bn2, name=name + "_relu2")
+    return act2
+
+def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False):
+    """Function to construct a MobileNet"""
+    data = sym.Variable("data")
+    body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2))
+    body = separable_conv_block(body, "separable_conv_block_1",
+                                int(32*alpha), int(64*alpha))
+    body = separable_conv_block(body, "separable_conv_block_2",
+                                int(64*alpha), int(128*alpha), downsample=True)
+    body = separable_conv_block(body, "separable_conv_block_3",
+                                int(128*alpha), int(128*alpha))
+    body = separable_conv_block(body, "separable_conv_block_4",
+                                int(128*alpha), int(256*alpha), downsample=True)
+    body = separable_conv_block(body, "separable_conv_block_5",
+                                int(256*alpha), int(256*alpha))
+    body = separable_conv_block(body, "separable_conv_block_6",
+                                int(256*alpha), int(512*alpha), downsample=True)
+    if is_shallow:
+        body = separable_conv_block(body, "separable_conv_block_7",
+                                    int(512*alpha), int(1024*alpha), downsample=True)
+        body = separable_conv_block(body, "separable_conv_block_8",
+                                    int(1024*alpha), int(1024*alpha))
+    else:
+        for i in range(7, 12):
+            body = separable_conv_block(body, "separable_conv_block_%d" % i,
+                                        int(512*alpha), int(512*alpha))
+        body = separable_conv_block(body, "separable_conv_block_12",
+                                    int(512*alpha), int(1024*alpha), downsample=True)
+        body = separable_conv_block(body, "separable_conv_block_13",
+                                    int(1024*alpha), int(1024*alpha))
+    pool = sym.global_avg_pool2d(data=body, name="pool")
+    flatten = sym.flatten(data=pool, name="flatten")
+    fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc")
+    softmax = sym.softmax(data=fc, name="softmax")
+    return softmax
+
+
+def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"):
+    """Get benchmark workload for mobilenet
+
+    Parameters
+    ----------
+    batch_size : int
+        The batch size used in the model
+
+    num_classes : int, optional
+        Number of claseses
+
+    image_shape : tuple, optional
+        The input image shape
+
+    dtype : str, optional
+        The data type
+
+    Returns
+    -------
+    net : nnvm.Symbol
+        The computational graph
+
+    params : dict of str to NDArray
+        The parameters.
+    """
+    image_shape = (3, 224, 224)
+    data_shape = (batch_size,) + image_shape
+    net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False)
+    params = {}
+    g = graph.create(net)
+    input_shapes, _ = graph_util.infer_shape(g, data=data_shape)
+    shape_dict = dict(zip(g.index.input_names, input_shapes))
+    for k, v in shape_dict.items():
+        if k == "data":
+            continue
+        # Specially generate non-negative parameters.
+        if k.endswith("gamma"):
+            init = np.random.uniform(0.9, 1, size=v)
+        elif k.endswith("var"):
+            init = np.random.uniform(0.9, 1, size=v)
+        else:
+            init = np.random.uniform(-0.1, 0.1, size=v)
+        params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0))
+    return net, params
diff --git a/nnvm/src/compiler/precompute_prune.cc b/nnvm/src/compiler/precompute_prune.cc
index a0159757c..a56a03986 100644
--- a/nnvm/src/compiler/precompute_prune.cc
+++ b/nnvm/src/compiler/precompute_prune.cc
@@ -44,17 +44,17 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
     } else {
       // scan again to find edge nodes, skip variables
       for (auto& e : n->inputs) {
-        if (!e.node->is_variable() && pruned.count(e.node.get())) {
+        if (pruned.count(e.node.get())) {
           if (!entry_var.count(e)) {
             nnvm::NodePtr var = nnvm::Node::Create();
-            var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
+            var->attrs.name = e.node->attrs.name;
+            if (e.node->num_outputs() != 1) {
+              var->attrs.name += "_output" + std::to_string(e.index);
+            }
             entry_var.emplace(e, var);
             CHECK(!unique_name.count(var->attrs.name));
             unique_name.insert(var->attrs.name);
           }
-          // TODO(ziheng): this pass now mutates the original graph structure
-          // This might not be a good thing, change to copy the structure instead
-          //
           e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
         }
       }
@@ -67,7 +67,6 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
   output_names.reserve(entry_var.size());
 
   for (auto kv : entry_var) {
-    if (kv.first.node->is_variable()) continue;
     pre_graph.outputs.emplace_back(kv.first);
     output_names.emplace_back(kv.second->attrs.name);
   }
diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py
index 379975d2d..59220a7ca 100644
--- a/nnvm/tests/python/compiler/test_build.py
+++ b/nnvm/tests/python/compiler/test_build.py
@@ -55,26 +55,28 @@ def test_run():
 
 def test_precompute_prune():
     x = sym.Variable("x") + 1
+    a = sym.Variable("a")
     y = sym.Variable("y")
-    z = y + x
+    z = y + x + a
     shape = (10, 10)
     dtype = tvm.float32
     nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
+    na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
     ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
-    params = {"x": nx}
+    params = {"x": nx, "a": na}
     graph, lib, params = nnvm.compiler.build(
         z, "llvm", shape={"y": ny.shape}, params=params)
-    assert graph.index.num_nodes == 3
+    assert graph.index.num_nodes == 4
     m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
     params["y"] = ny
     res = tvm.nd.empty(shape)
     m.run(**params)
     out = m.get_output(0, out=res)
     np.testing.assert_allclose(
-        res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy())
+        res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy())
 
 
 if __name__ == "__main__":
+    test_precompute_prune()
     test_compile()
     test_run()
-    test_precompute_prune()
diff --git a/nnvm/tutorials/README.txt b/nnvm/tutorials/README.txt
new file mode 100644
index 000000000..72f772fa6
--- /dev/null
+++ b/nnvm/tutorials/README.txt
@@ -0,0 +1,3 @@
+Tutorials
+=========
+This page contains the tutorials about NNVM.
diff --git a/nnvm/tutorials/mobilenet_inference_gpu.py b/nnvm/tutorials/mobilenet_inference_gpu.py
new file mode 100644
index 000000000..9343316b3
--- /dev/null
+++ b/nnvm/tutorials/mobilenet_inference_gpu.py
@@ -0,0 +1,82 @@
+"""
+Compile MobileNet Inference on GPU
+==================================
+**Author**: `Yuwei Hu <https://huyuwei.github.io/>`_
+
+This is an example of using NNVM to compile MobileNet model and deploy its inference on GPU.
+
+To begin with, we import nnvm(for compilation) and TVM(for deployment).
+"""
+import tvm
+import nnvm.compiler
+import nnvm.runtime
+import nnvm.testing
+from tvm.contrib import nvcc
+
+######################################################################
+# Register the NVCC Compiler Option
+# ---------------------------------
+# NNVM optimizes the graph and relies on TVM to generate fast
+# GPU code, to get the maximum performance, we need to enable
+# nvcc's compiler hook. This gives better performance than nvrtc mode.
+
+@tvm.register_func
+def tvm_callback_cuda_compile(code):
+    ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
+    return ptx
+
+######################################################################
+# Prepare the Benchmark
+# ---------------------
+# We construct a standard imagenet inference benchmark.
+# We use nnvm's testing utility to produce the model description and random parameters that so the example does not
+# depend on a specific front-end framework.
+#
+# .. note::
+#
+#   In a typical workflow, we can get this pair from :any:`nnvm.frontend`
+#
+target = "cuda"
+ctx = tvm.gpu(0)
+batch_size = 1
+num_classes = 1000
+image_shape = (3, 224, 224)
+data_shape = (batch_size,) + image_shape
+out_shape = (batch_size, num_classes)
+net, params = nnvm.testing.mobilenet.get_workload(
+    batch_size=1, image_shape=image_shape)
+
+######################################################################
+# Compile The Graph
+# -----------------
+# NNVM needs two things to compile a deep learning model:
+#
+# - net which is the graph representation of the computation
+# - params a dictionary of str to parameters.
+#
+# To compile the graph, we call the build function with the graph
+# configuration and parameters.
+# When parameters are provided, NNVM will pre-compute certain part of the graph if possible, 
+# the new parameter set returned as the third return value.
+
+graph, lib, params = nnvm.compiler.build(
+    net, target, shape={"data": data_shape}, params=params)
+
+######################################################################
+# Run the Compiled Module
+# -----------------------
+#
+# To deploy the module, we call :any:`nnvm.runtime.create` passing in the graph the lib and context.
+# Thanks to TVM, we can deploy the compiled module to many platforms and languages.
+# The deployment module is designed to contain minimum dependencies.
+# This example runs on the same machine.
+
+module = nnvm.runtime.create(graph, lib, ctx)
+# set input
+module.set_input(**params)
+# run
+module.run()
+# get output
+out = module.get_output(0, tvm.nd.empty(out_shape))
+# Convert to numpy
+out.asnumpy()
-- 
GitLab