diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py
index b34d3bbd8283c4eb357ca2ec7558510d4f6b3f43..0ab6f4a331ed2466a515aa887811e5541d0bdaad 100644
--- a/nnvm/python/nnvm/frontend/onnx.py
+++ b/nnvm/python/nnvm/frontend/onnx.py
@@ -4,87 +4,119 @@ from __future__ import absolute_import as _abs
 import tvm
 from .. import symbol as _sym
 from .. import graph as _graph
-from .. compiler import graph_util
+from ..compiler import graph_util
 from .common import get_nnvm_op, Renamer, AttrConverter as AttrCvt
 
 __all__ = ['from_onnx']
 
-def _revert_caffe2_pad(attr):
-    """Caffe2 require two times the normal padding."""
-    if len(attr) == 4:
-        attr = attr[:2]
-    elif len(attr) == 2:
-        pass
-    else:
-        raise ValueError("Invalid caffe2 type padding: {}".format(attr))
-    return attr
 
-def _math_name_picker(surfix):
-    def _impl(attr):
-        if attr.get('broadcast', 0):
-            return 'broadcast_' + surfix
-        return 'elemwise_' + surfix
-    return _impl
+class OnnxOpConverter(object):
+    """ A helper class for holding onnx op converters.
+    """
 
-def _broadcast_constraint():
-    def _broadcast_check(attrs):
-        if attrs.get('axis', None):
-            return False
-        return True
-    return _broadcast_check, "Specifying broadcast axis not allowed."
+    @classmethod
+    def get_converter(cls, opset):
+        """ Get converter matches given opset.
 
-def _dimension_picker(prefix, surfix=''):
-    def _impl(attr):
-        kernel = attr['kernel_shape']
-        if len(kernel) == 2:
-            return prefix + '2d' + surfix
+        :param opset: opset from model.
+        :return: converter, which should be `_impl_vx`. Number x is the biggest
+            number smaller than or equal to opset belongs to all support versions.
+        """
+        versions = [
+            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
+        ]
+        versions = sorted(versions + [opset])
+        version = versions[
+            max([i for i, v in enumerate(versions) if v == opset]) - 1]
+        if hasattr(cls, '_impl_v{}'.format(version)):
+            return getattr(cls, '_impl_v{}'.format(version))
         else:
-            raise NotImplementedError("Only 2d kernel supported.")
-    return _impl
+            raise NotImplementedError(
+                'opset version {} of {} not implemented'.format(
+                    version, cls.__name__))
 
-def _dimension_constraint():
-    def _dim_check(attrs):
-        if len(attrs['kernel_shape']) == 2:
-            return True
-        return False
-    return _dim_check, "Only 2d kernel supported."
 
-def _infer_channels(inputs, params, transpose=False):
-    """A hack for getting 'channles' or 'units' since onnx don't provide
-    these attributes. We check the shape of weights provided to get the number.
+class Elemwise(OnnxOpConverter):
+    """ A helper class for elemwise op converters.
     """
-    g = _graph.create(inputs)
-    shape_dict = {k: v.shape for k, v in params.items()}
-    _, out_shapes = graph_util.infer_shape(g, **shape_dict)
-    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
-    return channels
 
-def _elemwise(name):
-    def _impl(inputs, attr, *args):
-        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
-        op_name = _math_name_picker(name)(attr)
+    name = ''
+
+    @classmethod
+    def _math_name_picker(cls, suffix):
+
+        def _impl(attr):
+            if attr.get('broadcast', 0):
+                return 'broadcast_' + suffix
+            return 'elemwise_' + suffix
+
+        return _impl
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
+            len(inputs))
+        op_name = cls._math_name_picker(cls.name)(attr)
         axis = int(attr.get('axis', 0))
         conv_ops = ["conv2d", "conv2d_transpose"]
         if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
             # TODO(zhreshold): remove hard coded infershape
             inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
         return get_nnvm_op(op_name)(*inputs)
-    return _impl
 
-def _pooling(name):
-    return AttrCvt(
-        op_name=_dimension_picker(name),
-        transforms={
-            'kernel_shape': 'pool_size',
-            'pads': ('padding', (0, 0), _revert_caffe2_pad)},
-        # very weird attributes here in onnx, force check
-        ignores=['dilations'],
-        # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
-        extras={'ceil_mode': False},
-        custom_check=_dimension_constraint())
-
-def _conv():
-    def _impl(inputs, attr, params):
+
+class Pool(OnnxOpConverter):
+    """ A helper class for pool op converters.
+    """
+
+    name = ''
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return AttrCvt(
+            op_name=_dimension_picker(cls.name),
+            transforms={
+                'kernel_shape': 'pool_size',
+                'pads': ('padding', (0, 0), _revert_caffe2_pad)
+            },
+            # very weird attributes here in onnx, force check
+            ignores=['dilations'],
+            # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
+            extras={'ceil_mode': False},
+            custom_check=_dimension_constraint())(inputs, attr, params)
+
+
+class Absolute(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
+
+
+class Add(Elemwise):
+    name = 'add'
+
+
+class AveragePool(Pool):
+    name = 'avg_pool'
+
+
+class BatchNorm(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # TODO(zhreshold): 'spatial' is not properly handled here.
+        return AttrCvt(
+            op_name='batch_norm',
+            disables=['momentum'],
+            ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
+                                                               params)
+
+
+class Conv(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         # get number of channels
         channels = _infer_channels(inputs[1], params)
         attr['channels'] = channels
@@ -94,13 +126,16 @@ def _conv():
                 'kernel_shape': 'kernel_size',
                 'dilations': ('dilation', (0, 0)),
                 'pads': ('padding', (0, 0), _revert_caffe2_pad),
-                'group': ('groups', 1)},
+                'group': ('groups', 1)
+            },
             extras={'use_bias': len(inputs) == 3},
-            custom_check=_dimension_constraint())(inputs, attr)
-    return _impl
+            custom_check=_dimension_constraint())(inputs, attr, params)
 
-def _conv_transpose():
-    def _impl(inputs, attr, params):
+
+class ConvTranspose(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         # get number of channels
         channels = _infer_channels(inputs[1], params, True)
         attr['channels'] = channels
@@ -111,31 +146,34 @@ def _conv_transpose():
             transforms={
                 'kernel_shape': 'kernel_size',
                 'dilations': ('dilation', (0, 0)),
-                'pads': ('padding', (0, 0), _revert_caffe2_pad)},
+                'pads': ('padding', (0, 0), _revert_caffe2_pad)
+            },
             disables=['output_shape'],
             extras={'use_bias': len(inputs) == 3},
-            custom_check=_dimension_constraint())(inputs, attr)
-    return _impl
+            custom_check=_dimension_constraint())(inputs, attr, params)
 
-def _fully_connected():
-    def _impl(inputs, attr, params):
-        # get number of channels
-        channels = _infer_channels(inputs[1], params)
-        attr['units'] = channels
-        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
-    return _impl
 
-def _batch_norm():
-    # TODO(zhreshold): 'spatial' is not properly handled here.
-    return AttrCvt(
-        op_name='batch_norm',
-        disables=['momentum'],
-        ignores=['spatial', 'is_test', 'consumed_inputs'])
+class Div(Elemwise):
+    name = 'div'
 
 
-def _gemm():
-    def _impl(inputs, attr, params):
-        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
+class Elu(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 1.0))
+        return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(
+            inputs[0])
+
+
+class Gemm(OnnxOpConverter):
+    """ Operator converter for Gemm.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
+            len(inputs))
         # Y = alpha * A * B + beta * C
         alpha = float(attr.get('alpha', 1.0))
         beta = float(attr.get('beta', 1.0))
@@ -147,217 +185,325 @@ def _gemm():
             inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
         if not transB:
             inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
-        return _sym.dense(alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
-    return _impl
+        return _sym.dense(
+            alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
 
-def _thresholded_relu():
-    def _impl(inputs, attr, params):
-        alpha = float(attr.get('alpha', 0.0))
-        return _sym.relu(inputs[0] - alpha)
-    return _impl
 
-def _scaled_tanh():
-    def _impl(inputs, attr, params):
-        alpha = float(attr.get('alpha', 1.0))
-        beta = float(attr.get('beta', 1.0))
-        return _sym.tanh(beta * inputs[0]) * alpha
-    return _impl
+class MaxPool(Pool):
+    name = 'max_pool'
 
-def parametric_soft_plus():
-    def _impl(inputs, attr, params):
+
+class Mul(Elemwise):
+    name = 'mul'
+
+
+class Pad(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # get number of channels
+        channels = _infer_channels(inputs[1], params, True)
+        attr['channels'] = channels
+        groups = attr.pop('group')
+        attr['groups'] = groups
+        return AttrCvt(
+            op_name='pad',
+            transforms={
+                'value': 'pad_value',
+                'pads': 'pad_width'
+            },
+            custom_check=lambda attrs: attrs.get('mode') == 'constant')(
+                inputs, attr, params)
+
+
+class ParametricSoftPlus(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         alpha = float(attr.get('alpha', 1.0))
         beta = float(attr.get('beta', 1.0))
         return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
-    return _impl
 
-def _scale():
-    def _impl(inputs, attr, params):
+
+class Prelu(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
+            len(inputs))
+        channels = _infer_channels(inputs[1], params, False)
+        if channels == 1:
+            return inputs[0] * inputs[1]
+        return _sym.broadcast_mul(inputs[0], inputs[1])
+
+
+class Reciprocal(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return 1.0 / inputs[0]
+
+
+class Reshape(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return _sym.reshape(inputs[0], shape=attr['shape'])
+
+    @classmethod
+    def _impl_v5(cls, inputs, attr, params):
+        return _sym.reshape(
+            inputs[0],
+            shape=tuple(params[inputs[1].list_output_names()[0]].asnumpy()))
+
+
+class Scale(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         scale = float(attr.get('scale', 1.0))
         return inputs[0] * scale
-    return _impl
 
-def _absolute():
-    """This is a workaround."""
-    def _impl(inputs, attr, params):
-        return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
-    return _impl
 
-def _reciprocal():
-    def _impl(inputs, attr, params):
-        return 1.0 / inputs[0]
-    return _impl
+class Selu(OnnxOpConverter):
 
-def _selu():
-    def _impl(inputs, attr, params):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         alpha = float(attr.get('alpha', 1.6732))
         gamma = float(attr.get('gamma', 1.0507))
-        return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0]))
-                        + _sym.relu(inputs[0]))
-    return _impl
+        return gamma * (
+            -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
 
-def _elu():
-    def _impl(inputs, attr, params):
+
+class ScaledTanh(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         alpha = float(attr.get('alpha', 1.0))
-        return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])
-    return _impl
+        beta = float(attr.get('beta', 1.0))
+        return _sym.tanh(beta * inputs[0]) * alpha
 
-def _prelu():
-    def _impl(inputs, attr, params):
-        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
-        channels = _infer_channels(inputs[1], params, False)
-        if channels == 1:
-            return inputs[0] * inputs[1]
-        return _sym.broadcast_mul(inputs[0], inputs[1])
-    return _impl
 
-def _softsign():
-    def _impl(inputs, attr, params):
-        return inputs[0] / (1 + _absolute()(inputs, attr, params))
-    return _impl
+class SoftPlus(OnnxOpConverter):
 
-def _softplus():
-    def _impl(inputs, attr, params):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
         return _sym.log(_sym.exp(inputs[0]) + 1)
-    return _impl
 
-def _pad():
-    def _impl(inputs, attr, params):
-        # get number of channels
-        channels = _infer_channels(inputs[1], params, True)
-        attr['channels'] = channels
-        groups = attr.pop('group')
-        attr['groups'] = groups
-        return AttrCvt(
-            op_name='pad',
-            transforms={
-                'value': 'pad_value',
-                'pads': 'pad_width'},
-            custom_check=lambda attrs: attrs.get('mode') == 'constant')(inputs, attr)
+
+class Softsign(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params))
+
+
+class Sub(Elemwise):
+    name = 'sub'
+
+
+class Sum(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # Onnx Sum Operator
+        for in_index in range(len(inputs) - 1):
+            inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index],
+                                                      inputs[in_index + 1])
+
+        return inputs[len(inputs) - 1]
+
+
+class ThresholdedRelu(OnnxOpConverter):
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        alpha = float(attr.get('alpha', 0.0))
+        return _sym.relu(inputs[0] - alpha)
+
+
+def _revert_caffe2_pad(attr):
+    """Caffe2 require two times the normal padding."""
+    if len(attr) == 4:
+        attr = attr[:2]
+    elif len(attr) == 2:
+        pass
+    else:
+        raise ValueError("Invalid caffe2 type padding: {}".format(attr))
+    return attr
+
+
+def _broadcast_constraint():
+
+    def _broadcast_check(attrs):
+        if attrs.get('axis', None):
+            return False
+        return True
+
+    return _broadcast_check, "Specifying broadcast axis not allowed."
+
+
+def _dimension_picker(prefix, surfix=''):
+
+    def _impl(attr):
+        kernel = attr['kernel_shape']
+        if len(kernel) == 2:
+            return prefix + '2d' + surfix
+        else:
+            raise NotImplementedError("Only 2d kernel supported.")
+
     return _impl
 
-def _sum():
+
+def _dimension_constraint():
+
+    def _dim_check(attrs):
+        if len(attrs['kernel_shape']) == 2:
+            return True
+        return False
+
+    return _dim_check, "Only 2d kernel supported."
+
+
+def _infer_channels(inputs, params, transpose=False):
+    """A hack for getting 'channles' or 'units' since onnx don't provide
+    these attributes. We check the shape of weights provided to get the number.
+    """
+    g = _graph.create(inputs)
+    shape_dict = {k: v.shape for k, v in params.items()}
+    _, out_shapes = graph_util.infer_shape(g, **shape_dict)
+    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
+    return channels
+
+
+def _fully_connected(opset):
+
     def _impl(inputs, attr, params):
-        # Onnx Sum Operator
-        for in_index in range(len(inputs)-1):
-            inputs[in_index+1] = _sym.broadcast_add(inputs[in_index], inputs[in_index+1])
+        # get number of channels
+        channels = _infer_channels(inputs[1], params)
+        attr['units'] = channels
+        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
 
-        return inputs[len(inputs)-1]
     return _impl
 
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
+
 # _convert_map defines maps of name to converter functor(callable)
 # for 1 to 1 mapping, use Renamer if nothing but name is different
 # use AttrCvt if attributes need to be converted
 # for 1 to N mapping(composed), use custom callable functions
 # for N to 1 mapping, currently not supported(?)
-_convert_map = {
-    # defs/experimental
-    'Identity'      : Renamer('copy'),
-    # 'Affine'
-    'ThresholdedRelu': _thresholded_relu(),
-    'ScaledTanh'    : _scaled_tanh(),
-    'ParametricSoftplus': parametric_soft_plus(),
-    # 'ConstantFill'
-    # 'GivenTensorFill'
-    'FC'            : AttrCvt('dense', ignores=['axis', 'axis_w']),
-    'Scale'         : _scale(),
-    # 'GRUUnit'
-    # 'ATen'
-    # 'ImageScaler'
-    # 'MeanVarianceNormalization'
-    # 'Crop'
-    # 'Embedding'
-    # 'Upsample'
-    'SpatialBN'     : _batch_norm(),
-
-    # defs/generator
-    # 'Constant'
-    # 'RandomUniform'
-    # 'RandomNormal'
-    # 'RandomUniformLike'
-    # 'RandomNormalLike'
-
-    # defs/logical
-
-    # defs/math
-    'Add'           : _elemwise('add'),
-    'Sub'           : _elemwise('sub'),
-    'Mul'           : _elemwise('mul'),
-    'Div'           : _elemwise('div'),
-    'Neg'           : Renamer('negative'),
-    'Abs'           : _absolute(),
-    'Reciprocal'    : _reciprocal(),
-    # 'Floor'
-    # 'Ceil'
-    'Sqrt'          : Renamer('sqrt'),
-    'Relu'          : Renamer('relu'),
-    'LeakyRelu'     : Renamer('leaky_relu'),
-    'Selu'          : _selu(),
-    'Elu'           : _elu(),
-    'Exp'           : Renamer('exp'),
-    'Log'           : Renamer('log'),
-    'Tanh'          : Renamer('tanh'),
-    # 'Pow'
-    'PRelu'         : _prelu(),
-    'Sigmoid'       : Renamer('sigmoid'),
-    # 'HardSigmoid'
-    # 'Max' : this is the elemwise maximum
-    # 'Min' : this is the elemwise minimum
-    'Sum'           : _sum(),
-    # 'Mean'
-    # 'Clip'
-    # softmax default axis is different in onnx
-    'Softmax'       : AttrCvt('softmax', {'axis': ('axis', 1)}),
-    'LogSoftmax'    : AttrCvt('log_softmax', {'axis': ('axis', 1)}),
-    # 'Hardmax'
-    'Softsign'      : _softsign(),
-    'SoftPlus'      : _softplus(),
-    'Gemm'          : _gemm(),
-    # 'MatMul'  batch stacked dot operation
-
-    # defs/nn
-    'AveragePool'   : _pooling('avg_pool'),
-    'MaxPool'       : _pooling('max_pool'),
-    'Conv'          : _conv(),
-    'ConvTranspose' : _conv_transpose(),
-    'GlobalAveragePool': Renamer('global_avg_pool2d'),
-    'GlobalMaxPool' : Renamer('global_max_pool2d'),
-    'BatchNormalization': _batch_norm(),
-    # 'InstanceNormalization'
-    # 'LpNormalization'
-    'Dropout'       : AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
-    'Flatten'       : Renamer('flatten'),
-    # 'LRN'
-
-    # defs/reduction
-    'ReduceMax'     : AttrCvt('max', {'axes', 'axis'}),
-    'ReduceMin'     : AttrCvt('min', {'axes', 'axis'}),
-    'ReduceSum'     : AttrCvt('sum', {'axes', 'axis'}),
-    # 'ReduceMean'
-    # 'ReduceProd'
-    # 'ReduceLogSumExp'
-    # 'ArgMax'
-    # 'ArgMin'
-
-    # defs/tensor
-    'Cast'          : AttrCvt('cast', {'to': 'dtype'}),
-    'Reshape'       : Renamer('reshape'),
-    'Concat'        : Renamer('concatenate'),
-    'Split'         : AttrCvt('split', {'split': 'indices_or_sections'}),
-    # 'Slice'
-    'Transpose'     : AttrCvt('transpose', {'perm': 'axes'}),
-    # 'Gather'
-    # 'Squeeze'
-    'Pad'           : _pad(),
-}
+def _get_convert_map(opset):
+    return {
+        # defs/experimental
+        'Identity': Renamer('copy'),
+        # 'Affine'
+        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
+        'ScaledTanh': ScaledTanh.get_converter(opset),
+        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
+        # 'ConstantFill'
+        # 'GivenTensorFill'
+        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
+        'Scale': Scale.get_converter(opset),
+        # 'GRUUnit'
+        # 'ATen'
+        # 'ImageScaler'
+        # 'MeanVarianceNormalization'
+        # 'Crop'
+        # 'Embedding'
+        # 'Upsample'
+        'SpatialBN': BatchNorm.get_converter(opset),
+
+        # defs/generator
+        # 'Constant'
+        # 'RandomUniform'
+        # 'RandomNormal'
+        # 'RandomUniformLike'
+        # 'RandomNormalLike'
+
+        # defs/logical
+
+        # defs/math
+        'Add': Add.get_converter(opset),
+        'Sub': Sub.get_converter(opset),
+        'Mul': Mul.get_converter(opset),
+        'Div': Div.get_converter(opset),
+        'Neg': Renamer('negative'),
+        'Abs': Absolute.get_converter(opset),
+        'Reciprocal': Reciprocal.get_converter(opset),
+        # 'Floor'
+        # 'Ceil'
+        'Sqrt': Renamer('sqrt'),
+        'Relu': Renamer('relu'),
+        'LeakyRelu': Renamer('leaky_relu'),
+        'Selu': Selu.get_converter(opset),
+        'Elu': Elu.get_converter(opset),
+        'Exp': Renamer('exp'),
+        'Log': Renamer('log'),
+        'Tanh': Renamer('tanh'),
+        # 'Pow'
+        'PRelu': Prelu.get_converter(opset),
+        'Sigmoid': Renamer('sigmoid'),
+        # 'HardSigmoid'
+        # 'Max' : this is the elemwise maximum
+        # 'Min' : this is the elemwise minimum
+        'Sum': Sum.get_converter(opset),
+        # 'Mean'
+        # 'Clip'
+        # softmax default axis is different in onnx
+        'Softmax': AttrCvt('softmax', {'axis': ('axis', 1)}),
+        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
+        # 'Hardmax'
+        'Softsign': Softsign.get_converter(opset),
+        'SoftPlus': SoftPlus.get_converter(opset),
+        'Gemm': Gemm.get_converter(opset),
+        # 'MatMul'  batch stacked dot operation
+
+        # defs/nn
+        'AveragePool': AveragePool.get_converter(opset),
+        'MaxPool': MaxPool.get_converter(opset),
+        'Conv': Conv.get_converter(opset),
+        'ConvTranspose': ConvTranspose.get_converter(opset),
+        'GlobalAveragePool': Renamer('global_avg_pool2d'),
+        'GlobalMaxPool': Renamer('global_max_pool2d'),
+        'BatchNormalization': BatchNorm.get_converter(opset),
+        # 'InstanceNormalization'
+        # 'LpNormalization'
+        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
+        'Flatten': Renamer('flatten'),
+        # 'LRN'
+
+        # defs/reduction
+        'ReduceMax': AttrCvt('max', {'axes', 'axis'}),
+        'ReduceMin': AttrCvt('min', {'axes', 'axis'}),
+        'ReduceSum': AttrCvt('sum', {'axes', 'axis'}),
+        # 'ReduceMean'
+        # 'ReduceProd'
+        # 'ReduceLogSumExp'
+        # 'ArgMax'
+        # 'ArgMin'
+
+        # defs/tensor
+        'Cast': AttrCvt('cast', {'to': 'dtype'}),
+        'Reshape': Reshape.get_converter(opset),
+        'Concat': Renamer('concatenate'),
+        'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
+        # 'Slice'
+        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
+        # 'Gather'
+        # 'Squeeze'
+        'Pad': Pad.get_converter(opset),
+    }
 
 
 class GraphProto(object):
     """A helper class for handling nnvm graph copying from pb2.GraphProto.
     Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
     """
+
     def __init__(self):
         self._nodes = {}
         self._params = {}
@@ -365,7 +511,7 @@ class GraphProto(object):
         self._num_input = 0
         self._num_param = 0
 
-    def from_onnx(self, graph):
+    def from_onnx(self, graph, opset):
         """Construct nnvm nodes from onnx graph.
         The inputs from onnx graph is vague, only providing "1", "2"...
         For convenience, we rename the `real` input names to "input_0",
@@ -375,6 +521,7 @@ class GraphProto(object):
         ----------
         graph : onnx protobuf object
             The loaded onnx graph
+        opset : opset version
 
         Returns
         -------
@@ -410,7 +557,7 @@ class GraphProto(object):
             op_name = node.op_type
             attr = self._parse_attr(node.attribute)
             inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
-            op = self._convert_operator(op_name, inputs, attr)
+            op = self._convert_operator(op_name, inputs, attr, opset)
             node_output = self._fix_outputs(op_name, node.output)
             assert len(node_output) == len(op.list_output_names()), (
                 "Number of output mismatch {} vs {} in {}.".format(
@@ -438,7 +585,8 @@ class GraphProto(object):
         try:
             from onnx.numpy_helper import to_array
         except ImportError as e:
-            raise ImportError("Unable to import onnx which is required {}".format(e))
+            raise ImportError(
+                "Unable to import onnx which is required {}".format(e))
         np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
         return tvm.nd.array(np_array)
 
@@ -455,15 +603,23 @@ class GraphProto(object):
                     attrs[a.name] = tuple(getattr(a, f))
             for f in ['t', 'g']:
                 if a.HasField(f):
-                    raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
+                    raise NotImplementedError(
+                        "Filed {} is not supported in nnvm.".format(f))
             for f in ['tensors', 'graphs']:
                 if list(getattr(a, f)):
-                    raise NotImplementedError("Filed {} is not supported in nnvm.".format(f))
+                    raise NotImplementedError(
+                        "Filed {} is not supported in nnvm.".format(f))
             if a.name not in attrs:
                 raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
         return attrs
 
-    def _convert_operator(self, op_name, inputs, attrs, identity_list=None, convert_map=None):
+    def _convert_operator(self,
+                          op_name,
+                          inputs,
+                          attrs,
+                          opset,
+                          identity_list=None,
+                          convert_map=None):
         """Convert from onnx operator to nnvm operator.
         The converter must specify conversions explicity for incompatible name, and
         apply handlers to operator attributes.
@@ -476,6 +632,8 @@ class GraphProto(object):
             List of input symbols.
         attrs : dict
             Dict of operator attributes
+        opset : int
+            Opset version
         identity_list : list
             List of operators that don't require conversion
         convert_map : dict
@@ -489,13 +647,14 @@ class GraphProto(object):
             Converted nnvm Symbol
         """
         identity_list = identity_list if identity_list else _identity_list
-        convert_map = convert_map if convert_map else _convert_map
+        convert_map = convert_map if convert_map else _get_convert_map(opset)
         if op_name in identity_list:
             sym = get_nnvm_op(op_name)(*inputs, **attrs)
         elif op_name in convert_map:
             sym = convert_map[op_name](inputs, attrs, self._params)
         else:
-            raise NotImplementedError("Operator {} not implemented.".format(op_name))
+            raise NotImplementedError(
+                "Operator {} not implemented.".format(op_name))
         return sym
 
     def _fix_outputs(self, op_name, outputs):
@@ -510,7 +669,7 @@ class GraphProto(object):
         return outputs
 
 
-def from_onnx(graph):
+def from_onnx(model):
     """Load onnx graph which is a python protobuf object into nnvm graph.
     The companion parameters will be handled automatically.
     The inputs from onnx graph is vague, only providing "1", "2"...
@@ -519,8 +678,8 @@ def from_onnx(graph):
 
     Parameters
     ----------
-    graph : protobuf object
-        ONNX GraphProto, or ONNX ModelProto after ONNX v0.2
+    model : protobuf object
+        ONNX ModelProto after ONNX v1.1.0
 
     Returns
     -------
@@ -531,8 +690,7 @@ def from_onnx(graph):
         Dict of converted parameters stored in tvm.ndarray format
     """
     g = GraphProto()
-    if hasattr(graph, 'graph'):
-        # it's a ModelProto wrapper
-        graph = graph.graph
-    sym, params = g.from_onnx(graph)
+    graph = model.graph
+    opset = model.opset_import[0].version if model.opset_import else 1
+    sym, params = g.from_onnx(graph, opset)
     return sym, params
diff --git a/nnvm/tests/ci_build/install/ubuntu_install_onnx.sh b/nnvm/tests/ci_build/install/ubuntu_install_onnx.sh
index 37bd4e66edde1f094980453e97d1f273804a822b..138e70d712002f424c70e201380f46ce8af96158 100644
--- a/nnvm/tests/ci_build/install/ubuntu_install_onnx.sh
+++ b/nnvm/tests/ci_build/install/ubuntu_install_onnx.sh
@@ -1,5 +1,5 @@
-pip2 install onnx>=0.2.0
-pip3 install onnx>=0.2.0
+pip2 install onnx>=1.1.0
+pip3 install onnx>=1.1.0
 
 pip2 install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
 pip2 install torchvision
diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py
index d71e171836aefc5d953f4002846ac60b49d64673..3a7076d17b68868d6a2abeac0f5f6dc13f4a0036 100644
--- a/nnvm/tests/python/frontend/onnx/test_forward.py
+++ b/nnvm/tests/python/frontend/onnx/test_forward.py
@@ -14,8 +14,8 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
         c2_out = prepared_backend.run(W)[0]
         return c2_out
 
-    def get_tvm_output(graph, x, target, ctx, dtype='float32'):
-        new_sym, params = nnvm.frontend.from_onnx(graph)
+    def get_tvm_output(model, x, target, ctx, dtype='float32'):
+        new_sym, params = nnvm.frontend.from_onnx(model)
         shape_dict = {'input_0': x.shape}
         graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
         m = graph_runtime.create(graph, lib, ctx)
diff --git a/nnvm/tests/python/frontend/onnx/test_graph.py b/nnvm/tests/python/frontend/onnx/test_graph.py
index 6d64f47469caaeba7ac59ccaafc867ca99108fcd..89f13b447991124af106481e6ab7496a645d57fd 100644
--- a/nnvm/tests/python/frontend/onnx/test_graph.py
+++ b/nnvm/tests/python/frontend/onnx/test_graph.py
@@ -5,8 +5,8 @@ from nnvm.compiler import graph_util, graph_attr
 from model_zoo import super_resolution, super_resolution_sym
 
 def compare_graph(onnx_file, nnvm_sym, ishape):
-    onnx_graph = onnx.load(onnx_file)
-    onnx_sym, params = nnvm.frontend.from_onnx(onnx_graph)
+    onnx_model = onnx.load(onnx_file)
+    onnx_sym, params = nnvm.frontend.from_onnx(onnx_model)
     g1 = nnvm.graph.create(onnx_sym)
     g2 = nnvm.graph.create(nnvm_sym)
     ishapes = {'input_0': ishape}
diff --git a/nnvm/tutorials/from_onnx.py b/nnvm/tutorials/from_onnx.py
index 5a01fb60dc533fa9b250d7b7d8865ca309520ee7..c5680f4ae816e04b7a571070f0a831e33478e98e 100644
--- a/nnvm/tutorials/from_onnx.py
+++ b/nnvm/tutorials/from_onnx.py
@@ -44,9 +44,9 @@ model_url = ''.join(['https://gist.github.com/zhreshold/',
                      'super_resolution_0.2.onnx'])
 download(model_url, 'super_resolution.onnx', True)
 # now you have super_resolution.onnx on disk
-onnx_graph = onnx.load('super_resolution.onnx')
+onnx_model = onnx.load('super_resolution.onnx')
 # we can load the graph as NNVM compatible model
-sym, params = nnvm.frontend.from_onnx(onnx_graph)
+sym, params = nnvm.frontend.from_onnx(onnx_model)
 
 ######################################################################
 # Load a test image