From 30a5a6007de3d3ab2a4f7fb6ed637021a42f1322 Mon Sep 17 00:00:00 2001
From: "Joshua Z. Zhang" <cheungchih@gmail.com>
Date: Fri, 11 Jan 2019 14:36:52 -0800
Subject: [PATCH] [RELAY][FRONTEND]Onnx to relay frontend (#2302)

---
 docs/install/from_source.rst                  |    2 +-
 .../python/frontend/onnx/test_forward.py      |   16 +-
 python/tvm/relay/expr.py                      |    8 +
 python/tvm/relay/frontend/__init__.py         |    1 +
 python/tvm/relay/frontend/common.py           |  182 +++
 python/tvm/relay/frontend/nnvm_common.py      |   12 +-
 python/tvm/relay/frontend/onnx.py             | 1090 +++++++++++++++++
 tests/python/frontend/onnx/test_forward.py    | 1033 ++++++++++++++++
 tests/scripts/task_python_frontend.sh         |    3 +
 tutorials/relay/from_onnx.py                  |   93 ++
 10 files changed, 2421 insertions(+), 19 deletions(-)
 create mode 100644 python/tvm/relay/frontend/onnx.py
 create mode 100644 tests/python/frontend/onnx/test_forward.py
 create mode 100644 tutorials/relay/from_onnx.py

diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst
index 00cb96fe4..81d06f1dc 100644
--- a/docs/install/from_source.rst
+++ b/docs/install/from_source.rst
@@ -35,7 +35,7 @@ Our goal is to build the shared libraries:
 .. code:: bash
 
     sudo apt-get update
-    sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev
+    sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake
 
 The minimal building requirements are
 
diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py
index 143ae8583..a98ef297f 100644
--- a/nnvm/tests/python/frontend/onnx/test_forward.py
+++ b/nnvm/tests/python/frontend/onnx/test_forward.py
@@ -910,7 +910,7 @@ def test_single_ops():
         model = helper.make_model(graph, producer_name='_test')
         for target, ctx in ctx_list():
             tvm_out = get_tvm_output(model, [x], target, ctx)
-            tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+            tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
 
     x = np.random.uniform(size=in_shape).astype(dtype)
     verify_single_ops("Neg",x, -x)
@@ -918,13 +918,13 @@ def test_single_ops():
     verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5)
     verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5)
     verify_single_ops("Relu",x, np.maximum(x, 0))
-    verify_single_ops("Exp",x, np.exp(x))
-    verify_single_ops("Log",x, np.log(x))
-    verify_single_ops("Log",x, np.log(x))
-    verify_single_ops("Tanh",x, np.tanh(x))
-    verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
-    verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
-    verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
+    verify_single_ops("Exp",x, np.exp(x), rtol=1e-5, atol=1e-5)
+    verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
+    verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5)
+    verify_single_ops("Tanh",x, np.tanh(x), rtol=1e-5, atol=1e-5)
+    verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)), rtol=1e-5, atol=1e-5)
+    verify_single_ops("Softsign",x, x / (1 + np.abs(x)), rtol=1e-5, atol=1e-5)
+    verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)), rtol=1e-5, atol=1e-5)
 
 def test_leaky_relu():
     def leaky_relu_x(x, alpha):
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 9de0344bf..f510d6195 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -465,6 +465,14 @@ def const(value, dtype=None):
     """
     if isinstance(value, (_base.numeric_types, (bool, list))):
         value = _np.array(value, dtype=dtype)
+    if not dtype:
+        # when dtype is None: int maps to "int32", float maps to "float32"
+        map_dtype = {
+            _np.dtype('int64'): _np.int32,
+            _np.dtype('float64'): _np.float32
+            }.get(value.dtype, None)
+        if map_dtype:
+            value = value.astype(map_dtype)
     if isinstance(value, (_np.ndarray, _np.generic)):
         value = _nd.array(value)
 
diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py
index fd0199aaf..30d544694 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/python/tvm/relay/frontend/__init__.py
@@ -9,3 +9,4 @@ from __future__ import absolute_import
 
 from .mxnet import from_mxnet
 from .keras import from_keras
+from .onnx import from_onnx
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 2d5817734..0464f0b8b 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -1,6 +1,11 @@
 """Common utilities"""
 from __future__ import absolute_import as _abs
+import logging
+from topi.util import get_const_tuple
 from .. import expr as _expr
+from .. import expr as _expr
+from .. import ir_pass
+from .. import op as _op
 
 
 class RequiredAttr(object):
@@ -204,6 +209,30 @@ class StrAttrsDict(object):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
 
+def get_relay_op(op_name):
+    """Get the callable function from Relay based on operator name.
+    Parameters
+    ----------
+    op_name : str
+        The Relay operator name.
+    """
+    if '.' in op_name:
+        # explicit hierachical modules
+        op = _op
+        try:
+            for opn in op_name.split('.'):
+                op = getattr(op, opn)
+        except AttributeError:
+            op = None
+    else:
+        # try search op in various modules
+        for candidate in (_op, _op.nn, _op.image):
+            op = getattr(candidate, op_name, None)
+            if op is not None:
+                break
+    if not op:
+        raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
+    return op
 
 class ExprTable(object):
     """Table storing Relay expressions by names."""
@@ -227,3 +256,156 @@ class ExprTable(object):
     def set_expr(self, name, expr):
         assert isinstance(expr, _expr.Expr)
         self.exprs[name] = expr
+
+
+class AttrCvt(object):
+    """Common attribute conveter. An AttrConverter instance is a callable:
+    ```
+    attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
+    new_op_name, new_attr = attr_converter(attrs)
+    ```
+
+    Parameters
+    ----------
+    op_name : str or callable
+        If set as str, returned operator name is the str.
+        If set as callable, returned operator is the str returned by calling:
+        `op_name = func(attr)`
+    transforms : dict of `new_name, or (new_name, default_value, transform function)`
+        If only a new_name is provided, it's like renaming the attribute name.
+        If default_value if provded, then the attribute is considered as optional.
+        If transform function is provided, the original attribute value is handled
+        by transform function.
+    excludes : list
+        A list of excluded attributes that should `NOT` appear.
+        Raise NotImplementedError if occured.
+    disables : list
+        A list of attributes that is disabled in relay. Log warnings.
+    ignores : list
+        A list of attributes that is ignored in relay. Debug level logging.
+    extras : dict
+        A series of additional attributes should be added anyway to the returned
+        attribute dict.
+    custom_check : callable
+        A custom function takes attribute, and return True/False.
+        Raise RuntimeError if not bool(True) returned.
+    """
+    def __init__(self, op_name, transforms=None,
+                 excludes=None, disables=None, ignores=None,
+                 extras=None, custom_check=None):
+        self._op_name = op_name
+        self._transforms = transforms if transforms else {}
+        self._excludes = excludes if excludes else []
+        self._disables = disables if disables else []
+        self._ignores = ignores if ignores else []
+        self._extras = extras if extras else {}
+        self._custom_check = custom_check
+
+    def __call__(self, inputs, attrs, *args):
+        # apply custom check
+        if self._custom_check:
+            func, msg = self._custom_check
+            if not func(attrs):
+                raise RuntimeError("Check failed: {}".format(msg))
+        # get new op_name
+        if isinstance(self._op_name, str):
+            op_name = self._op_name
+        else:
+            assert callable(self._op_name), "op_name can either be string or callable"
+            op_name = self._op_name(attrs)
+        # convert attributes
+        new_attrs = {}
+        for k in attrs.keys():
+            if k in self._excludes:
+                raise NotImplementedError("Attribute {} not supported yet.".format(k))
+            elif k in self._disables:
+                logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
+            elif k in self._ignores:
+                logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)
+            elif k in self._transforms:
+                new_name, defaults, transform = self._parse_default(self._transforms[k])
+                if defaults is None:
+                    new_attr = self._required_attr(attrs, k)
+                else:
+                    new_attr = attrs.get(k, None)
+                if new_attr is None:
+                    new_attrs[new_name] = defaults
+                else:
+                    new_attrs[new_name] = transform(new_attr)
+            else:
+                # copy
+                new_attrs[k] = attrs[k]
+        # add extras
+        new_attrs.update(self._extras)
+        return get_relay_op(op_name)(*inputs, **new_attrs)
+
+    def _parse_default(self, target):
+        """Helper function to parse default values."""
+        if not isinstance(target, (list, tuple)):
+            k, v, t = target, None, lambda x: x
+        elif len(target) == 1:
+            k, v, t = target[0], None, lambda x: x
+        elif len(target) == 2:
+            k, v, t = target[0], target[1], lambda x: x
+        elif len(target) > 2:
+            k, v, t = target[0], target[1], target[2]
+        else:
+            k = None  # should raise
+        if not isinstance(k, str):
+            msg = "{} is not a valid target, (name, default) expected.".format(target)
+            raise ValueError(msg)
+        return k, v, t
+
+    def _parse_bool(self, value):
+        """Helper function to parse default boolean values."""
+        if isinstance(value, str):
+            return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
+        return bool(value)
+
+    def _required_attr(self, attr, key):
+        """Wrapper for getting required attributes."""
+        assert isinstance(attr, dict)
+        if key not in attr:
+            raise AttributeError("Required attribute {} not found.".format(key))
+        return attr[key]
+
+def get_name(node):
+    name = ''
+    if hasattr(node, "name_hint"):
+        name = node.name_hint
+    return name
+
+def infer_shape(inputs):
+    """A method to get the output shape of an intermediate node in the graph."""
+    out_type = ir_pass.infer_type(inputs)
+    out_shapes = get_const_tuple(out_type.checked_type.shape)
+    return out_shapes
+
+def infer_channels(inputs, transpose=False):
+    """A hack for getting 'channels' or 'units' since caffe2 does not provide
+    these attributes. We check the shape of weights provided to get the number.
+    """
+    out_type = ir_pass.infer_type(inputs)
+    out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
+    return channels
+
+def new_var(name_hint,
+            type_annotation=None,
+            shape=None,
+            dtype="float32"):
+    return _expr.var(name_hint, type_annotation, shape, dtype)
+
+class Renamer(object):
+    """A simply renamer for operators.
+
+    Parameters
+    ----------
+    new_name : str
+        The new name for the operator
+    """
+    def __init__(self, new_name):
+        self._new_name = new_name
+
+    def __call__(self, inputs, attrs, *args):
+        return get_relay_op(self._new_name)(*inputs, **attrs)
diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py
index 17502dbaa..d75b0cac6 100644
--- a/python/tvm/relay/frontend/nnvm_common.py
+++ b/python/tvm/relay/frontend/nnvm_common.py
@@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs
 
 from .. import expr as _expr
 from .. import op as _op
-
-def _get_relay_op(op_name):
-    op = _op
-    for path in op_name.split("."):
-        op = getattr(op, path)
-    if not op:
-        raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
-    return op
-
+from .common import get_relay_op
 
 def _warn_not_used(attr, op='nnvm'):
     import warnings
@@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'):
 
 def _rename(new_op):
     if isinstance(new_op, str):
-        new_op = _get_relay_op(new_op)
+        new_op = get_relay_op(new_op)
     # attrs are ignored.
     def impl(inputs, _, _dtype='float32'):
         return new_op(*inputs)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
new file mode 100644
index 000000000..effe50e06
--- /dev/null
+++ b/python/tvm/relay/frontend/onnx.py
@@ -0,0 +1,1090 @@
+# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
+"""ONNX: Open Neural Network Exchange frontend for Relay."""
+from __future__ import absolute_import as _abs
+
+import logging
+import numpy as np
+from ... import nd as _nd
+from .. import ir_pass
+from .. import expr as _expr
+from .. import op as _op
+from .common import AttrCvt, Renamer
+from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name
+
+__all__ = ['from_onnx']
+
+def dimension_picker(prefix, surfix=''):
+    def _impl(attr):
+        kernel = attr['kernel_shape']
+        if len(kernel) == 2:
+            return prefix + '2d' + surfix
+        else:
+            raise NotImplementedError("Only 2d kernel supported.")
+
+    return _impl
+
+def revert_caffe2_pad(pads):
+    """Caffe2 requires two times the normal padding."""
+    if len(pads) == 4:
+        pads = pads[:2]
+    elif len(pads) == 2:
+        pass
+    else:
+        raise ValueError("Invalid caffe2 type padding: {}".format(pads))
+    return pads
+
+def dimension_constraint():
+    def _dim_check(attrs):
+        if len(attrs['kernel_shape']) == 2:
+            return True
+        return False
+
+    return _dim_check, "Only 2d kernel supported."
+
+class OnnxOpConverter(object):
+    """ A helper class for holding onnx op converters.
+    """
+
+    @classmethod
+    def get_converter(cls, opset):
+        """ Get converter matches given opset.
+
+        Parameters
+        ----------
+        opset: int
+            opset from model.
+
+        Returns
+        -------
+        converter, which should be `_impl_vx`. Number x is the biggest
+            number smaller than or equal to opset belongs to all support versions.
+        """
+        versions = [
+            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
+        ]
+        versions = sorted(versions + [opset])
+        version = versions[
+            max([i for i, v in enumerate(versions) if v == opset]) - 1]
+        if hasattr(cls, '_impl_v{}'.format(version)):
+            return getattr(cls, '_impl_v{}'.format(version))
+        raise NotImplementedError(
+            'opset version {} of {} not implemented'.format(
+                version, cls.__name__))
+
+
+class Elemwise(OnnxOpConverter):
+    """ A helper class for elemwise op converters.
+    """
+    name = ''
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
+            len(inputs))
+        op_name = cls.name
+        conv_ops = ["conv2d", "conv2d_transpose"]
+        if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
+            # TODO(zhreshold): remove hard coded infershape
+            axis = int(attr.get('axis', 0))
+            inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)
+        return get_relay_op(op_name)(*inputs)
+
+class Pool(OnnxOpConverter):
+    """ A helper class for pool op converters.
+    """
+    name = ''
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return AttrCvt(
+            op_name=dimension_picker(cls.name),
+            transforms={
+                'kernel_shape': 'pool_size',
+                'pads': ('padding', (0, 0), revert_caffe2_pad)
+            },
+            # very weird attributes here in onnx, force check
+            ignores=['dilations'],
+            # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
+            extras={'ceil_mode': False},
+            custom_check=dimension_constraint())(inputs, attr, params)
+
+
+class Absolute(OnnxOpConverter):
+    """ Operator converter for Absolute.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _op.nn.relu(inputs[0]) + _op.nn.relu(_op.negative(inputs[0]))
+
+
+class Add(Elemwise):
+    """ Operator converter for Add.
+    """
+    name = 'add'
+
+
+class AveragePool(Pool):
+    """ Operator converter for AveragePool.
+    """
+    name = 'avg_pool'
+
+
+class BatchNorm(OnnxOpConverter):
+    """ Operator converter for BatchNorm.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # TODO(zhreshold): 'spatial' is not properly handled here.
+        out = AttrCvt(
+            op_name='batch_norm',
+            ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr,
+                                                                           params)
+        return out[0]
+
+
+class Conv(OnnxOpConverter):
+    """ Operator converter for Conv.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # get number of channels
+        out = AttrCvt(op_name=dimension_picker('conv'),
+                      transforms={
+                          'kernel_shape': 'kernel_size',
+                          'dilations': ('dilation', (0, 0)),
+                          'pads': ('padding', (0, 0), revert_caffe2_pad),
+                          'group': ('groups', 1)},
+                      custom_check=dimension_constraint())(inputs[:2], attr, params)
+        use_bias = len(inputs) == 3
+        if use_bias:
+            out = _op.nn.bias_add(out, inputs[2])
+        return out
+
+
+class ConvTranspose(OnnxOpConverter):
+    """ Operator converter for ConvTranspose.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # get number of channels
+        channels = infer_channels(inputs[1], True)
+        attr['channels'] = channels
+        groups = attr.pop('group')
+        attr['groups'] = groups
+        out = AttrCvt(
+            op_name=dimension_picker('conv', '_transpose'),
+            transforms={
+                'kernel_shape': 'kernel_size',
+                'dilations': ('dilation', (0, 0)),
+                'pads': ('padding', (0, 0), revert_caffe2_pad)
+            },
+            disables=['output_shape'],
+            custom_check=dimension_constraint())(inputs[:2], attr, params)
+        use_bias = len(inputs) == 3
+        if use_bias:
+            out = _op.nn.bias_add(out, inputs[2])
+        return out
+
+
+class Div(Elemwise):
+    name = 'divide'
+
+
+class Elu(OnnxOpConverter):
+    """ Operator converter for Elu.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 1.0))
+        return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \
+                                     _op.nn.relu(inputs[0])
+
+
+class Gemm(OnnxOpConverter):
+    """ Operator converter for Gemm.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
+            len(inputs))
+        # Y = alpha * A * B + beta * C
+        alpha = float(attr.get('alpha', 1.0))
+        beta = float(attr.get('beta', 1.0))
+        transA = int(attr.get('transA', 0))
+        transB = int(attr.get('transB', 0))
+        # get number of channels
+        channels = infer_channels(inputs[1], not transB)
+        if transA:
+            inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
+        if not transB:
+            inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
+        inputs[0] = _op.nn.batch_flatten(inputs[0])
+        out = _op.nn.dense(_expr.const(alpha) * inputs[0],
+                           inputs[1], units=channels)
+        return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
+
+class MatMul(OnnxOpConverter):
+    """ Operator converter for MatMul.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
+        input_1_t = _op.transpose(inputs[1], axes=(1, 0))
+        return _op.nn.dense(inputs[0], input_1_t)
+
+class MaxPool(Pool):
+    name = 'max_pool'
+
+
+class Mul(Elemwise):
+    name = 'multiply'
+
+
+class Pad(OnnxOpConverter):
+    """ Operator converter for Pad.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        pad_width = []
+        pads = attr.pop('paddings')
+        dims = int(len(pads) / 2)
+        for i in range(dims):
+            pad_width.append((pads[i], pads[i+dims]))
+        attr['pad_width'] = pad_width
+
+        return AttrCvt(
+            _op.nn.pad,
+            transforms={
+                'value': 'pad_value',
+            },
+            ignores=['mode'],
+            custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
+                          'split mode != constant'))(inputs, attr, params)
+
+    @classmethod
+    def _impl_v2(cls, inputs, attr, params):
+        pad_width = []
+        pads = attr.pop('pads')
+        dims = int(len(pads) / 2)
+        for i in range(dims):
+            pad_width.append((pads[i], pads[i+dims]))
+        attr['pad_width'] = pad_width
+
+        return AttrCvt(
+            'pad',
+            transforms={
+                'value': 'pad_value',
+            },
+            ignores=['mode'],
+            custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
+                          'split mode != constant'))(inputs, attr, params)
+
+
+class ParametricSoftPlus(OnnxOpConverter):
+    """ Operator converter for ParametricSoftPlus.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = _expr.const(float(attr.get('alpha', 1.0)))
+        beta = _expr.const(float(attr.get('beta', 1.0)))
+        return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha
+
+
+class Prelu(OnnxOpConverter):
+    """ Operator converter for Prelu.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
+        return _op.nn.prelu(inputs[0], inputs[1])
+
+
+class Reciprocal(OnnxOpConverter):
+    """ Operator converter for Reciprocal.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _expr.const(1.0) / inputs[0]
+
+class Reshape(OnnxOpConverter):
+    """ Operator converter for Reshape.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if 'shape' in attr:
+            return _op.reshape(inputs[0], attr['shape'])
+
+        if get_name(inputs[1]) in params:
+            shape = tuple(params[inputs[1].name_hint].asnumpy())
+            out = _op.reshape(inputs[0], shape)
+        else:
+            out = _op.reshape_like(inputs[0], inputs[1])
+
+        return out
+
+class Concat(OnnxOpConverter):
+    """ Operator converter for Concat.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, args, params):
+        return AttrCvt(op_name='concatenate')((inputs,), args)
+
+class Scale(OnnxOpConverter):
+    """ Operator converter for Scale.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        scale = float(attr.get('scale', 1.0))
+        return inputs[0] * _expr.const(scale)
+
+
+class Selu(OnnxOpConverter):
+    """ Operator converter for Selu.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 1.6732))
+        gamma = float(attr.get('gamma', 1.0507))
+        return _expr.const(gamma) * (_expr.const(-alpha) *
+                                     _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) +
+                                     _op.nn.relu(inputs[0]))
+
+
+class ScaledTanh(OnnxOpConverter):
+    """ Operator converter for ScaledTanh.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 1.0))
+        beta = float(attr.get('beta', 1.0))
+        return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha)
+
+
+class SoftPlus(OnnxOpConverter):
+    """ Operator converter for SoftPlus.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _op.log(_op.exp(inputs[0]) + _expr.const(1.))
+
+
+class Softsign(OnnxOpConverter):
+    """ Operator converter for Softsign.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params))
+
+
+class Sub(Elemwise):
+    name = 'subtract'
+
+
+class Sum(OnnxOpConverter):
+    """ Operator converter for Sum.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # Onnx Sum Operator
+        for in_index in range(len(inputs) - 1):
+            inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1])
+
+        return inputs[len(inputs) - 1]
+
+
+class ThresholdedRelu(OnnxOpConverter):
+    """ Operator converter for ThresholdedRelu.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 0.0))
+        alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
+        mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
+        return inputs[0] * mask
+
+
+def _broadcast_constraint():
+
+    def _broadcast_check(attrs):
+        if attrs.get('axis', None):
+            return False
+        return True
+
+    return _broadcast_check, "Specifying broadcast axis not allowed."
+
+
+def _fully_connected(opset):
+
+    def _impl(inputs, attr, params):
+        # get number of channels
+        channels = infer_channels(inputs[1], params)
+        attr['units'] = channels
+        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
+
+    return _impl
+
+
+class Upsample(OnnxOpConverter):
+    """ Operator converter for Upsample (nearest mode).
+    """
+
+    @classmethod
+    def _impl_v7(cls, inputs, attr, params):
+        scales = attr.get('scales')
+        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
+        mode = attr.get('mode')
+        if mode == b'nearest':
+            method = "NEAREST_NEIGHBOR"
+        elif mode == b'linear':
+            method = "BILINEAR"
+        else:
+            raise ValueError("Invalid ONNX upsample mode: {}".format(mode))
+        attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'}
+        return AttrCvt('upsampling')(inputs, attr)
+
+
+class Shape(OnnxOpConverter):
+    """ Operator converter for Shape.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # Result of this operator is prominently used by reshape operator.
+        # Just pass the input as it is so that reshape_like can be used there.
+        logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)")
+        return inputs[0]
+
+class Cast(OnnxOpConverter):
+    """ Operator converter for Cast.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+
+    @classmethod
+    def _impl_v5(cls, inputs, attr, params):
+        try:
+            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+            attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import onnx.mapping which is required {}".format(e))
+        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+
+
+class Unsqueeze(OnnxOpConverter):
+    """ Operator converter for Unsqueeze.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        for axes in attr['axes']:
+            inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1)
+        return inputs[0]
+
+
+class Split(OnnxOpConverter):
+    """ Operator converter for Split.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        attr['indices_or_sections'] = []
+        index = 0
+        for i in attr['split'][:-1]:
+            index += i
+            attr['indices_or_sections'].append(index)
+        return AttrCvt(
+            'split',
+            ignores=['split'])(inputs, attr, params)
+
+
+class Slice(OnnxOpConverter):
+    """ Operator converter for Slice.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if isinstance(attr['starts'], int):
+            attr['starts'] = (attr['starts'],)
+            attr['ends'] = (attr['ends'],)
+
+        try:
+            # Update the starts and ends according to axes if required.
+            if isinstance(attr['axes'], int):
+                attr['axes'] = (attr['axes'],)
+
+            if (max(attr['axes']) + 1) != len(attr['axes']):
+                new_axes = []
+                new_starts = []
+                new_ends = []
+                pop_index = 0
+                for i in range(max(attr['axes']) + 1):
+                    if i in attr['axes']:
+                        new_axes.append(i)
+                        new_starts.append(attr['starts'][pop_index])
+                        new_ends.append(attr['ends'][pop_index])
+                        pop_index += 1
+                    else:
+                        new_axes.append(i)
+                        new_starts.append(0)
+                        new_ends.append(np.iinfo(np.int32).max)
+                attr['axes'] = new_axes
+                attr['starts'] = new_starts
+                attr['ends'] = new_ends
+        except KeyError:
+            pass
+
+        return AttrCvt('strided_slice',
+                       transforms={'starts': 'begin',
+                                   'ends': 'end'},
+                       ignores=['axes'])(inputs, attr)
+
+class Gather(OnnxOpConverter):
+    """ Operator converter for Gather.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        axis = attr.get('axis', 0)
+        return AttrCvt('take',
+                       extras={'axis':axis})(inputs, {})
+        #return _op.take(inputs[0], inputs[1], axis)
+
+class LRN(OnnxOpConverter):
+    """ Operator converter for Local Response Normalization.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        """LRN support only NCHW format
+        https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
+        """
+        axis = 1
+        alpha = attr.get('alpha', 0.0001)
+        beta = attr.get('beta', 0.75)
+        bias = attr.get('bias', 1.0)
+        nsize = attr.get('size')
+        attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias}
+        return AttrCvt('lrn')(inputs, attr)
+
+class Maximum(OnnxOpConverter):
+    """ Operator converter for Maximum.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if not isinstance(inputs, list) or len(inputs) < 2:
+            raise ValueError("Expect minimum 2 inputs")
+        _max = inputs[0]
+        for i in range(1, len(inputs)):
+            _max = AttrCvt('maximum')([_max, inputs[i]], {})
+        return _max
+
+class Minimum(OnnxOpConverter):
+    """ Operator converter for Minimum.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if not isinstance(inputs, list) or len(inputs) < 2:
+            raise ValueError("Expect minimum 2 inputs")
+        _min = inputs[0]
+        for i in range(1, len(inputs)):
+            _min = AttrCvt('minimum')([_min, inputs[i]], {})
+        return _min
+
+class Mean(OnnxOpConverter):
+    """ Operator converter for Mean.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if not isinstance(inputs, list) or len(inputs) < 2:
+            raise ValueError("Expect minimum 2 inputs")
+        # avoid overflow
+        concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
+        return _op.mean(concat, axis=0, keepdims=False)
+
+class HardSigmoid(OnnxOpConverter):
+    """ Operator converter for HardSigmoid.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = attr.get('alpha', 0.2)
+        beta = attr.get('beta', 0.5)
+        transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
+        attr = {'a_min':0, 'a_max':1}
+        return AttrCvt('clip')([transformX], attr)
+
+class Reduce(OnnxOpConverter):
+    """ Operator converter for reduce ops.
+    """
+    name = ''
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        if 'axes' in attr:
+            axis = attr.get('axes', 0)
+        else:
+            axis_len = len(infer_shape(inputs[0]))
+            axis = list(range(axis_len))
+        attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)}
+        return AttrCvt(cls.name)(inputs, attr)
+
+class ReduceMax(Reduce):
+    """ Operator converter for ArgMax.
+    """
+    name = 'max'
+
+class ReduceMin(Reduce):
+    """ Operator converter for ArgMax.
+    """
+    name = 'min'
+
+class ReduceSum(Reduce):
+    """ Operator converter for ArgMax.
+    """
+    name = 'sum'
+
+class ReduceMean(Reduce):
+    """ Operator converter for ArgMax.
+    """
+    name = 'mean'
+
+class ArgMax(OnnxOpConverter):
+    """ Operator converter for ArgMax.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        axis = attr.get('axis', 0)
+        keepdims = attr.get('keepdims', True)
+        attr = {'axis':axis, 'keepdims':keepdims}
+        return AttrCvt('argmax')(inputs, attr)
+
+class ArgMin(OnnxOpConverter):
+    """ Operator converter for ArgMin.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        axis = attr.get('axis', 0)
+        keepdims = attr.get('keepdims', True)
+        attr = {'axis':axis, 'keepdims':keepdims}
+        return AttrCvt('argmin')(inputs, attr)
+
+class Softmax(OnnxOpConverter):
+    """ Operator converter for Softmax.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # set default value when axis is not set in the model
+        if 'axis' not in attr:
+            attr['axis'] = 1
+        return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params)
+
+class ConstantFill(OnnxOpConverter):
+    """ Operator converter for ConstantFill.
+    """
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        num_inputs = len(inputs)
+        if 'shape' in attr:
+            if num_inputs > 1:
+                raise ImportError(
+                    "Can't set shape and input tensor at a time")
+            shape = attr.pop('shape')
+        else:
+            if num_inputs == 1:
+                raise ImportError(
+                    "Either shape attribute or input should be set")
+            if 'input_as_shape' in attr and attr['input_as_shape']:
+                shape = params[get_name(inputs[0])].asnumpy()
+            else:
+                if 'extra_shape' in attr:
+                    raise ImportError(
+                        "Extra Shape not supported with fill_like")
+                return _op.full_like(inputs[0], inputs[1])
+
+        if 'extra_shape' in attr:
+            shape = shape + attr.pop('extra_shape')
+        return _op.full(inputs[0], shape)
+
+# compatible operators that do NOT require any conversion.
+_identity_list = []
+
+
+# _convert_map defines maps of name to converter functor(callable)
+# for 1 to 1 mapping, use Renamer if nothing but name is different
+# use AttrCvt if attributes need to be converted
+# for 1 to N mapping(composed), use custom callable functions
+# for N to 1 mapping, currently not supported(?)
+def _get_convert_map(opset):
+    return {
+        # defs/experimental
+        'Identity': Renamer('copy'),
+        # 'Affine'
+        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
+        'ScaledTanh': ScaledTanh.get_converter(opset),
+        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
+        'ConstantFill': ConstantFill.get_converter(opset),
+        # 'GivenTensorFill'
+        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
+        'Scale': Scale.get_converter(opset),
+        # 'GRUUnit'
+        # 'ATen'
+        # 'ImageScaler'
+        # 'MeanVarianceNormalization'
+        # 'Crop'
+        # 'Embedding'
+        'Upsample' : Upsample.get_converter(opset),
+        'SpatialBN': BatchNorm.get_converter(opset),
+
+        # defs/generator
+        # 'Constant' # Implemented
+        # 'RandomUniform'
+        # 'RandomNormal'
+        # 'RandomUniformLike'
+        # 'RandomNormalLike'
+
+        # defs/logical
+
+        # defs/math
+        'Add': Add.get_converter(opset),
+        'Sub': Sub.get_converter(opset),
+        'Mul': Mul.get_converter(opset),
+        'Div': Div.get_converter(opset),
+        'Neg': Renamer('negative'),
+        'Abs': Absolute.get_converter(opset),
+        'Reciprocal': Reciprocal.get_converter(opset),
+        'Floor': Renamer('floor'),
+        'Ceil': Renamer('ceil'),
+        'Sqrt': Renamer('sqrt'),
+        'Relu': Renamer('relu'),
+        'LeakyRelu': Renamer('leaky_relu'),
+        'Selu': Selu.get_converter(opset),
+        'Elu': Elu.get_converter(opset),
+        'Exp': Renamer('exp'),
+        'Log': Renamer('log'),
+        'Tanh': Renamer('tanh'),
+        'Pow': Renamer('power'),
+        'PRelu': Prelu.get_converter(opset),
+        'Sigmoid': Renamer('sigmoid'),
+        'HardSigmoid': HardSigmoid.get_converter(opset),
+        'Max': Maximum.get_converter(opset),
+        'Min': Minimum.get_converter(opset),
+        'Sum': Sum.get_converter(opset),
+        'Mean': Mean.get_converter(opset),
+        'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
+        # softmax default axis is different in onnx
+        'Softmax': Softmax.get_converter(opset),
+        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
+        # 'Hardmax'
+        'Softsign': Softsign.get_converter(opset),
+        'SoftPlus': SoftPlus.get_converter(opset),
+        'Gemm': Gemm.get_converter(opset),
+        'MatMul': MatMul.get_converter(opset),
+
+        # defs/nn
+        'AveragePool': AveragePool.get_converter(opset),
+        'MaxPool': MaxPool.get_converter(opset),
+        'Conv': Conv.get_converter(opset),
+        'ConvTranspose': ConvTranspose.get_converter(opset),
+        'GlobalAveragePool': Renamer('global_avg_pool2d'),
+        'GlobalMaxPool': Renamer('global_max_pool2d'),
+        'BatchNormalization': BatchNorm.get_converter(opset),
+        # 'InstanceNormalization'
+        # 'LpNormalization'
+        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
+        'Flatten': Renamer('flatten'),
+        'LRN': LRN.get_converter(opset),
+
+        # defs/reduction
+        'ReduceMax': ReduceMax.get_converter(opset),
+        'ReduceMin': ReduceMin.get_converter(opset),
+        'ReduceSum': ReduceSum.get_converter(opset),
+        'ReduceMean': ReduceMean.get_converter(opset),
+        # 'ReduceProd'
+        # 'ReduceLogSumExp'
+        'ArgMax': ArgMax.get_converter(opset),
+        'ArgMin': ArgMin.get_converter(opset),
+
+        # defs/tensor
+        'Cast': Cast.get_converter(opset),
+        'Reshape': Reshape.get_converter(opset),
+        'Concat': Concat.get_converter(opset),
+        'Split': Split.get_converter(opset),
+        'Slice': Slice.get_converter(opset),
+        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
+        'Gather': Gather.get_converter(opset),
+        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
+        'Unsqueeze': Unsqueeze.get_converter(opset),
+        'Pad': Pad.get_converter(opset),
+        # TODO(zhreshold) Shape op is implemented as bypass op in relay
+        # 'Shape': Shape.get_converter(opset),
+    }
+
+
+class GraphProto(object):
+    """A helper class for handling Relay expression copying from pb2.GraphProto.
+    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
+
+        Parameters
+    ----------
+    shape : dict of str to tuple, optional
+        The input shape to the graph
+
+    dtype : str or dict of str to str
+        The input types to the graph
+    """
+
+    def __init__(self, shape, dtype):
+        self._nodes = {}
+        self._params = {}
+        self._renames = {}
+        self._num_input = 0
+        self._num_param = 0
+        self._shape = shape
+        self._dtype = dtype
+
+    def from_onnx(self, graph, opset):
+        """Construct Relay expression from ONNX graph.
+
+        Onnx graph is a python protobuf object.
+        The companion parameters will be handled automatically.
+        However, the input names from onnx graph is vague, mixing inputs and
+        network weights/bias such as "1", "2"...
+        For convenience, we rename the `real` input names to "input_0",
+        "input_1"... And renaming parameters to "param_0", "param_1"...
+
+        Parameters
+        ----------
+        graph : onnx protobuf object
+            The loaded onnx graph
+        opset : opset version
+
+        Returns
+        -------
+        sym : tvm.relay.expr.Function
+            The returned relay function
+        params : dict
+            A dict of name: tvm.nd.array pairs, used as pretrained weights
+        """
+        # parse network inputs to relay, aka parameters
+        for init_tensor in graph.initializer:
+            if not init_tensor.name.strip():
+                raise ValueError("Tensor's name is required.")
+            self._params[init_tensor.name] = self._parse_array(init_tensor)
+        for i in graph.input:
+            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
+            #  and the name is 'i.name'
+            i_name = self._parse_value_proto(i)
+            d_type = self._parse_dtype(i, 'float32')
+            if i_name in self._params:
+                # i is a param instead of input
+                self._num_param += 1
+                self._params[i_name] = self._params.pop(i_name)
+                self._nodes[i_name] = new_var(i_name,
+                                              shape=self._params[i_name].shape,
+                                              dtype=self._params[i_name].dtype)
+            else:
+                self._num_input += 1
+                shape = self._shape[i_name] if i_name in self._shape else ()
+                if isinstance(self._dtype, dict):
+                    dtype = self._dtype[i_name] if i_name in self._dtype else d_type
+                else:
+                    dtype = d_type
+                self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype)
+        # construct nodes, nodes are stored as directed acyclic graph
+        for node in graph.node:
+            op_name = node.op_type
+            attr = self._parse_attr(node.attribute)
+            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
+            if op_name == "Constant":
+                t_proto = self._parse_attr(node.attribute)["value"]
+                self._num_param += 1
+                self._params[node.output[0]] = self._parse_array(t_proto)
+                self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims))
+            else:
+                if op_name == "ConstantFill":
+                    fill_value = attr.get('value', 0.0)
+                    dtype = attr.get('dtype', b'int32').decode("utf-8")
+                    i_name = node.output[0]
+                    self._params[i_name] = fill_value
+                    self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype)
+                    inputs.append(self._nodes[i_name])
+
+                op = self._convert_operator(op_name, inputs, attr, opset)
+                node_output = self._fix_outputs(op_name, node.output)
+                if not isinstance(op, _expr.TupleWrapper):
+                    outputs_num = 1
+                else:
+                    outputs_num = len(op)
+                assert len(node_output) == outputs_num, (
+                    "Number of output mismatch {} vs {} in {}.".format(
+                        len(node_output), outputs_num, op_name))
+                if outputs_num == 1:
+                    self._nodes[node_output[0]] = op
+                else:
+                    for k, i in zip(list(node_output), range(len(node_output))):
+                        self._nodes[k] = op[i]
+
+        # now return the outputs
+        outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
+        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
+        func = _expr.Function(ir_pass.free_vars(outputs), outputs)
+        return func, self._params
+
+    def _parse_value_proto(self, value_proto):
+        """Parse ValueProto or raw str."""
+        try:
+            name = value_proto.name
+        except AttributeError:
+            name = value_proto
+        return name
+
+    def _parse_dtype(self, value_proto, dtype):
+        """Parse dtype."""
+        try:
+            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
+        except AttributeError:
+            return dtype
+
+    def _parse_array(self, tensor_proto):
+        """Grab data in TensorProto and convert to numpy array."""
+        try:
+            from onnx.numpy_helper import to_array
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import onnx which is required {}".format(e))
+        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
+        return _nd.array(np_array)
+
+    def _parse_attr(self, attr_proto):
+        """Convert a list of AttributeProto to a dict, with names as keys."""
+        attrs = {}
+        for a in attr_proto:
+            for f in ['f', 'i', 's']:
+                if a.HasField(f):
+                    attrs[a.name] = getattr(a, f)
+            for f in ['floats', 'ints', 'strings']:
+                if list(getattr(a, f)):
+                    assert a.name not in attrs, "Only one type of attr is allowed"
+                    attrs[a.name] = tuple(getattr(a, f))
+            for f in ['t']:
+                if a.HasField(f):
+                    attrs[a.name] = getattr(a, f)
+            for f in ['tensors']:
+                if list(getattr(a, f)):
+                    assert a.name not in attrs, "Only one type of attr is allowed"
+                    attrs[a.name] = tuple(getattr(a, f))
+            for f in ['g']:
+                if a.HasField(f):
+                    raise NotImplementedError(
+                        "Filed {} is not supported in relay.".format(f))
+            for f in ['graphs']:
+                if list(getattr(a, f)):
+                    raise NotImplementedError(
+                        "Filed {} is not supported in relay.".format(f))
+            if a.name not in attrs:
+                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
+        return attrs
+
+    def _convert_operator(self,
+                          op_name,
+                          inputs,
+                          attrs,
+                          opset):
+        """Convert ONNX operator into a Relay operator.
+        The converter must specify conversions explicity for incompatible name, and
+        apply handlers to operator attributes.
+
+        Parameters
+        ----------
+        op_name : str
+            Operator name, such as Convolution, FullyConnected
+        inputs : list of tvm.relay.expr.Function
+            List of inputs.
+        attrs : dict
+            Dict of operator attributes
+        opset : int
+            Opset version
+
+        Returns
+        -------
+        sym : tvm.relay.expr.Function
+            Converted relay function
+        """
+        convert_map = _get_convert_map(opset)
+        if op_name in _identity_list:
+            sym = get_relay_op(op_name)(*inputs, **attrs)
+        elif op_name in convert_map:
+            sym = convert_map[op_name](inputs, attrs, self._params)
+        else:
+            raise NotImplementedError(
+                "Operator {} not implemented.".format(op_name))
+        return sym
+
+    def _fix_outputs(self, op_name, outputs):
+        """A hack to handle dropout or similar operator that have more than one out
+        in ONNX.
+        """
+        if op_name == 'Dropout':
+            if len(outputs) == 1:
+                return outputs
+            # TODO(zhreshold): support dropout mask?
+            outputs = outputs[:-1]
+        return outputs
+
+def from_onnx(model,
+              shape=None,
+              dtype="float32"):
+    """Convert a ONNX model into an equivalent Relay Function.
+
+    ONNX graphs are represented as Python Protobuf objects.
+    The companion parameters will be handled automatically.
+    However, the input names from onnx graph is vague, mixing inputs and
+    network weights/bias such as "1", "2"...
+    For convenience, we rename the `real` input names to "input_0",
+    "input_1"... And renaming parameters to "param_0", "param_1"...
+
+    Parameters
+    ----------
+    model : protobuf object
+        ONNX ModelProto after ONNX v1.1.0
+
+    shape : dict of str to tuple, optional
+        The input shape to the graph
+
+    dtype : str or dict of str to str
+        The input types to the graph
+
+    Returns
+    -------
+    sym : tvm.relay.expr.Function
+        Compatible relay function
+
+    params : dict of str to tvm.NDArray
+        The parameter dict to be used by relay
+    """
+    g = GraphProto(shape, dtype)
+    graph = model.graph
+    try:
+        opset = model.opset_import[0].version if model.opset_import else 1
+    except AttributeError:
+        opset = 1
+    sym, params = g.from_onnx(graph, opset)
+    return sym, params
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
new file mode 100644
index 000000000..de95ff00a
--- /dev/null
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -0,0 +1,1033 @@
+import numpy as np
+import math
+import topi
+import topi.testing
+import tvm
+from tvm import relay
+from tvm.contrib import graph_runtime
+from nnvm.testing.config import ctx_list
+import onnx
+from onnx import helper, TensorProto
+import unittest
+
+def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'):
+    """ Generic function to execute and get tvm output"""
+    target = 'llvm'
+    if isinstance(input_data, list):
+        input_names = {}
+        shape_dict = {}
+        dtype_dict = {}
+        for i, _ in enumerate(input_data):
+            input_names[i] = graph_def.graph.input[i].name
+            shape_dict[input_names[i]] = input_data[i].shape
+            dtype_dict[input_names[i]] = input_data[i].dtype
+    else:
+        input_names = graph_def.graph.input[0].name
+        shape_dict = {input_names: input_data.shape}
+        dtype_dict = {input_names: input_data.dtype}
+
+    sym, params = relay.frontend.from_onnx(graph_def, shape_dict)
+    with relay.build_config(opt_level=1):
+        graph, lib, params = relay.build(sym, target, params=params)
+
+    ctx = tvm.cpu(0)
+    from tvm.contrib import graph_runtime
+    m = graph_runtime.create(graph, lib, ctx)
+    # set inputs
+    if isinstance(input_data, list):
+        for i, e in enumerate(input_names):
+            m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
+    else:
+        m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype)))
+
+    m.set_input(**params)
+    # execute
+    m.run()
+    # get outputs
+    if isinstance(output_shape, list) and isinstance(output_dtype, list):
+        tvm_output_list = []
+        for i, _ in enumerate(output_shape):
+            tvm_output = m.get_output(i)
+            tvm_output_list.append(tvm_output.asnumpy())
+        return tvm_output_list
+    else:
+        tvm_output = m.get_output(0)
+        return tvm_output.asnumpy()
+
+def get_caffe2_output(model, x, dtype='float32'):
+    import caffe2.python.onnx.backend
+    prepared_backend = caffe2.python.onnx.backend.prepare(model)
+    W = {model.graph.input[0].name: x.astype(dtype)}
+    c2_out = prepared_backend.run(W)[0]
+    return c2_out
+
+
+def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
+    dtype = 'float32'
+    x = np.random.uniform(size=data_shape)
+    model = onnx.load_model(graph_file)
+    c2_out = get_caffe2_output(model, x, dtype)
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
+        tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
+
+def verify_super_resolution_example():
+    verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
+
+def verify_squeezenet1_1():
+    verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
+
+def verify_lenet():
+    verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
+
+def verify_resnet18():
+    verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
+
+
+def test_reshape():
+    in_shape = (4, 3, 3, 4)
+    ref_shape = (6, 2, 4, 3)
+
+    ref_array = np.array(ref_shape)
+    ref_node = onnx.helper.make_node('Constant',
+                                 inputs=[],
+                                 outputs=['ref_in'],
+                                 value=onnx.helper.make_tensor(name = 'const_tensor',
+                                                               data_type = onnx.TensorProto.INT32,
+                                                               dims = ref_array.shape,
+                                                               vals = ref_array.flatten().astype(int)))
+    reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
+
+    graph = helper.make_graph([ref_node, reshape_node],
+                              "reshape_test",
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(ref_shape))])
+
+    model = helper.make_model(graph, producer_name='reshape_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape).astype('int32')
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+
+    tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+
+def test_reshape_like():
+    in_shape = (4, 3, 3, 4)
+    ref_shape = (3, 4, 4, 3)
+
+    ref_array = np.random.uniform(size=ref_shape).astype('float32')
+    ref_node = onnx.helper.make_node('Constant',
+                                 inputs=[],
+                                 outputs=['ref_in'],
+                                 value=onnx.helper.make_tensor(name = 'const_tensor',
+                                                               data_type = onnx.TensorProto.FLOAT,
+                                                               dims = ref_array.shape,
+                                                               vals = ref_array.flatten().astype(float)))
+    copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
+    reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
+
+    graph = helper.make_graph([ref_node, copy_node, reshape_node],
+                              "reshape_like_test",
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(ref_shape))])
+
+    model = helper.make_model(graph, producer_name='reshape_like_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape).astype('float32')
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+
+    tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+
+def _test_power_iteration(x_shape, y_shape):
+    if isinstance(y_shape, int):
+        y_shape = [y_shape]
+
+    x = np.random.uniform(size=x_shape).astype(np.float32)
+    y = np.random.uniform(size=y_shape).astype(np.float32)
+
+    np_res = np.power(x, y).astype(np.float32)
+
+    res = helper.make_node("Pow", ['x', 'y'], ['out'])
+
+    graph = helper.make_graph([res],
+                              'power_test',
+                              inputs = [helper.make_tensor_value_info("x",
+                                            TensorProto.FLOAT, list(x_shape)),
+                                        helper.make_tensor_value_info("y",
+                                            TensorProto.FLOAT, list(y_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(np_res.shape))])
+
+    model = helper.make_model(graph, producer_name='power_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape)
+        tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_power():
+    _test_power_iteration((1, 3), (1))
+    _test_power_iteration((2, 3), (2, 3))
+    _test_power_iteration((2, 3), (1, 3))
+
+def test_squeeze():
+    in_shape = (1, 3, 1, 3, 1, 1)
+    out_shape = (3, 3)
+    y = helper.make_node("Squeeze", ['in'], ['out'], axes=[0, 2, 4, 5])
+
+    graph = helper.make_graph([y],
+                              'squeeze_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(out_shape))])
+
+    model = helper.make_model(graph, producer_name='squeeze_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape).astype('float32')
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
+
+    tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+
+def test_unsqueeze():
+    in_shape = (3, 3)
+    axis = (0, 3, 4)
+    out_shape = (1, 3, 3, 1, 1)
+    y = helper.make_node("Unsqueeze", ['in'], ['out'], axes=list(axis))
+
+    graph = helper.make_graph([y],
+                              'squeeze_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(out_shape))])
+
+    model = helper.make_model(graph, producer_name='squeeze_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape).astype('float32')
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32')
+
+    tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+
+def verify_gather(in_shape, indices, axis, dtype):
+    x = np.random.uniform(size=in_shape).astype(dtype)
+    indices = np.array(indices, dtype="int32")
+    out_np = np.take(x, indices, axis=axis)
+
+    y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis)
+
+    graph = helper.make_graph([y],
+                              'gather_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(in_shape)),
+                                        helper.make_tensor_value_info("indices",
+                                            TensorProto.INT32, list(indices.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(out_np.shape))])
+    model = helper.make_model(graph, producer_name='gather_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape)
+        tvm.testing.assert_allclose(out_np, tvm_out)
+
+def test_gather():
+    verify_gather((4,), [1], 0, 'int32')
+    verify_gather((1,4), [0], 0, 'int32')
+    verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32')
+    verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32')
+    verify_gather((3,3,3), [[[1,0]]], -1, 'int32')
+    verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32')
+
+def _test_slice_iteration(indata, outdata, starts, ends, axes=None):
+    if axes:
+        y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends)
+    else:
+        y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends)
+
+    graph = helper.make_graph([y],
+                              'slice_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name='slice_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+
+    tvm.testing.assert_allclose(outdata, tvm_out)
+
+def test_slice():
+    x = np.random.randn(20, 10, 5).astype(np.float32)
+    _test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
+    _test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
+    _test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1))
+    _test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1))
+
+def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
+    indata = np.random.uniform(-1, 1, size=inshape).astype(dtype)
+    outdata = outfunc(indata, **npargs)
+
+    y = helper.make_node(opname, ['in'], ['out'], **kwargs)
+
+    graph = helper.make_graph([y],
+                              opname+'_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name=opname+'_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype)
+
+    tvm.testing.assert_allclose(outdata, tvm_out)
+
+def test_floor():
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {})
+
+def test_ceil():
+    _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {})
+
+def test_clip():
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              np.clip,
+                              {'a_min': -1.0, 'a_max': 1.0},
+                              'float32',
+                              'Clip',
+                              {'min': -1.0, 'max': 1.0})
+
+def test_matmul():
+    a_shape = (4, 3)
+    b_shape = (3, 4)
+
+    a_array = np.random.uniform(size=a_shape).astype('float32')
+    b_array = np.random.uniform(size=b_shape).astype('float32')
+    out_np = np.matmul(a_array, b_array)
+
+    mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
+
+    graph = helper.make_graph([mul_node],
+                              "matmul_test",
+                              inputs = [helper.make_tensor_value_info("a",
+                                            TensorProto.FLOAT, list(a_shape)),
+                                        helper.make_tensor_value_info("b",
+                                            TensorProto.FLOAT, list(b_shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(out_np.shape))])
+
+    model = helper.make_model(graph, producer_name='matmul_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape)
+        tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
+    in_array = np.random.uniform(size=shape).astype(dtype)
+
+    if alpha == None and beta == None and bias==None:
+        alpha = 0.0001
+        beta = 0.75
+        bias = 1.0
+        node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize)
+    else:
+        node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha,
+                                     beta=beta, bias=bias, size=nsize)
+
+    graph = helper.make_graph([node],
+                              "lrn_test",
+                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))])
+    model = helper.make_model(graph, producer_name='lrn_test')
+
+    def _get_python_lrn():
+        square_sum = np.zeros(shape).astype(dtype)
+        for n, c, h, w in np.ndindex(in_array.shape):
+            square_sum[n, c, h, w] = sum(in_array[n,
+                                         max(0, c - int(math.floor((nsize - 1) / 2))): \
+                                             min(5, c + int(math.ceil((nsize - 1) / 2)) + 1),
+                                         h,
+                                         w] ** 2)
+        py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
+        return py_out
+
+    for target, ctx in ctx_list():
+        input_name = model.graph.input[0].name
+        py_out = _get_python_lrn()
+        tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32')
+        tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+def test_lrn():
+    verify_lrn((5, 5, 5, 5), 3, 'float32')
+    verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
+
+def _test_upsample_nearest():
+    scale = 2
+    in_shape = (1, 1, 3, 3)
+    out_shape = (1, 1, 3*scale, 3*scale)
+    y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
+
+    in_array = np.random.uniform(size=in_shape).astype(np.float32)
+    out_array = topi.testing.upsampling_python(in_array, scale, "NCHW")
+
+    graph = helper.make_graph([y],
+                              'upsample_nearest_test',
+                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+
+    model = helper.make_model(graph, producer_name='upsample_nearest_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
+        tvm.testing.assert_allclose(out_array, tvm_out)
+
+def _test_upsample_bilinear():
+    scale = 2
+    in_shape = (1, 1, 3, 3)
+    out_shape = (1, 1, 3*scale, 3*scale)
+    y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0])
+
+    in_array = np.random.uniform(size=in_shape).astype(np.float32)
+    out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW")
+
+    graph = helper.make_graph([y],
+                              'upsample_bilinear_test',
+                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))])
+
+    model = helper.make_model(graph, producer_name='upsample_bilinear_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32')
+        tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_upsample():
+    _test_upsample_nearest()
+    _test_upsample_bilinear()
+
+def _test_softmax(inshape, axis):
+    opname = 'Softmax'
+    indata = np.random.uniform(size=inshape).astype(np.float32)
+    outshape = inshape
+    outdata = topi.testing.softmax_python(indata)
+    if isinstance(axis, int):
+        y = helper.make_node(opname, ['in'], ['out'], axis = axis)
+    elif axis is None:
+        y = helper.make_node(opname, ['in'], ['out'])
+
+    graph = helper.make_graph([y],
+                              opname+'_test',
+                              inputs = [helper.make_tensor_value_info("in",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(outdata.shape))])
+
+    model = helper.make_model(graph, producer_name=opname+'_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32')
+        tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_softmax():
+    _test_softmax((1, 10), None)
+    _test_softmax((1, 10), 1)
+
+def verify_min(input_dim):
+    dtype = 'float32'
+
+    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np3 = np.random.uniform(size=input_dim).astype(dtype)
+
+    b_np = np.min((a_np1, a_np2, a_np3), axis=0)
+
+    min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"])
+
+    graph = helper.make_graph([min_node],
+                              "Min_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np2",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np3",
+                                            TensorProto.FLOAT, list(input_dim))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='Min_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_min():
+    verify_min((1, 3, 20, 20))
+    verify_min((20, 20))
+
+def verify_max(input_dim):
+    dtype = 'float32'
+
+    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np3 = np.random.uniform(size=input_dim).astype(dtype)
+
+    b_np = np.max((a_np1, a_np2, a_np3), axis=0)
+
+    max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"])
+
+    graph = helper.make_graph([max_node],
+                              "Max_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np2",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np3",
+                                            TensorProto.FLOAT, list(input_dim))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='Max_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_max():
+    verify_max((1, 3, 20, 20))
+    verify_max((20, 20))
+
+def verify_mean(input_dim):
+    dtype = 'float32'
+
+    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np2 = np.random.uniform(size=input_dim).astype(dtype)
+    a_np3 = np.random.uniform(size=input_dim).astype(dtype)
+
+    b_np = np.mean((a_np1, a_np2, a_np3), axis=0)
+
+    mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"])
+
+    graph = helper.make_graph([mean_node],
+                              "Mean_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np2",
+                                            TensorProto.FLOAT, list(input_dim)),
+                                        helper.make_tensor_value_info("a_np3",
+                                            TensorProto.FLOAT, list(input_dim))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='Mean_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_mean():
+    verify_mean((1, 3, 20, 20))
+    verify_mean((20, 20))
+
+def verify_hardsigmoid(input_dim, alpha, beta):
+    dtype = 'float32'
+
+    a_np1 = np.random.uniform(size=input_dim).astype(dtype)
+
+    b_np = np.clip(a_np1 * alpha + beta, 0, 1)
+
+    hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta)
+
+    graph = helper.make_graph([hardsigmoid_node],
+                              "HardSigmoid_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.FLOAT, list(input_dim))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='HardSigmoid_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_hardsigmoid():
+    verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6)
+    verify_hardsigmoid((20, 20), 0.3, 0.4)
+
+def verify_argmin(input_dim, axis=None, keepdims=None):
+    def _argmin_numpy(data, axis=0, keepdims=True):
+        result = np.argmin(data, axis=axis)
+        if (keepdims == 1):
+            result = np.expand_dims(result, axis)
+        return result.astype(data.dtype)
+
+    a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
+    if keepdims is None and axis is None:
+        b_np = _argmin_numpy(a_np1)
+        node = onnx.helper.make_node('ArgMin',
+                                     inputs=['a_np1'],
+                                     outputs=['out'])
+    elif axis is None:
+        b_np = _argmin_numpy(a_np1, keepdims=keepdims)
+        node = onnx.helper.make_node('ArgMin',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     keepdims=keepdims)
+    elif keepdims is None:
+        b_np = _argmin_numpy(a_np1, axis=axis)
+        node = onnx.helper.make_node('ArgMin',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     axis=axis)
+    else:
+        b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims)
+        node = onnx.helper.make_node('ArgMin',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     axis=axis,
+                                     keepdims=keepdims)
+    graph = helper.make_graph([node],
+                              "argmin_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.INT32, list(a_np1.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.INT32, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='argmin_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def verify_argmax(input_dim, axis=None, keepdims=None):
+    def _argmax_numpy(data, axis=0, keepdims=True):
+        result = np.argmax(data, axis=axis)
+        if (keepdims == 1):
+            result = np.expand_dims(result, axis)
+        return result.astype(data.dtype)
+
+    a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
+    if keepdims is None and axis is None:
+        b_np = _argmax_numpy(a_np1)
+        node = onnx.helper.make_node('ArgMax',
+                                     inputs=['a_np1'],
+                                     outputs=['out'])
+    elif axis is None:
+        b_np = _argmax_numpy(a_np1, keepdims=keepdims)
+        node = onnx.helper.make_node('ArgMax',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     keepdims=keepdims)
+    elif keepdims is None:
+        b_np = _argmax_numpy(a_np1, axis=axis)
+        node = onnx.helper.make_node('ArgMax',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     axis=axis)
+    else:
+        b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims)
+        node = onnx.helper.make_node('ArgMax',
+                                     inputs=['a_np1'],
+                                     outputs=['out'],
+                                     axis=axis,
+                                     keepdims=keepdims)
+
+    graph = helper.make_graph([node],
+                              "argmax_test",
+                              inputs = [helper.make_tensor_value_info("a_np1",
+                                            TensorProto.INT32, list(a_np1.shape))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.INT32, list(b_np.shape))])
+
+    model = helper.make_model(graph, producer_name='argmax_test')
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype)
+        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_forward_arg_min_max():
+    '''Verify argmin and argmax'''
+    verify_argmin([3,4,4])
+    verify_argmax([3,4,4])
+    verify_argmin([3,4,4], axis=1)
+    verify_argmax([3,4,4], axis=0)
+    verify_argmin([3,4,4], keepdims=0)
+    verify_argmax([3,4,4], keepdims=1)
+    for axis in [None, 0,1,2]:
+        for keepdims in [None, True,False]:
+            verify_argmin([3,4,4], axis, keepdims)
+            verify_argmax([3,4,4], axis, keepdims)
+
+def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
+    input_a = np.random.uniform(size=input_dim).astype(dtype)
+    out = np.empty(shape=out_dim, dtype=dtype)
+    out.fill(value)
+
+    if is_shape == True:
+        fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs)
+    else:
+        fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)
+
+    graph = helper.make_graph([fill_node],
+                              "fill_test",
+                              inputs = [helper.make_tensor_value_info("input_a",
+                                            TensorProto.FLOAT, list(input_dim))],
+                              outputs = [helper.make_tensor_value_info("out",
+                                            TensorProto.FLOAT, list(out.shape))])
+
+    model = helper.make_model(graph, producer_name='fill_test')
+
+    for target, ctx in ctx_list():
+        if is_shape == True:
+            tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
+        else:
+            tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape)
+
+        tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_constantfill():
+    verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
+    verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32')
+    verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6))
+
+
+def verify_pad(indata, pads, value=0.0):
+    indata = np.array(indata).astype(np.float32)
+    #  numpy expect result
+    len_dim = len(pads) // 2
+    np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)]
+    outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value)
+    #  onnx graph
+    node = helper.make_node(
+        'Pad',
+        inputs=['input'],
+        outputs=['output'],
+        mode='constant',
+        pads=pads,
+        value=value
+    )
+    graph = helper.make_graph([node],
+                              'pad_test',
+                              inputs = [helper.make_tensor_value_info("input",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("output",
+                                            TensorProto.FLOAT, list(outdata.shape))])
+    model = helper.make_model(graph, producer_name='pad_test')
+    #  tvm result
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+    tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_pad():
+    verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0)
+    verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0)
+    verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0)
+
+def verify_reduce_x(name, indata, axis, keepdims):
+    indata = np.array(indata).astype(np.float32)
+    #  numpy expect result
+    if name == 'ReduceMax':
+        outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1)
+    elif name == 'ReduceMin':
+        outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1)
+    elif name == 'ReduceSum':
+        outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
+    elif name == 'ReduceMean':
+        outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
+    else:
+        raise Exception('unsupport op: {}'.format(name))
+    if len(np.asarray(outdata).shape) == 0:
+        outdata = np.asarray([outdata])
+    #  onnx graph
+    if axis is None:
+        node = helper.make_node(name, inputs=['input'], outputs=['output'],
+                                keepdims=keepdims)
+    else:
+        node = helper.make_node(name, inputs=['input'], outputs=['output'],
+                                axes=axis, keepdims=keepdims)
+    graph = helper.make_graph([node],
+                              '{}_test'.format(name),
+                              inputs = [helper.make_tensor_value_info("input",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("output",
+                                            TensorProto.FLOAT, list(outdata.shape))])
+    model = helper.make_model(graph, producer_name='{}_test'.format(name))
+    #  tvm result
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32')
+    tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+
+def test_reduce_max():
+    verify_reduce_x("ReduceMax",
+                    np.random.randn(3, 2, 2).astype(np.float32),
+                    axis=None, keepdims=1)
+    verify_reduce_x("ReduceMax",
+                    np.random.randn(3, 2, 3).astype(np.float32),
+                    axis=None, keepdims=0)
+    verify_reduce_x("ReduceMax",
+                    np.random.randn(3, 3, 3).astype(np.float32),
+                    axis=(1,), keepdims=1)
+
+def test_reduce_min():
+    verify_reduce_x("ReduceMin",
+                    np.random.randn(3, 2, 2).astype(np.float32),
+                    axis=None, keepdims=1)
+    verify_reduce_x("ReduceMin",
+                    np.random.randn(3, 2, 3).astype(np.float32),
+                    axis=None, keepdims=0)
+    verify_reduce_x("ReduceMin",
+                    np.random.randn(3, 3, 3).astype(np.float32),
+                    axis=(1,), keepdims=1)
+
+def test_reduce_sum():
+    verify_reduce_x("ReduceSum",
+                    np.random.randn(3, 2, 2).astype(np.float32),
+                    axis=None, keepdims=1)
+    verify_reduce_x("ReduceSum",
+                    np.random.randn(3, 2, 3).astype(np.float32),
+                    axis=None, keepdims=0)
+    verify_reduce_x("ReduceSum",
+                    np.random.randn(3, 3, 3).astype(np.float32),
+                    axis=(1,), keepdims=1)
+
+def test_reduce_mean():
+    verify_reduce_x("ReduceMean",
+                    np.random.randn(3, 2, 2).astype(np.float32),
+                    axis=None, keepdims=1)
+    verify_reduce_x("ReduceMean",
+                    np.random.randn(3, 2, 3).astype(np.float32),
+                    axis=None, keepdims=0)
+    verify_reduce_x("ReduceMean",
+                    np.random.randn(3, 3, 3).astype(np.float32),
+                    axis=(1,), keepdims=1)
+
+def verify_split(indata, outdatas, split, axis=0):
+    indata = np.array(indata).astype(np.float32)
+    outdatas = [np.array(o).astype(np.float32) for o in outdatas]
+    node = helper.make_node(
+        'Split',
+        inputs=['input'],
+        outputs=['output_{}'.format(i) for i in range(len(split))],
+        axis=axis,
+        split=split
+    )
+    graph = helper.make_graph([node],
+                              'split_test',
+                              inputs = [helper.make_tensor_value_info("input",
+                                            TensorProto.FLOAT, list(indata.shape))],
+                              outputs = [helper.make_tensor_value_info("output_{}".format(i),
+                                            TensorProto.FLOAT, list(outdatas[i].shape))
+                                            for i in range(len(split))
+                                         ])
+    model = helper.make_model(graph, producer_name='split_test')
+
+    for target, ctx in ctx_list():
+        output_shape = [o.shape for o in outdatas]
+        output_type = ['float32', 'float32', 'float32']
+        tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type)
+    for o, t in zip(outdatas, tvm_out):
+        tvm.testing.assert_allclose(o, t)
+
+def test_split():
+    # 1D
+    verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0)
+    verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0)
+    # 2D
+    verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]],
+                 [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1)
+
+def test_binary_ops():
+    in_shape = (1, 2, 3, 3)
+    dtype = "float32"
+    out_shape = in_shape
+
+    def verify_binary_ops(op, x, y, out_np, broadcast=None):
+        if broadcast is None:
+            z = helper.make_node(op, ['in1', 'in2'], ['out'])
+        else:
+            z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1)
+        graph = helper.make_graph([z],
+                                   '_test',
+                                  inputs = [helper.make_tensor_value_info("in1",
+                                                TensorProto.FLOAT, list(in_shape)),
+                                            helper.make_tensor_value_info("in2",
+                                                TensorProto.FLOAT, list(in_shape))],
+                                  outputs = [helper.make_tensor_value_info("out",
+                                                TensorProto.FLOAT, list(out_shape))])
+        model = helper.make_model(graph, producer_name='_test')
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(model, [x, y], target, ctx)
+            tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+    x = np.random.uniform(size=in_shape).astype(dtype)
+    y = np.random.uniform(size=in_shape).astype(dtype)
+    z = np.random.uniform(size=(3,)).astype(dtype)
+    verify_binary_ops("Add",x, y, x + y, broadcast=None)
+    verify_binary_ops("Add", x, z,  x + z, broadcast=True)
+    verify_binary_ops("Sub", x, y, x - y, broadcast=None)
+    verify_binary_ops("Sub", x, z, x - z, broadcast=True)
+    verify_binary_ops("Mul",x, y, x * y, broadcast=None)
+    verify_binary_ops("Mul", x, z,  x * z, broadcast=True)
+    verify_binary_ops("Div", x, y, x / y, broadcast=None)
+    verify_binary_ops("Div", x, z, x / z, broadcast=True)
+    verify_binary_ops("Sum", x, y, x + y, broadcast=None)
+
+def test_single_ops():
+    in_shape = (1, 2, 3, 3)
+    dtype = "float32"
+    out_shape = in_shape
+
+    def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
+        z = helper.make_node(op, ['in1'], ['out'])
+        graph = helper.make_graph([z],
+                                   '_test',
+                                  inputs = [helper.make_tensor_value_info("in1",
+                                                TensorProto.FLOAT, list(in_shape)),],
+                                  outputs = [helper.make_tensor_value_info("out",
+                                                TensorProto.FLOAT, list(out_shape))])
+        model = helper.make_model(graph, producer_name='_test')
+        for target, ctx in ctx_list():
+            tvm_out = get_tvm_output(model, [x], target, ctx)
+            tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
+
+    x = np.random.uniform(size=in_shape).astype(dtype)
+    verify_single_ops("Neg",x, -x)
+    verify_single_ops("Abs",x, np.abs(x))
+    verify_single_ops("Reciprocal",x, 1/x)
+    verify_single_ops("Sqrt",x, np.sqrt(x))
+    verify_single_ops("Relu",x, np.maximum(x, 0))
+    verify_single_ops("Exp",x, np.exp(x))
+    verify_single_ops("Log",x, np.log(x))
+    verify_single_ops("Log",x, np.log(x))
+    verify_single_ops("Tanh",x, np.tanh(x))
+    verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)))
+    verify_single_ops("Softsign",x, x / (1 + np.abs(x)))
+    verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)))
+
+def test_leaky_relu():
+    def leaky_relu_x(x, alpha):
+        return np.where(x >= 0, x, x * alpha)
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              leaky_relu_x,
+                              {'alpha': 0.25},
+                              'float32',
+                              'LeakyRelu',
+                              {'alpha': 0.25})
+
+def test_elu():
+    def elu_x(x, alpha):
+        return np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              elu_x,
+                              {'alpha': 0.25},
+                              'float32',
+                              'Elu',
+                              {'alpha': 0.25})
+
+def test_selu():
+    def selu_x(x, alpha, gamma):
+        return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0))
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              selu_x,
+                              {'alpha': 0.25, 'gamma': 0.3},
+                              'float32',
+                              'Selu',
+                              {'alpha': 0.25, 'gamma': 0.3})
+
+def test_ThresholdedRelu():
+    def ThresholdedRelu_x(x, alpha):
+        out_np = np.clip(x, alpha, np.inf)
+        out_np[out_np == alpha] = 0
+        return out_np
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              ThresholdedRelu_x,
+                              {'alpha': 0.25},
+                              'float32',
+                              'ThresholdedRelu',
+                              {'alpha': 0.25})
+
+def test_ScaledTanh():
+    def ScaledTanh_x(x, alpha, beta):
+        return alpha * np.tanh(beta * x)
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              ScaledTanh_x,
+                              {'alpha': 0.25, 'beta': 0.3},
+                              'float32',
+                              'ScaledTanh',
+                              {'alpha': 0.25, 'beta': 0.3})
+
+def test_ParametricSoftplus():
+    def ParametricSoftplus_x(x, alpha, beta):
+        return alpha * np.log(np.exp(beta * x) + 1)
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              ParametricSoftplus_x,
+                              {'alpha': 0.25, 'beta': 0.3},
+                              'float32',
+                              'ParametricSoftplus',
+                              {'alpha': 0.25, 'beta': 0.3})
+
+def test_Scale():
+    def Scale_x(x, scale):
+        return scale * x
+    _test_onnx_op_elementwise((2, 4, 5, 6),
+                              Scale_x,
+                              {'scale': 0.25},
+                              'float32',
+                              'Scale',
+                              {'scale': 0.25})
+
+def test_LogSoftmax():
+    _test_onnx_op_elementwise((1, 4),
+                              topi.testing.log_softmax_python,
+                              {},
+                              'float32',
+                              'LogSoftmax',
+                              {'axis': 1})
+
+if __name__ == '__main__':
+    test_reshape()
+    test_reshape_like()
+    test_power()
+    test_squeeze()
+    test_unsqueeze()
+    test_slice()
+    test_floor()
+    test_ceil()
+    test_clip()
+    test_matmul()
+    test_gather()
+    test_lrn()
+    test_upsample()
+    test_forward_min()
+    test_forward_max()
+    test_forward_mean()
+    test_forward_hardsigmoid()
+    test_forward_arg_min_max()
+    test_softmax()
+    test_constantfill()
+    test_pad()
+    test_reduce_max()
+    test_reduce_min()
+    test_reduce_sum()
+    test_reduce_mean()
+    test_pad()
+    test_split()
+    test_binary_ops()
+    test_single_ops()
+    test_leaky_relu()
+    test_elu()
+    test_selu()
+    test_ThresholdedRelu()
+    test_ScaledTanh()
+    test_ParametricSoftplus()
+    test_Scale()
+    test_LogSoftmax()
diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh
index d3041e20b..bede96e80 100755
--- a/tests/scripts/task_python_frontend.sh
+++ b/tests/scripts/task_python_frontend.sh
@@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1
 
 echo "Running relay Keras frontend test..."
 python3 -m nose -v tests/python/frontend/keras || exit -1
+
+echo "Running relay ONNX frondend test..."
+python3 -m nose -v tests/python/frontend/onnx || exit -1
diff --git a/tutorials/relay/from_onnx.py b/tutorials/relay/from_onnx.py
new file mode 100644
index 000000000..46b803a3a
--- /dev/null
+++ b/tutorials/relay/from_onnx.py
@@ -0,0 +1,93 @@
+"""
+Compile ONNX Models
+===================
+**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
+
+This article is an introductory tutorial to deploy ONNX models with Relay.
+
+For us to begin with, ONNX package must be installed.
+
+A quick solution is to install protobuf compiler, and
+
+.. code-block:: bash
+
+    pip install onnx --user
+
+or please refer to offical site.
+https://github.com/onnx/onnx
+"""
+import onnx
+import numpy as np
+import tvm
+import tvm.relay as relay
+
+def download(url, path, overwrite=False):
+    import os
+    if os.path.isfile(path) and not overwrite:
+        print('File {} existed, skip.'.format(path))
+        return
+    print('Downloading from url {} to {}'.format(url, path))
+    try:
+        import urllib.request
+        urllib.request.urlretrieve(url, path)
+    except:
+        import urllib
+        urllib.urlretrieve(url, path)
+
+######################################################################
+# Load pretrained ONNX model
+# ---------------------------------------------
+# The example super resolution model used here is exactly the same model in onnx tutorial
+# http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html
+# we skip the pytorch model construction part, and download the saved onnx model
+model_url = ''.join(['https://gist.github.com/zhreshold/',
+                     'bcda4716699ac97ea44f791c24310193/raw/',
+                     '93672b029103648953c4e5ad3ac3aadf346a4cdc/',
+                     'super_resolution_0.2.onnx'])
+download(model_url, 'super_resolution.onnx', False)
+# now you have super_resolution.onnx on disk
+onnx_model = onnx.load('super_resolution.onnx')
+
+######################################################################
+# Load a test image
+# ---------------------------------------------
+# A single cat dominates the examples!
+from PIL import Image
+img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
+download(img_url, 'cat.png')
+img = Image.open('cat.png').resize((224, 224))
+img_ycbcr = img.convert("YCbCr")  # convert to YCbCr
+img_y, img_cb, img_cr = img_ycbcr.split()
+x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
+
+######################################################################
+# Compile the model with relay
+# ---------------------------------------------
+target = 'llvm'
+
+input_name = '1'
+shape_dict = {input_name: x.shape}
+sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
+
+with relay.build_config(opt_level=1):
+    intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target)
+
+######################################################################
+# Execute on TVM
+# ---------------------------------------------
+tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy()
+
+######################################################################
+# Display results
+# ---------------------------------------------
+# We put input and output image neck to neck
+from matplotlib import pyplot as plt
+out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L')
+out_cb = img_cb.resize(out_y.size, Image.BICUBIC)
+out_cr = img_cr.resize(out_y.size, Image.BICUBIC)
+result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB')
+canvas = np.full((672, 672*2, 3), 255)
+canvas[0:224, 0:224, :] = np.asarray(img)
+canvas[:, 672:, :] = np.asarray(result)
+plt.imshow(canvas.astype(np.uint8))
+plt.show()
-- 
GitLab