From 30a5a6007de3d3ab2a4f7fb6ed637021a42f1322 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" <cheungchih@gmail.com> Date: Fri, 11 Jan 2019 14:36:52 -0800 Subject: [PATCH] [RELAY][FRONTEND]Onnx to relay frontend (#2302) --- docs/install/from_source.rst | 2 +- .../python/frontend/onnx/test_forward.py | 16 +- python/tvm/relay/expr.py | 8 + python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/common.py | 182 +++ python/tvm/relay/frontend/nnvm_common.py | 12 +- python/tvm/relay/frontend/onnx.py | 1090 +++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 1033 ++++++++++++++++ tests/scripts/task_python_frontend.sh | 3 + tutorials/relay/from_onnx.py | 93 ++ 10 files changed, 2421 insertions(+), 19 deletions(-) create mode 100644 python/tvm/relay/frontend/onnx.py create mode 100644 tests/python/frontend/onnx/test_forward.py create mode 100644 tutorials/relay/from_onnx.py diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 00cb96fe4..81d06f1dc 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -35,7 +35,7 @@ Our goal is to build the shared libraries: .. code:: bash sudo apt-get update - sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev + sudo apt-get install -y python python-dev python-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake The minimal building requirements are diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 143ae8583..a98ef297f 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -910,7 +910,7 @@ def test_single_ops(): model = helper.make_model(graph, producer_name='_test') for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, [x], target, ctx) - tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) x = np.random.uniform(size=in_shape).astype(dtype) verify_single_ops("Neg",x, -x) @@ -918,13 +918,13 @@ def test_single_ops(): verify_single_ops("Reciprocal",x, 1/x, rtol=1e-5, atol=1e-5) verify_single_ops("Sqrt",x, np.sqrt(x), rtol=1e-5, atol=1e-5) verify_single_ops("Relu",x, np.maximum(x, 0)) - verify_single_ops("Exp",x, np.exp(x)) - verify_single_ops("Log",x, np.log(x)) - verify_single_ops("Log",x, np.log(x)) - verify_single_ops("Tanh",x, np.tanh(x)) - verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x))) - verify_single_ops("Softsign",x, x / (1 + np.abs(x))) - verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x))) + verify_single_ops("Exp",x, np.exp(x), rtol=1e-5, atol=1e-5) + verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5) + verify_single_ops("Log",x, np.log(x), rtol=1e-5, atol=1e-5) + verify_single_ops("Tanh",x, np.tanh(x), rtol=1e-5, atol=1e-5) + verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x)), rtol=1e-5, atol=1e-5) + verify_single_ops("Softsign",x, x / (1 + np.abs(x)), rtol=1e-5, atol=1e-5) + verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x)), rtol=1e-5, atol=1e-5) def test_leaky_relu(): def leaky_relu_x(x, alpha): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 9de0344bf..f510d6195 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -465,6 +465,14 @@ def const(value, dtype=None): """ if isinstance(value, (_base.numeric_types, (bool, list))): value = _np.array(value, dtype=dtype) + if not dtype: + # when dtype is None: int maps to "int32", float maps to "float32" + map_dtype = { + _np.dtype('int64'): _np.int32, + _np.dtype('float64'): _np.float32 + }.get(value.dtype, None) + if map_dtype: + value = value.astype(map_dtype) if isinstance(value, (_np.ndarray, _np.generic)): value = _nd.array(value) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index fd0199aaf..30d544694 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -9,3 +9,4 @@ from __future__ import absolute_import from .mxnet import from_mxnet from .keras import from_keras +from .onnx import from_onnx diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 2d5817734..0464f0b8b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -1,6 +1,11 @@ """Common utilities""" from __future__ import absolute_import as _abs +import logging +from topi.util import get_const_tuple from .. import expr as _expr +from .. import expr as _expr +from .. import ir_pass +from .. import op as _op class RequiredAttr(object): @@ -204,6 +209,30 @@ class StrAttrsDict(object): raise AttributeError("Required attribute {} not found.".format(key)) return default +def get_relay_op(op_name): + """Get the callable function from Relay based on operator name. + Parameters + ---------- + op_name : str + The Relay operator name. + """ + if '.' in op_name: + # explicit hierachical modules + op = _op + try: + for opn in op_name.split('.'): + op = getattr(op, opn) + except AttributeError: + op = None + else: + # try search op in various modules + for candidate in (_op, _op.nn, _op.image): + op = getattr(candidate, op_name, None) + if op is not None: + break + if not op: + raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + return op class ExprTable(object): """Table storing Relay expressions by names.""" @@ -227,3 +256,156 @@ class ExprTable(object): def set_expr(self, name, expr): assert isinstance(expr, _expr.Expr) self.exprs[name] = expr + + +class AttrCvt(object): + """Common attribute conveter. An AttrConverter instance is a callable: + ``` + attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) + new_op_name, new_attr = attr_converter(attrs) + ``` + + Parameters + ---------- + op_name : str or callable + If set as str, returned operator name is the str. + If set as callable, returned operator is the str returned by calling: + `op_name = func(attr)` + transforms : dict of `new_name, or (new_name, default_value, transform function)` + If only a new_name is provided, it's like renaming the attribute name. + If default_value if provded, then the attribute is considered as optional. + If transform function is provided, the original attribute value is handled + by transform function. + excludes : list + A list of excluded attributes that should `NOT` appear. + Raise NotImplementedError if occured. + disables : list + A list of attributes that is disabled in relay. Log warnings. + ignores : list + A list of attributes that is ignored in relay. Debug level logging. + extras : dict + A series of additional attributes should be added anyway to the returned + attribute dict. + custom_check : callable + A custom function takes attribute, and return True/False. + Raise RuntimeError if not bool(True) returned. + """ + def __init__(self, op_name, transforms=None, + excludes=None, disables=None, ignores=None, + extras=None, custom_check=None): + self._op_name = op_name + self._transforms = transforms if transforms else {} + self._excludes = excludes if excludes else [] + self._disables = disables if disables else [] + self._ignores = ignores if ignores else [] + self._extras = extras if extras else {} + self._custom_check = custom_check + + def __call__(self, inputs, attrs, *args): + # apply custom check + if self._custom_check: + func, msg = self._custom_check + if not func(attrs): + raise RuntimeError("Check failed: {}".format(msg)) + # get new op_name + if isinstance(self._op_name, str): + op_name = self._op_name + else: + assert callable(self._op_name), "op_name can either be string or callable" + op_name = self._op_name(attrs) + # convert attributes + new_attrs = {} + for k in attrs.keys(): + if k in self._excludes: + raise NotImplementedError("Attribute {} not supported yet.".format(k)) + elif k in self._disables: + logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) + elif k in self._ignores: + logging.debug("Attribute %s is ignored in relay.sym.%s", k, op_name) + elif k in self._transforms: + new_name, defaults, transform = self._parse_default(self._transforms[k]) + if defaults is None: + new_attr = self._required_attr(attrs, k) + else: + new_attr = attrs.get(k, None) + if new_attr is None: + new_attrs[new_name] = defaults + else: + new_attrs[new_name] = transform(new_attr) + else: + # copy + new_attrs[k] = attrs[k] + # add extras + new_attrs.update(self._extras) + return get_relay_op(op_name)(*inputs, **new_attrs) + + def _parse_default(self, target): + """Helper function to parse default values.""" + if not isinstance(target, (list, tuple)): + k, v, t = target, None, lambda x: x + elif len(target) == 1: + k, v, t = target[0], None, lambda x: x + elif len(target) == 2: + k, v, t = target[0], target[1], lambda x: x + elif len(target) > 2: + k, v, t = target[0], target[1], target[2] + else: + k = None # should raise + if not isinstance(k, str): + msg = "{} is not a valid target, (name, default) expected.".format(target) + raise ValueError(msg) + return k, v, t + + def _parse_bool(self, value): + """Helper function to parse default boolean values.""" + if isinstance(value, str): + return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] + return bool(value) + + def _required_attr(self, attr, key): + """Wrapper for getting required attributes.""" + assert isinstance(attr, dict) + if key not in attr: + raise AttributeError("Required attribute {} not found.".format(key)) + return attr[key] + +def get_name(node): + name = '' + if hasattr(node, "name_hint"): + name = node.name_hint + return name + +def infer_shape(inputs): + """A method to get the output shape of an intermediate node in the graph.""" + out_type = ir_pass.infer_type(inputs) + out_shapes = get_const_tuple(out_type.checked_type.shape) + return out_shapes + +def infer_channels(inputs, transpose=False): + """A hack for getting 'channels' or 'units' since caffe2 does not provide + these attributes. We check the shape of weights provided to get the number. + """ + out_type = ir_pass.infer_type(inputs) + out_shapes = [get_const_tuple(out_type.checked_type.shape)] + channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + return channels + +def new_var(name_hint, + type_annotation=None, + shape=None, + dtype="float32"): + return _expr.var(name_hint, type_annotation, shape, dtype) + +class Renamer(object): + """A simply renamer for operators. + + Parameters + ---------- + new_name : str + The new name for the operator + """ + def __init__(self, new_name): + self._new_name = new_name + + def __call__(self, inputs, attrs, *args): + return get_relay_op(self._new_name)(*inputs, **attrs) diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index 17502dbaa..d75b0cac6 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -4,15 +4,7 @@ from __future__ import absolute_import as _abs from .. import expr as _expr from .. import op as _op - -def _get_relay_op(op_name): - op = _op - for path in op_name.split("."): - op = getattr(op, path) - if not op: - raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) - return op - +from .common import get_relay_op def _warn_not_used(attr, op='nnvm'): import warnings @@ -22,7 +14,7 @@ def _warn_not_used(attr, op='nnvm'): def _rename(new_op): if isinstance(new_op, str): - new_op = _get_relay_op(new_op) + new_op = get_relay_op(new_op) # attrs are ignored. def impl(inputs, _, _dtype='float32'): return new_op(*inputs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py new file mode 100644 index 000000000..effe50e06 --- /dev/null +++ b/python/tvm/relay/frontend/onnx.py @@ -0,0 +1,1090 @@ +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +"""ONNX: Open Neural Network Exchange frontend for Relay.""" +from __future__ import absolute_import as _abs + +import logging +import numpy as np +from ... import nd as _nd +from .. import ir_pass +from .. import expr as _expr +from .. import op as _op +from .common import AttrCvt, Renamer +from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name + +__all__ = ['from_onnx'] + +def dimension_picker(prefix, surfix=''): + def _impl(attr): + kernel = attr['kernel_shape'] + if len(kernel) == 2: + return prefix + '2d' + surfix + else: + raise NotImplementedError("Only 2d kernel supported.") + + return _impl + +def revert_caffe2_pad(pads): + """Caffe2 requires two times the normal padding.""" + if len(pads) == 4: + pads = pads[:2] + elif len(pads) == 2: + pass + else: + raise ValueError("Invalid caffe2 type padding: {}".format(pads)) + return pads + +def dimension_constraint(): + def _dim_check(attrs): + if len(attrs['kernel_shape']) == 2: + return True + return False + + return _dim_check, "Only 2d kernel supported." + +class OnnxOpConverter(object): + """ A helper class for holding onnx op converters. + """ + + @classmethod + def get_converter(cls, opset): + """ Get converter matches given opset. + + Parameters + ---------- + opset: int + opset from model. + + Returns + ------- + converter, which should be `_impl_vx`. Number x is the biggest + number smaller than or equal to opset belongs to all support versions. + """ + versions = [ + int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d + ] + versions = sorted(versions + [opset]) + version = versions[ + max([i for i, v in enumerate(versions) if v == opset]) - 1] + if hasattr(cls, '_impl_v{}'.format(version)): + return getattr(cls, '_impl_v{}'.format(version)) + raise NotImplementedError( + 'opset version {} of {} not implemented'.format( + version, cls.__name__)) + + +class Elemwise(OnnxOpConverter): + """ A helper class for elemwise op converters. + """ + name = '' + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "Math op take 2 inputs, {} given".format( + len(inputs)) + op_name = cls.name + conv_ops = ["conv2d", "conv2d_transpose"] + if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops): + # TODO(zhreshold): remove hard coded infershape + axis = int(attr.get('axis', 0)) + inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) + return get_relay_op(op_name)(*inputs) + +class Pool(OnnxOpConverter): + """ A helper class for pool op converters. + """ + name = '' + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad) + }, + # very weird attributes here in onnx, force check + ignores=['dilations'], + # TODO(zhreshold): make sure ceil_mode in onnx, and layout? + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, attr, params) + + +class Absolute(OnnxOpConverter): + """ Operator converter for Absolute. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.nn.relu(inputs[0]) + _op.nn.relu(_op.negative(inputs[0])) + + +class Add(Elemwise): + """ Operator converter for Add. + """ + name = 'add' + + +class AveragePool(Pool): + """ Operator converter for AveragePool. + """ + name = 'avg_pool' + + +class BatchNorm(OnnxOpConverter): + """ Operator converter for BatchNorm. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # TODO(zhreshold): 'spatial' is not properly handled here. + out = AttrCvt( + op_name='batch_norm', + ignores=['spatial', 'is_test', 'consumed_inputs', 'momentum'])(inputs, attr, + params) + return out[0] + + +class Conv(OnnxOpConverter): + """ Operator converter for Conv. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # get number of channels + out = AttrCvt(op_name=dimension_picker('conv'), + transforms={ + 'kernel_shape': 'kernel_size', + 'dilations': ('dilation', (0, 0)), + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'group': ('groups', 1)}, + custom_check=dimension_constraint())(inputs[:2], attr, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + + +class ConvTranspose(OnnxOpConverter): + """ Operator converter for ConvTranspose. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + # get number of channels + channels = infer_channels(inputs[1], True) + attr['channels'] = channels + groups = attr.pop('group') + attr['groups'] = groups + out = AttrCvt( + op_name=dimension_picker('conv', '_transpose'), + transforms={ + 'kernel_shape': 'kernel_size', + 'dilations': ('dilation', (0, 0)), + 'pads': ('padding', (0, 0), revert_caffe2_pad) + }, + disables=['output_shape'], + custom_check=dimension_constraint())(inputs[:2], attr, params) + use_bias = len(inputs) == 3 + if use_bias: + out = _op.nn.bias_add(out, inputs[2]) + return out + + +class Div(Elemwise): + name = 'divide' + + +class Elu(OnnxOpConverter): + """ Operator converter for Elu. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = float(attr.get('alpha', 1.0)) + return _expr.const(-alpha) * _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + \ + _op.nn.relu(inputs[0]) + + +class Gemm(OnnxOpConverter): + """ Operator converter for Gemm. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format( + len(inputs)) + # Y = alpha * A * B + beta * C + alpha = float(attr.get('alpha', 1.0)) + beta = float(attr.get('beta', 1.0)) + transA = int(attr.get('transA', 0)) + transB = int(attr.get('transB', 0)) + # get number of channels + channels = infer_channels(inputs[1], not transB) + if transA: + inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) + if not transB: + inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) + inputs[0] = _op.nn.batch_flatten(inputs[0]) + out = _op.nn.dense(_expr.const(alpha) * inputs[0], + inputs[1], units=channels) + return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) + +class MatMul(OnnxOpConverter): + """ Operator converter for MatMul. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) + input_1_t = _op.transpose(inputs[1], axes=(1, 0)) + return _op.nn.dense(inputs[0], input_1_t) + +class MaxPool(Pool): + name = 'max_pool' + + +class Mul(Elemwise): + name = 'multiply' + + +class Pad(OnnxOpConverter): + """ Operator converter for Pad. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + pad_width = [] + pads = attr.pop('paddings') + dims = int(len(pads) / 2) + for i in range(dims): + pad_width.append((pads[i], pads[i+dims])) + attr['pad_width'] = pad_width + + return AttrCvt( + _op.nn.pad, + transforms={ + 'value': 'pad_value', + }, + ignores=['mode'], + custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', + 'split mode != constant'))(inputs, attr, params) + + @classmethod + def _impl_v2(cls, inputs, attr, params): + pad_width = [] + pads = attr.pop('pads') + dims = int(len(pads) / 2) + for i in range(dims): + pad_width.append((pads[i], pads[i+dims])) + attr['pad_width'] = pad_width + + return AttrCvt( + 'pad', + transforms={ + 'value': 'pad_value', + }, + ignores=['mode'], + custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', + 'split mode != constant'))(inputs, attr, params) + + +class ParametricSoftPlus(OnnxOpConverter): + """ Operator converter for ParametricSoftPlus. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = _expr.const(float(attr.get('alpha', 1.0))) + beta = _expr.const(float(attr.get('beta', 1.0))) + return _op.log(_op.exp(beta * inputs[0]) + _expr.const(1.)) * alpha + + +class Prelu(OnnxOpConverter): + """ Operator converter for Prelu. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) + return _op.nn.prelu(inputs[0], inputs[1]) + + +class Reciprocal(OnnxOpConverter): + """ Operator converter for Reciprocal. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _expr.const(1.0) / inputs[0] + +class Reshape(OnnxOpConverter): + """ Operator converter for Reshape. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'shape' in attr: + return _op.reshape(inputs[0], attr['shape']) + + if get_name(inputs[1]) in params: + shape = tuple(params[inputs[1].name_hint].asnumpy()) + out = _op.reshape(inputs[0], shape) + else: + out = _op.reshape_like(inputs[0], inputs[1]) + + return out + +class Concat(OnnxOpConverter): + """ Operator converter for Concat. + """ + + @classmethod + def _impl_v1(cls, inputs, args, params): + return AttrCvt(op_name='concatenate')((inputs,), args) + +class Scale(OnnxOpConverter): + """ Operator converter for Scale. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + scale = float(attr.get('scale', 1.0)) + return inputs[0] * _expr.const(scale) + + +class Selu(OnnxOpConverter): + """ Operator converter for Selu. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = float(attr.get('alpha', 1.6732)) + gamma = float(attr.get('gamma', 1.0507)) + return _expr.const(gamma) * (_expr.const(-alpha) * + _op.nn.relu(_expr.const(1.) - _op.exp(inputs[0])) + + _op.nn.relu(inputs[0])) + + +class ScaledTanh(OnnxOpConverter): + """ Operator converter for ScaledTanh. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = float(attr.get('alpha', 1.0)) + beta = float(attr.get('beta', 1.0)) + return _op.tanh(_expr.const(beta) * inputs[0]) * _expr.const(alpha) + + +class SoftPlus(OnnxOpConverter): + """ Operator converter for SoftPlus. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.log(_op.exp(inputs[0]) + _expr.const(1.)) + + +class Softsign(OnnxOpConverter): + """ Operator converter for Softsign. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return inputs[0] / (_expr.const(1.) + Absolute.get_converter(1)(inputs, attr, params)) + + +class Sub(Elemwise): + name = 'subtract' + + +class Sum(OnnxOpConverter): + """ Operator converter for Sum. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # Onnx Sum Operator + for in_index in range(len(inputs) - 1): + inputs[in_index + 1] = _op.add(inputs[in_index], inputs[in_index + 1]) + + return inputs[len(inputs) - 1] + + +class ThresholdedRelu(OnnxOpConverter): + """ Operator converter for ThresholdedRelu. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = float(attr.get('alpha', 0.0)) + alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) + mask = _op.greater(inputs[0], alpha_tensor).astype("float32") + return inputs[0] * mask + + +def _broadcast_constraint(): + + def _broadcast_check(attrs): + if attrs.get('axis', None): + return False + return True + + return _broadcast_check, "Specifying broadcast axis not allowed." + + +def _fully_connected(opset): + + def _impl(inputs, attr, params): + # get number of channels + channels = infer_channels(inputs[1], params) + attr['units'] = channels + return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr) + + return _impl + + +class Upsample(OnnxOpConverter): + """ Operator converter for Upsample (nearest mode). + """ + + @classmethod + def _impl_v7(cls, inputs, attr, params): + scales = attr.get('scales') + assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] + mode = attr.get('mode') + if mode == b'nearest': + method = "NEAREST_NEIGHBOR" + elif mode == b'linear': + method = "BILINEAR" + else: + raise ValueError("Invalid ONNX upsample mode: {}".format(mode)) + attr = {'scale':int(scales[-1]), 'method':method, 'layout':'NCHW'} + return AttrCvt('upsampling')(inputs, attr) + + +class Shape(OnnxOpConverter): + """ Operator converter for Shape. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # Result of this operator is prominently used by reshape operator. + # Just pass the input as it is so that reshape_like can be used there. + logging.warning("Shape: Differently implemented in relay as a bypass (dummy operator)") + return inputs[0] + +class Cast(OnnxOpConverter): + """ Operator converter for Cast. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + + @classmethod + def _impl_v5(cls, inputs, attr, params): + try: + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']] + except ImportError as e: + raise ImportError( + "Unable to import onnx.mapping which is required {}".format(e)) + return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) + + +class Unsqueeze(OnnxOpConverter): + """ Operator converter for Unsqueeze. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + for axes in attr['axes']: + inputs[0] = _op.expand_dims(inputs[0], axis=axes, num_newaxis=1) + return inputs[0] + + +class Split(OnnxOpConverter): + """ Operator converter for Split. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + attr['indices_or_sections'] = [] + index = 0 + for i in attr['split'][:-1]: + index += i + attr['indices_or_sections'].append(index) + return AttrCvt( + 'split', + ignores=['split'])(inputs, attr, params) + + +class Slice(OnnxOpConverter): + """ Operator converter for Slice. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if isinstance(attr['starts'], int): + attr['starts'] = (attr['starts'],) + attr['ends'] = (attr['ends'],) + + try: + # Update the starts and ends according to axes if required. + if isinstance(attr['axes'], int): + attr['axes'] = (attr['axes'],) + + if (max(attr['axes']) + 1) != len(attr['axes']): + new_axes = [] + new_starts = [] + new_ends = [] + pop_index = 0 + for i in range(max(attr['axes']) + 1): + if i in attr['axes']: + new_axes.append(i) + new_starts.append(attr['starts'][pop_index]) + new_ends.append(attr['ends'][pop_index]) + pop_index += 1 + else: + new_axes.append(i) + new_starts.append(0) + new_ends.append(np.iinfo(np.int32).max) + attr['axes'] = new_axes + attr['starts'] = new_starts + attr['ends'] = new_ends + except KeyError: + pass + + return AttrCvt('strided_slice', + transforms={'starts': 'begin', + 'ends': 'end'}, + ignores=['axes'])(inputs, attr) + +class Gather(OnnxOpConverter): + """ Operator converter for Gather. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 0) + return AttrCvt('take', + extras={'axis':axis})(inputs, {}) + #return _op.take(inputs[0], inputs[1], axis) + +class LRN(OnnxOpConverter): + """ Operator converter for Local Response Normalization. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + """LRN support only NCHW format + https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN + """ + axis = 1 + alpha = attr.get('alpha', 0.0001) + beta = attr.get('beta', 0.75) + bias = attr.get('bias', 1.0) + nsize = attr.get('size') + attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias} + return AttrCvt('lrn')(inputs, attr) + +class Maximum(OnnxOpConverter): + """ Operator converter for Maximum. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + _max = inputs[0] + for i in range(1, len(inputs)): + _max = AttrCvt('maximum')([_max, inputs[i]], {}) + return _max + +class Minimum(OnnxOpConverter): + """ Operator converter for Minimum. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + _min = inputs[0] + for i in range(1, len(inputs)): + _min = AttrCvt('minimum')([_min, inputs[i]], {}) + return _min + +class Mean(OnnxOpConverter): + """ Operator converter for Mean. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + # avoid overflow + concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0) + return _op.mean(concat, axis=0, keepdims=False) + +class HardSigmoid(OnnxOpConverter): + """ Operator converter for HardSigmoid. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = attr.get('alpha', 0.2) + beta = attr.get('beta', 0.5) + transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta) + attr = {'a_min':0, 'a_max':1} + return AttrCvt('clip')([transformX], attr) + +class Reduce(OnnxOpConverter): + """ Operator converter for reduce ops. + """ + name = '' + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)} + return AttrCvt(cls.name)(inputs, attr) + +class ReduceMax(Reduce): + """ Operator converter for ArgMax. + """ + name = 'max' + +class ReduceMin(Reduce): + """ Operator converter for ArgMax. + """ + name = 'min' + +class ReduceSum(Reduce): + """ Operator converter for ArgMax. + """ + name = 'sum' + +class ReduceMean(Reduce): + """ Operator converter for ArgMax. + """ + name = 'mean' + +class ArgMax(OnnxOpConverter): + """ Operator converter for ArgMax. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 0) + keepdims = attr.get('keepdims', True) + attr = {'axis':axis, 'keepdims':keepdims} + return AttrCvt('argmax')(inputs, attr) + +class ArgMin(OnnxOpConverter): + """ Operator converter for ArgMin. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 0) + keepdims = attr.get('keepdims', True) + attr = {'axis':axis, 'keepdims':keepdims} + return AttrCvt('argmin')(inputs, attr) + +class Softmax(OnnxOpConverter): + """ Operator converter for Softmax. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + # set default value when axis is not set in the model + if 'axis' not in attr: + attr['axis'] = 1 + return AttrCvt('softmax', transforms={'axis': ('axis', 1)})(inputs, attr, params) + +class ConstantFill(OnnxOpConverter): + """ Operator converter for ConstantFill. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + num_inputs = len(inputs) + if 'shape' in attr: + if num_inputs > 1: + raise ImportError( + "Can't set shape and input tensor at a time") + shape = attr.pop('shape') + else: + if num_inputs == 1: + raise ImportError( + "Either shape attribute or input should be set") + if 'input_as_shape' in attr and attr['input_as_shape']: + shape = params[get_name(inputs[0])].asnumpy() + else: + if 'extra_shape' in attr: + raise ImportError( + "Extra Shape not supported with fill_like") + return _op.full_like(inputs[0], inputs[1]) + + if 'extra_shape' in attr: + shape = shape + attr.pop('extra_shape') + return _op.full(inputs[0], shape) + +# compatible operators that do NOT require any conversion. +_identity_list = [] + + +# _convert_map defines maps of name to converter functor(callable) +# for 1 to 1 mapping, use Renamer if nothing but name is different +# use AttrCvt if attributes need to be converted +# for 1 to N mapping(composed), use custom callable functions +# for N to 1 mapping, currently not supported(?) +def _get_convert_map(opset): + return { + # defs/experimental + 'Identity': Renamer('copy'), + # 'Affine' + 'ThresholdedRelu': ThresholdedRelu.get_converter(opset), + 'ScaledTanh': ScaledTanh.get_converter(opset), + 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), + 'ConstantFill': ConstantFill.get_converter(opset), + # 'GivenTensorFill' + 'FC': AttrCvt('dense', ignores=['axis', 'axis_w']), + 'Scale': Scale.get_converter(opset), + # 'GRUUnit' + # 'ATen' + # 'ImageScaler' + # 'MeanVarianceNormalization' + # 'Crop' + # 'Embedding' + 'Upsample' : Upsample.get_converter(opset), + 'SpatialBN': BatchNorm.get_converter(opset), + + # defs/generator + # 'Constant' # Implemented + # 'RandomUniform' + # 'RandomNormal' + # 'RandomUniformLike' + # 'RandomNormalLike' + + # defs/logical + + # defs/math + 'Add': Add.get_converter(opset), + 'Sub': Sub.get_converter(opset), + 'Mul': Mul.get_converter(opset), + 'Div': Div.get_converter(opset), + 'Neg': Renamer('negative'), + 'Abs': Absolute.get_converter(opset), + 'Reciprocal': Reciprocal.get_converter(opset), + 'Floor': Renamer('floor'), + 'Ceil': Renamer('ceil'), + 'Sqrt': Renamer('sqrt'), + 'Relu': Renamer('relu'), + 'LeakyRelu': Renamer('leaky_relu'), + 'Selu': Selu.get_converter(opset), + 'Elu': Elu.get_converter(opset), + 'Exp': Renamer('exp'), + 'Log': Renamer('log'), + 'Tanh': Renamer('tanh'), + 'Pow': Renamer('power'), + 'PRelu': Prelu.get_converter(opset), + 'Sigmoid': Renamer('sigmoid'), + 'HardSigmoid': HardSigmoid.get_converter(opset), + 'Max': Maximum.get_converter(opset), + 'Min': Minimum.get_converter(opset), + 'Sum': Sum.get_converter(opset), + 'Mean': Mean.get_converter(opset), + 'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}), + # softmax default axis is different in onnx + 'Softmax': Softmax.get_converter(opset), + 'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}), + # 'Hardmax' + 'Softsign': Softsign.get_converter(opset), + 'SoftPlus': SoftPlus.get_converter(opset), + 'Gemm': Gemm.get_converter(opset), + 'MatMul': MatMul.get_converter(opset), + + # defs/nn + 'AveragePool': AveragePool.get_converter(opset), + 'MaxPool': MaxPool.get_converter(opset), + 'Conv': Conv.get_converter(opset), + 'ConvTranspose': ConvTranspose.get_converter(opset), + 'GlobalAveragePool': Renamer('global_avg_pool2d'), + 'GlobalMaxPool': Renamer('global_max_pool2d'), + 'BatchNormalization': BatchNorm.get_converter(opset), + # 'InstanceNormalization' + # 'LpNormalization' + 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), + 'Flatten': Renamer('flatten'), + 'LRN': LRN.get_converter(opset), + + # defs/reduction + 'ReduceMax': ReduceMax.get_converter(opset), + 'ReduceMin': ReduceMin.get_converter(opset), + 'ReduceSum': ReduceSum.get_converter(opset), + 'ReduceMean': ReduceMean.get_converter(opset), + # 'ReduceProd' + # 'ReduceLogSumExp' + 'ArgMax': ArgMax.get_converter(opset), + 'ArgMin': ArgMin.get_converter(opset), + + # defs/tensor + 'Cast': Cast.get_converter(opset), + 'Reshape': Reshape.get_converter(opset), + 'Concat': Concat.get_converter(opset), + 'Split': Split.get_converter(opset), + 'Slice': Slice.get_converter(opset), + 'Transpose': AttrCvt('transpose', {'perm': 'axes'}), + 'Gather': Gather.get_converter(opset), + 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), + 'Unsqueeze': Unsqueeze.get_converter(opset), + 'Pad': Pad.get_converter(opset), + # TODO(zhreshold) Shape op is implemented as bypass op in relay + # 'Shape': Shape.get_converter(opset), + } + + +class GraphProto(object): + """A helper class for handling Relay expression copying from pb2.GraphProto. + Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto + + Parameters + ---------- + shape : dict of str to tuple, optional + The input shape to the graph + + dtype : str or dict of str to str + The input types to the graph + """ + + def __init__(self, shape, dtype): + self._nodes = {} + self._params = {} + self._renames = {} + self._num_input = 0 + self._num_param = 0 + self._shape = shape + self._dtype = dtype + + def from_onnx(self, graph, opset): + """Construct Relay expression from ONNX graph. + + Onnx graph is a python protobuf object. + The companion parameters will be handled automatically. + However, the input names from onnx graph is vague, mixing inputs and + network weights/bias such as "1", "2"... + For convenience, we rename the `real` input names to "input_0", + "input_1"... And renaming parameters to "param_0", "param_1"... + + Parameters + ---------- + graph : onnx protobuf object + The loaded onnx graph + opset : opset version + + Returns + ------- + sym : tvm.relay.expr.Function + The returned relay function + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights + """ + # parse network inputs to relay, aka parameters + for init_tensor in graph.initializer: + if not init_tensor.name.strip(): + raise ValueError("Tensor's name is required.") + self._params[init_tensor.name] = self._parse_array(init_tensor) + for i in graph.input: + # from onnx v0.2, GraphProto.input has type ValueInfoProto, + # and the name is 'i.name' + i_name = self._parse_value_proto(i) + d_type = self._parse_dtype(i, 'float32') + if i_name in self._params: + # i is a param instead of input + self._num_param += 1 + self._params[i_name] = self._params.pop(i_name) + self._nodes[i_name] = new_var(i_name, + shape=self._params[i_name].shape, + dtype=self._params[i_name].dtype) + else: + self._num_input += 1 + shape = self._shape[i_name] if i_name in self._shape else () + if isinstance(self._dtype, dict): + dtype = self._dtype[i_name] if i_name in self._dtype else d_type + else: + dtype = d_type + self._nodes[i_name] = new_var(i_name, shape=shape, dtype=dtype) + # construct nodes, nodes are stored as directed acyclic graph + for node in graph.node: + op_name = node.op_type + attr = self._parse_attr(node.attribute) + inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] + if op_name == "Constant": + t_proto = self._parse_attr(node.attribute)["value"] + self._num_param += 1 + self._params[node.output[0]] = self._parse_array(t_proto) + self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) + else: + if op_name == "ConstantFill": + fill_value = attr.get('value', 0.0) + dtype = attr.get('dtype', b'int32').decode("utf-8") + i_name = node.output[0] + self._params[i_name] = fill_value + self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) + inputs.append(self._nodes[i_name]) + + op = self._convert_operator(op_name, inputs, attr, opset) + node_output = self._fix_outputs(op_name, node.output) + if not isinstance(op, _expr.TupleWrapper): + outputs_num = 1 + else: + outputs_num = len(op) + assert len(node_output) == outputs_num, ( + "Number of output mismatch {} vs {} in {}.".format( + len(node_output), outputs_num, op_name)) + if outputs_num == 1: + self._nodes[node_output[0]] = op + else: + for k, i in zip(list(node_output), range(len(node_output))): + self._nodes[k] = op[i] + + # now return the outputs + outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + func = _expr.Function(ir_pass.free_vars(outputs), outputs) + return func, self._params + + def _parse_value_proto(self, value_proto): + """Parse ValueProto or raw str.""" + try: + name = value_proto.name + except AttributeError: + name = value_proto + return name + + def _parse_dtype(self, value_proto, dtype): + """Parse dtype.""" + try: + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name + except AttributeError: + return dtype + + def _parse_array(self, tensor_proto): + """Grab data in TensorProto and convert to numpy array.""" + try: + from onnx.numpy_helper import to_array + except ImportError as e: + raise ImportError( + "Unable to import onnx which is required {}".format(e)) + np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) + return _nd.array(np_array) + + def _parse_attr(self, attr_proto): + """Convert a list of AttributeProto to a dict, with names as keys.""" + attrs = {} + for a in attr_proto: + for f in ['f', 'i', 's']: + if a.HasField(f): + attrs[a.name] = getattr(a, f) + for f in ['floats', 'ints', 'strings']: + if list(getattr(a, f)): + assert a.name not in attrs, "Only one type of attr is allowed" + attrs[a.name] = tuple(getattr(a, f)) + for f in ['t']: + if a.HasField(f): + attrs[a.name] = getattr(a, f) + for f in ['tensors']: + if list(getattr(a, f)): + assert a.name not in attrs, "Only one type of attr is allowed" + attrs[a.name] = tuple(getattr(a, f)) + for f in ['g']: + if a.HasField(f): + raise NotImplementedError( + "Filed {} is not supported in relay.".format(f)) + for f in ['graphs']: + if list(getattr(a, f)): + raise NotImplementedError( + "Filed {} is not supported in relay.".format(f)) + if a.name not in attrs: + raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) + return attrs + + def _convert_operator(self, + op_name, + inputs, + attrs, + opset): + """Convert ONNX operator into a Relay operator. + The converter must specify conversions explicity for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + op_name : str + Operator name, such as Convolution, FullyConnected + inputs : list of tvm.relay.expr.Function + List of inputs. + attrs : dict + Dict of operator attributes + opset : int + Opset version + + Returns + ------- + sym : tvm.relay.expr.Function + Converted relay function + """ + convert_map = _get_convert_map(opset) + if op_name in _identity_list: + sym = get_relay_op(op_name)(*inputs, **attrs) + elif op_name in convert_map: + sym = convert_map[op_name](inputs, attrs, self._params) + else: + raise NotImplementedError( + "Operator {} not implemented.".format(op_name)) + return sym + + def _fix_outputs(self, op_name, outputs): + """A hack to handle dropout or similar operator that have more than one out + in ONNX. + """ + if op_name == 'Dropout': + if len(outputs) == 1: + return outputs + # TODO(zhreshold): support dropout mask? + outputs = outputs[:-1] + return outputs + +def from_onnx(model, + shape=None, + dtype="float32"): + """Convert a ONNX model into an equivalent Relay Function. + + ONNX graphs are represented as Python Protobuf objects. + The companion parameters will be handled automatically. + However, the input names from onnx graph is vague, mixing inputs and + network weights/bias such as "1", "2"... + For convenience, we rename the `real` input names to "input_0", + "input_1"... And renaming parameters to "param_0", "param_1"... + + Parameters + ---------- + model : protobuf object + ONNX ModelProto after ONNX v1.1.0 + + shape : dict of str to tuple, optional + The input shape to the graph + + dtype : str or dict of str to str + The input types to the graph + + Returns + ------- + sym : tvm.relay.expr.Function + Compatible relay function + + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + g = GraphProto(shape, dtype) + graph = model.graph + try: + opset = model.opset_import[0].version if model.opset_import else 1 + except AttributeError: + opset = 1 + sym, params = g.from_onnx(graph, opset) + return sym, params diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py new file mode 100644 index 000000000..de95ff00a --- /dev/null +++ b/tests/python/frontend/onnx/test_forward.py @@ -0,0 +1,1033 @@ +import numpy as np +import math +import topi +import topi.testing +import tvm +from tvm import relay +from tvm.contrib import graph_runtime +from nnvm.testing.config import ctx_list +import onnx +from onnx import helper, TensorProto +import unittest + +def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32'): + """ Generic function to execute and get tvm output""" + target = 'llvm' + if isinstance(input_data, list): + input_names = {} + shape_dict = {} + dtype_dict = {} + for i, _ in enumerate(input_data): + input_names[i] = graph_def.graph.input[i].name + shape_dict[input_names[i]] = input_data[i].shape + dtype_dict[input_names[i]] = input_data[i].dtype + else: + input_names = graph_def.graph.input[0].name + shape_dict = {input_names: input_data.shape} + dtype_dict = {input_names: input_data.dtype} + + sym, params = relay.frontend.from_onnx(graph_def, shape_dict) + with relay.build_config(opt_level=1): + graph, lib, params = relay.build(sym, target, params=params) + + ctx = tvm.cpu(0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + if isinstance(input_data, list): + for i, e in enumerate(input_names): + m.set_input(input_names[i], tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + else: + m.set_input(input_names, tvm.nd.array(input_data.astype(input_data.dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + if isinstance(output_shape, list) and isinstance(output_dtype, list): + tvm_output_list = [] + for i, _ in enumerate(output_shape): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list + else: + tvm_output = m.get_output(0) + return tvm_output.asnumpy() + +def get_caffe2_output(model, x, dtype='float32'): + import caffe2.python.onnx.backend + prepared_backend = caffe2.python.onnx.backend.prepare(model) + W = {model.graph.input[0].name: x.astype(dtype)} + c2_out = prepared_backend.run(W)[0] + return c2_out + + +def verify_onnx_forward_impl(graph_file, data_shape, out_shape): + dtype = 'float32' + x = np.random.uniform(size=data_shape) + model = onnx.load_model(graph_file) + c2_out = get_caffe2_output(model, x, dtype) + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) + tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) + +def verify_super_resolution_example(): + verify_onnx_forward_impl(super_resolution, (1, 1, 224, 224), (1, 1, 672, 672)) + +def verify_squeezenet1_1(): + verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000)) + +def verify_lenet(): + verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10)) + +def verify_resnet18(): + verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000)) + + +def test_reshape(): + in_shape = (4, 3, 3, 4) + ref_shape = (6, 2, 4, 3) + + ref_array = np.array(ref_shape) + ref_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['ref_in'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.INT32, + dims = ref_array.shape, + vals = ref_array.flatten().astype(int))) + reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"]) + + graph = helper.make_graph([ref_node, reshape_node], + "reshape_test", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name='reshape_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + +def test_reshape_like(): + in_shape = (4, 3, 3, 4) + ref_shape = (3, 4, 4, 3) + + ref_array = np.random.uniform(size=ref_shape).astype('float32') + ref_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['ref_in'], + value=onnx.helper.make_tensor(name = 'const_tensor', + data_type = onnx.TensorProto.FLOAT, + dims = ref_array.shape, + vals = ref_array.flatten().astype(float))) + copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"]) + reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"]) + + graph = helper.make_graph([ref_node, copy_node, reshape_node], + "reshape_like_test", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name='reshape_like_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + +def _test_power_iteration(x_shape, y_shape): + if isinstance(y_shape, int): + y_shape = [y_shape] + + x = np.random.uniform(size=x_shape).astype(np.float32) + y = np.random.uniform(size=y_shape).astype(np.float32) + + np_res = np.power(x, y).astype(np.float32) + + res = helper.make_node("Pow", ['x', 'y'], ['out']) + + graph = helper.make_graph([res], + 'power_test', + inputs = [helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("y", + TensorProto.FLOAT, list(y_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(np_res.shape))]) + + model = helper.make_model(graph, producer_name='power_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, y], target, ctx, np_res.shape) + tvm.testing.assert_allclose(np_res, tvm_out, rtol=1e-5, atol=1e-5) + +def test_power(): + _test_power_iteration((1, 3), (1)) + _test_power_iteration((2, 3), (2, 3)) + _test_power_iteration((2, 3), (1, 3)) + +def test_squeeze(): + in_shape = (1, 3, 1, 3, 1, 1) + out_shape = (3, 3) + y = helper.make_node("Squeeze", ['in'], ['out'], axes=[0, 2, 4, 5]) + + graph = helper.make_graph([y], + 'squeeze_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='squeeze_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') + + tvm.testing.assert_allclose(out_shape, tvm_out.shape) + +def test_unsqueeze(): + in_shape = (3, 3) + axis = (0, 3, 4) + out_shape = (1, 3, 3, 1, 1) + y = helper.make_node("Unsqueeze", ['in'], ['out'], axes=list(axis)) + + graph = helper.make_graph([y], + 'squeeze_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='squeeze_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('float32') + tvm_out = get_tvm_output(model, x, target, ctx, out_shape, 'float32') + + tvm.testing.assert_allclose(out_shape, tvm_out.shape) + +def verify_gather(in_shape, indices, axis, dtype): + x = np.random.uniform(size=in_shape).astype(dtype) + indices = np.array(indices, dtype="int32") + out_np = np.take(x, indices, axis=axis) + + y = helper.make_node("Gather", ['in', 'indices'], ['out'], axis=axis) + + graph = helper.make_graph([y], + 'gather_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) + model = helper.make_model(graph, producer_name='gather_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, indices], target, ctx, out_np.shape) + tvm.testing.assert_allclose(out_np, tvm_out) + +def test_gather(): + verify_gather((4,), [1], 0, 'int32') + verify_gather((1,4), [0], 0, 'int32') + verify_gather((4,), [[[1,0],[0,1]]], 0, 'float32') + verify_gather((2,2), [[[1,0],[0,1]]], 1, 'int32') + verify_gather((3,3,3), [[[1,0]]], -1, 'int32') + verify_gather((4,3,5,6), [[2,1,0,0]], 0, 'float32') + +def _test_slice_iteration(indata, outdata, starts, ends, axes=None): + if axes: + y = helper.make_node("Slice", ['in'], ['out'], axes=axes, starts=starts, ends=ends) + else: + y = helper.make_node("Slice", ['in'], ['out'], starts=starts, ends=ends) + + graph = helper.make_graph([y], + 'slice_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name='slice_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + + tvm.testing.assert_allclose(outdata, tvm_out) + +def test_slice(): + x = np.random.randn(20, 10, 5).astype(np.float32) + _test_slice_iteration(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) + _test_slice_iteration(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) + _test_slice_iteration(x, x[:, 1:1000], (1), (1000), (1)) + _test_slice_iteration(x, x[:, 0:-1], (0), (-1), (1)) + +def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs): + indata = np.random.uniform(-1, 1, size=inshape).astype(dtype) + outdata = outfunc(indata, **npargs) + + y = helper.make_node(opname, ['in'], ['out'], **kwargs) + + graph = helper.make_graph([y], + opname+'_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name=opname+'_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, dtype) + + tvm.testing.assert_allclose(outdata, tvm_out) + +def test_floor(): + _test_onnx_op_elementwise((2, 4, 5, 6), np.floor, {}, 'float32', 'Floor', {}) + +def test_ceil(): + _test_onnx_op_elementwise((2, 4, 5, 6), np.ceil, {}, 'float32', 'Ceil', {}) + +def test_clip(): + _test_onnx_op_elementwise((2, 4, 5, 6), + np.clip, + {'a_min': -1.0, 'a_max': 1.0}, + 'float32', + 'Clip', + {'min': -1.0, 'max': 1.0}) + +def test_matmul(): + a_shape = (4, 3) + b_shape = (3, 4) + + a_array = np.random.uniform(size=a_shape).astype('float32') + b_array = np.random.uniform(size=b_shape).astype('float32') + out_np = np.matmul(a_array, b_array) + + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + + graph = helper.make_graph([mul_node], + "matmul_test", + inputs = [helper.make_tensor_value_info("a", + TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", + TensorProto.FLOAT, list(b_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) + + model = helper.make_model(graph, producer_name='matmul_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, out_np.shape) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + +def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None): + in_array = np.random.uniform(size=shape).astype(dtype) + + if alpha == None and beta == None and bias==None: + alpha = 0.0001 + beta = 0.75 + bias = 1.0 + node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], size=nsize) + else: + node = onnx.helper.make_node('LRN', inputs=['in'], outputs=['out'], alpha=alpha, + beta=beta, bias=bias, size=nsize) + + graph = helper.make_graph([node], + "lrn_test", + inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], + outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))]) + model = helper.make_model(graph, producer_name='lrn_test') + + def _get_python_lrn(): + square_sum = np.zeros(shape).astype(dtype) + for n, c, h, w in np.ndindex(in_array.shape): + square_sum[n, c, h, w] = sum(in_array[n, + max(0, c - int(math.floor((nsize - 1) / 2))): \ + min(5, c + int(math.ceil((nsize - 1) / 2)) + 1), + h, + w] ** 2) + py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta) + return py_out + + for target, ctx in ctx_list(): + input_name = model.graph.input[0].name + py_out = _get_python_lrn() + tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 'float32') + tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_lrn(): + verify_lrn((5, 5, 5, 5), 3, 'float32') + verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) + +def _test_upsample_nearest(): + scale = 2 + in_shape = (1, 1, 3, 3) + out_shape = (1, 1, 3*scale, 3*scale) + y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) + + in_array = np.random.uniform(size=in_shape).astype(np.float32) + out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + + graph = helper.make_graph([y], + 'upsample_nearest_test', + inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='upsample_nearest_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32') + tvm.testing.assert_allclose(out_array, tvm_out) + +def _test_upsample_bilinear(): + scale = 2 + in_shape = (1, 1, 3, 3) + out_shape = (1, 1, 3*scale, 3*scale) + y = helper.make_node("Upsample", ['in'], ['out'], mode='linear', scales=[1.0, 1.0, 2.0, 2.0]) + + in_array = np.random.uniform(size=in_shape).astype(np.float32) + out_array = topi.testing.bilinear_resize_python(in_array, (3*scale, 3*scale), "NCHW") + + graph = helper.make_graph([y], + 'upsample_bilinear_test', + inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='upsample_bilinear_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 'float32') + tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) + +def test_upsample(): + _test_upsample_nearest() + _test_upsample_bilinear() + +def _test_softmax(inshape, axis): + opname = 'Softmax' + indata = np.random.uniform(size=inshape).astype(np.float32) + outshape = inshape + outdata = topi.testing.softmax_python(indata) + if isinstance(axis, int): + y = helper.make_node(opname, ['in'], ['out'], axis = axis) + elif axis is None: + y = helper.make_node(opname, ['in'], ['out']) + + graph = helper.make_graph([y], + opname+'_test', + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name=opname+'_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 'float32') + tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + +def test_softmax(): + _test_softmax((1, 10), None) + _test_softmax((1, 10), 1) + +def verify_min(input_dim): + dtype = 'float32' + + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) + + b_np = np.min((a_np1, a_np2, a_np3), axis=0) + + min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"]) + + graph = helper.make_graph([min_node], + "Min_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='Min_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_min(): + verify_min((1, 3, 20, 20)) + verify_min((20, 20)) + +def verify_max(input_dim): + dtype = 'float32' + + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) + + b_np = np.max((a_np1, a_np2, a_np3), axis=0) + + max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"]) + + graph = helper.make_graph([max_node], + "Max_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='Max_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_max(): + verify_max((1, 3, 20, 20)) + verify_max((20, 20)) + +def verify_mean(input_dim): + dtype = 'float32' + + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) + + b_np = np.mean((a_np1, a_np2, a_np3), axis=0) + + mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"]) + + graph = helper.make_graph([mean_node], + "Mean_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", + TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", + TensorProto.FLOAT, list(input_dim))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='Mean_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, b_np.shape) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_mean(): + verify_mean((1, 3, 20, 20)) + verify_mean((20, 20)) + +def verify_hardsigmoid(input_dim, alpha, beta): + dtype = 'float32' + + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + + b_np = np.clip(a_np1 * alpha + beta, 0, 1) + + hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], alpha=alpha, beta=beta) + + graph = helper.make_graph([hardsigmoid_node], + "HardSigmoid_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.FLOAT, list(input_dim))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='HardSigmoid_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_hardsigmoid(): + verify_hardsigmoid((1, 3, 20, 20), 0.5, 0.6) + verify_hardsigmoid((20, 20), 0.3, 0.4) + +def verify_argmin(input_dim, axis=None, keepdims=None): + def _argmin_numpy(data, axis=0, keepdims=True): + result = np.argmin(data, axis=axis) + if (keepdims == 1): + result = np.expand_dims(result, axis) + return result.astype(data.dtype) + + a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) + if keepdims is None and axis is None: + b_np = _argmin_numpy(a_np1) + node = onnx.helper.make_node('ArgMin', + inputs=['a_np1'], + outputs=['out']) + elif axis is None: + b_np = _argmin_numpy(a_np1, keepdims=keepdims) + node = onnx.helper.make_node('ArgMin', + inputs=['a_np1'], + outputs=['out'], + keepdims=keepdims) + elif keepdims is None: + b_np = _argmin_numpy(a_np1, axis=axis) + node = onnx.helper.make_node('ArgMin', + inputs=['a_np1'], + outputs=['out'], + axis=axis) + else: + b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims) + node = onnx.helper.make_node('ArgMin', + inputs=['a_np1'], + outputs=['out'], + axis=axis, + keepdims=keepdims) + graph = helper.make_graph([node], + "argmin_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.INT32, list(a_np1.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.INT32, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='argmin_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def verify_argmax(input_dim, axis=None, keepdims=None): + def _argmax_numpy(data, axis=0, keepdims=True): + result = np.argmax(data, axis=axis) + if (keepdims == 1): + result = np.expand_dims(result, axis) + return result.astype(data.dtype) + + a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32) + if keepdims is None and axis is None: + b_np = _argmax_numpy(a_np1) + node = onnx.helper.make_node('ArgMax', + inputs=['a_np1'], + outputs=['out']) + elif axis is None: + b_np = _argmax_numpy(a_np1, keepdims=keepdims) + node = onnx.helper.make_node('ArgMax', + inputs=['a_np1'], + outputs=['out'], + keepdims=keepdims) + elif keepdims is None: + b_np = _argmax_numpy(a_np1, axis=axis) + node = onnx.helper.make_node('ArgMax', + inputs=['a_np1'], + outputs=['out'], + axis=axis) + else: + b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims) + node = onnx.helper.make_node('ArgMax', + inputs=['a_np1'], + outputs=['out'], + axis=axis, + keepdims=keepdims) + + graph = helper.make_graph([node], + "argmax_test", + inputs = [helper.make_tensor_value_info("a_np1", + TensorProto.INT32, list(a_np1.shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.INT32, list(b_np.shape))]) + + model = helper.make_model(graph, producer_name='argmax_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, b_np.dtype) + tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5) + +def test_forward_arg_min_max(): + '''Verify argmin and argmax''' + verify_argmin([3,4,4]) + verify_argmax([3,4,4]) + verify_argmin([3,4,4], axis=1) + verify_argmax([3,4,4], axis=0) + verify_argmin([3,4,4], keepdims=0) + verify_argmax([3,4,4], keepdims=1) + for axis in [None, 0,1,2]: + for keepdims in [None, True,False]: + verify_argmin([3,4,4], axis, keepdims) + verify_argmax([3,4,4], axis, keepdims) + +def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs): + input_a = np.random.uniform(size=input_dim).astype(dtype) + out = np.empty(shape=out_dim, dtype=dtype) + out.fill(value) + + if is_shape == True: + fill_node = helper.make_node("ConstantFill", [], ["out"], shape=input_dim, value=value, **kwargs) + else: + fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs) + + graph = helper.make_graph([fill_node], + "fill_test", + inputs = [helper.make_tensor_value_info("input_a", + TensorProto.FLOAT, list(input_dim))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out.shape))]) + + model = helper.make_model(graph, producer_name='fill_test') + + for target, ctx in ctx_list(): + if is_shape == True: + tvm_out = get_tvm_output(model, [], target, ctx, out.shape) + else: + tvm_out = get_tvm_output(model, [input_a], target, ctx, out.shape) + + tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5) + +def test_constantfill(): + verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32') + verify_constantfill(False, (2, 3, 4, 5), (2, 3, 4, 5), 10, 'float32') + verify_constantfill(True, (2, 3, 4, 5), (2, 3, 4, 5, 4, 5, 6), 10, 'float32', extra_shape=(4, 5, 6)) + + +def verify_pad(indata, pads, value=0.0): + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] + outdata = np.pad(indata, pad_width=np_pads, mode='constant', constant_values=value) + # onnx graph + node = helper.make_node( + 'Pad', + inputs=['input'], + outputs=['output'], + mode='constant', + pads=pads, + value=value + ) + graph = helper.make_graph([node], + 'pad_test', + inputs = [helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + model = helper.make_model(graph, producer_name='pad_test') + # tvm result + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + +def test_pad(): + verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 0.0) + verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 0.0) + verify_pad(np.random.randn(3, 2).astype(np.float32), [0, 0, 1, 0], 5.0) + +def verify_reduce_x(name, indata, axis, keepdims): + indata = np.array(indata).astype(np.float32) + # numpy expect result + if name == 'ReduceMax': + outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1) + elif name == 'ReduceMin': + outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1) + elif name == 'ReduceSum': + outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1) + elif name == 'ReduceMean': + outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1) + else: + raise Exception('unsupport op: {}'.format(name)) + if len(np.asarray(outdata).shape) == 0: + outdata = np.asarray([outdata]) + # onnx graph + if axis is None: + node = helper.make_node(name, inputs=['input'], outputs=['output'], + keepdims=keepdims) + else: + node = helper.make_node(name, inputs=['input'], outputs=['output'], + axes=axis, keepdims=keepdims) + graph = helper.make_graph([node], + '{}_test'.format(name), + inputs = [helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + model = helper.make_model(graph, producer_name='{}_test'.format(name)) + # tvm result + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 'float32') + tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + +def test_reduce_max(): + verify_reduce_x("ReduceMax", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=1) + verify_reduce_x("ReduceMax", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=0) + verify_reduce_x("ReduceMax", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=1) + +def test_reduce_min(): + verify_reduce_x("ReduceMin", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=1) + verify_reduce_x("ReduceMin", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=0) + verify_reduce_x("ReduceMin", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=1) + +def test_reduce_sum(): + verify_reduce_x("ReduceSum", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=1) + verify_reduce_x("ReduceSum", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=0) + verify_reduce_x("ReduceSum", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=1) + +def test_reduce_mean(): + verify_reduce_x("ReduceMean", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=1) + verify_reduce_x("ReduceMean", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=0) + verify_reduce_x("ReduceMean", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=1) + +def verify_split(indata, outdatas, split, axis=0): + indata = np.array(indata).astype(np.float32) + outdatas = [np.array(o).astype(np.float32) for o in outdatas] + node = helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_{}'.format(i) for i in range(len(split))], + axis=axis, + split=split + ) + graph = helper.make_graph([node], + 'split_test', + inputs = [helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape))], + outputs = [helper.make_tensor_value_info("output_{}".format(i), + TensorProto.FLOAT, list(outdatas[i].shape)) + for i in range(len(split)) + ]) + model = helper.make_model(graph, producer_name='split_test') + + for target, ctx in ctx_list(): + output_shape = [o.shape for o in outdatas] + output_type = ['float32', 'float32', 'float32'] + tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, output_type) + for o, t in zip(outdatas, tvm_out): + tvm.testing.assert_allclose(o, t) + +def test_split(): + # 1D + verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3., 4.], [5., 6.]], [2, 2, 2], 0) + verify_split([1., 2., 3., 4., 5., 6.], [[1., 2.], [3.], [4., 5., 6.]], [2, 1, 3], 0) + # 2D + verify_split([[1., 2., 3., 4.], [7., 8., 9., 10.]], + [[[1., 2.], [7., 8.]], [[3., 4.], [9., 10.]]], [2, 2], 1) + +def test_binary_ops(): + in_shape = (1, 2, 3, 3) + dtype = "float32" + out_shape = in_shape + + def verify_binary_ops(op, x, y, out_np, broadcast=None): + if broadcast is None: + z = helper.make_node(op, ['in1', 'in2'], ['out']) + else: + z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1) + graph = helper.make_graph([z], + '_test', + inputs = [helper.make_tensor_value_info("in1", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in2", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) + model = helper.make_model(graph, producer_name='_test') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, y], target, ctx) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + + x = np.random.uniform(size=in_shape).astype(dtype) + y = np.random.uniform(size=in_shape).astype(dtype) + z = np.random.uniform(size=(3,)).astype(dtype) + verify_binary_ops("Add",x, y, x + y, broadcast=None) + verify_binary_ops("Add", x, z, x + z, broadcast=True) + verify_binary_ops("Sub", x, y, x - y, broadcast=None) + verify_binary_ops("Sub", x, z, x - z, broadcast=True) + verify_binary_ops("Mul",x, y, x * y, broadcast=None) + verify_binary_ops("Mul", x, z, x * z, broadcast=True) + verify_binary_ops("Div", x, y, x / y, broadcast=None) + verify_binary_ops("Div", x, z, x / z, broadcast=True) + verify_binary_ops("Sum", x, y, x + y, broadcast=None) + +def test_single_ops(): + in_shape = (1, 2, 3, 3) + dtype = "float32" + out_shape = in_shape + + def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): + z = helper.make_node(op, ['in1'], ['out']) + graph = helper.make_graph([z], + '_test', + inputs = [helper.make_tensor_value_info("in1", + TensorProto.FLOAT, list(in_shape)),], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_shape))]) + model = helper.make_model(graph, producer_name='_test') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x], target, ctx) + tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol) + + x = np.random.uniform(size=in_shape).astype(dtype) + verify_single_ops("Neg",x, -x) + verify_single_ops("Abs",x, np.abs(x)) + verify_single_ops("Reciprocal",x, 1/x) + verify_single_ops("Sqrt",x, np.sqrt(x)) + verify_single_ops("Relu",x, np.maximum(x, 0)) + verify_single_ops("Exp",x, np.exp(x)) + verify_single_ops("Log",x, np.log(x)) + verify_single_ops("Log",x, np.log(x)) + verify_single_ops("Tanh",x, np.tanh(x)) + verify_single_ops("Sigmoid",x, 1 / (1 + np.exp(-x))) + verify_single_ops("Softsign",x, x / (1 + np.abs(x))) + verify_single_ops("SoftPlus",x, np.log(1 + np.exp(x))) + +def test_leaky_relu(): + def leaky_relu_x(x, alpha): + return np.where(x >= 0, x, x * alpha) + _test_onnx_op_elementwise((2, 4, 5, 6), + leaky_relu_x, + {'alpha': 0.25}, + 'float32', + 'LeakyRelu', + {'alpha': 0.25}) + +def test_elu(): + def elu_x(x, alpha): + return np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) + _test_onnx_op_elementwise((2, 4, 5, 6), + elu_x, + {'alpha': 0.25}, + 'float32', + 'Elu', + {'alpha': 0.25}) + +def test_selu(): + def selu_x(x, alpha, gamma): + return gamma * np.where(x > 0, x, alpha * (np.exp(x) - 1.0)) + _test_onnx_op_elementwise((2, 4, 5, 6), + selu_x, + {'alpha': 0.25, 'gamma': 0.3}, + 'float32', + 'Selu', + {'alpha': 0.25, 'gamma': 0.3}) + +def test_ThresholdedRelu(): + def ThresholdedRelu_x(x, alpha): + out_np = np.clip(x, alpha, np.inf) + out_np[out_np == alpha] = 0 + return out_np + _test_onnx_op_elementwise((2, 4, 5, 6), + ThresholdedRelu_x, + {'alpha': 0.25}, + 'float32', + 'ThresholdedRelu', + {'alpha': 0.25}) + +def test_ScaledTanh(): + def ScaledTanh_x(x, alpha, beta): + return alpha * np.tanh(beta * x) + _test_onnx_op_elementwise((2, 4, 5, 6), + ScaledTanh_x, + {'alpha': 0.25, 'beta': 0.3}, + 'float32', + 'ScaledTanh', + {'alpha': 0.25, 'beta': 0.3}) + +def test_ParametricSoftplus(): + def ParametricSoftplus_x(x, alpha, beta): + return alpha * np.log(np.exp(beta * x) + 1) + _test_onnx_op_elementwise((2, 4, 5, 6), + ParametricSoftplus_x, + {'alpha': 0.25, 'beta': 0.3}, + 'float32', + 'ParametricSoftplus', + {'alpha': 0.25, 'beta': 0.3}) + +def test_Scale(): + def Scale_x(x, scale): + return scale * x + _test_onnx_op_elementwise((2, 4, 5, 6), + Scale_x, + {'scale': 0.25}, + 'float32', + 'Scale', + {'scale': 0.25}) + +def test_LogSoftmax(): + _test_onnx_op_elementwise((1, 4), + topi.testing.log_softmax_python, + {}, + 'float32', + 'LogSoftmax', + {'axis': 1}) + +if __name__ == '__main__': + test_reshape() + test_reshape_like() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_clip() + test_matmul() + test_gather() + test_lrn() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantfill() + test_pad() + test_reduce_max() + test_reduce_min() + test_reduce_sum() + test_reduce_mean() + test_pad() + test_split() + test_binary_ops() + test_single_ops() + test_leaky_relu() + test_elu() + test_selu() + test_ThresholdedRelu() + test_ScaledTanh() + test_ParametricSoftplus() + test_Scale() + test_LogSoftmax() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index d3041e20b..bede96e80 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -32,3 +32,6 @@ python3 -m nose -v tests/python/frontend/mxnet || exit -1 echo "Running relay Keras frontend test..." python3 -m nose -v tests/python/frontend/keras || exit -1 + +echo "Running relay ONNX frondend test..." +python3 -m nose -v tests/python/frontend/onnx || exit -1 diff --git a/tutorials/relay/from_onnx.py b/tutorials/relay/from_onnx.py new file mode 100644 index 000000000..46b803a3a --- /dev/null +++ b/tutorials/relay/from_onnx.py @@ -0,0 +1,93 @@ +""" +Compile ONNX Models +=================== +**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_ + +This article is an introductory tutorial to deploy ONNX models with Relay. + +For us to begin with, ONNX package must be installed. + +A quick solution is to install protobuf compiler, and + +.. code-block:: bash + + pip install onnx --user + +or please refer to offical site. +https://github.com/onnx/onnx +""" +import onnx +import numpy as np +import tvm +import tvm.relay as relay + +def download(url, path, overwrite=False): + import os + if os.path.isfile(path) and not overwrite: + print('File {} existed, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + import urllib.request + urllib.request.urlretrieve(url, path) + except: + import urllib + urllib.urlretrieve(url, path) + +###################################################################### +# Load pretrained ONNX model +# --------------------------------------------- +# The example super resolution model used here is exactly the same model in onnx tutorial +# http://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html +# we skip the pytorch model construction part, and download the saved onnx model +model_url = ''.join(['https://gist.github.com/zhreshold/', + 'bcda4716699ac97ea44f791c24310193/raw/', + '93672b029103648953c4e5ad3ac3aadf346a4cdc/', + 'super_resolution_0.2.onnx']) +download(model_url, 'super_resolution.onnx', False) +# now you have super_resolution.onnx on disk +onnx_model = onnx.load('super_resolution.onnx') + +###################################################################### +# Load a test image +# --------------------------------------------- +# A single cat dominates the examples! +from PIL import Image +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +download(img_url, 'cat.png') +img = Image.open('cat.png').resize((224, 224)) +img_ycbcr = img.convert("YCbCr") # convert to YCbCr +img_y, img_cb, img_cr = img_ycbcr.split() +x = np.array(img_y)[np.newaxis, np.newaxis, :, :] + +###################################################################### +# Compile the model with relay +# --------------------------------------------- +target = 'llvm' + +input_name = '1' +shape_dict = {input_name: x.shape} +sym, params = relay.frontend.from_onnx(onnx_model, shape_dict) + +with relay.build_config(opt_level=1): + intrp = relay.build_module.create_executor('graph', sym, tvm.cpu(0), target) + +###################################################################### +# Execute on TVM +# --------------------------------------------- +tvm_output = intrp.evaluate(sym)(tvm.nd.array(x.astype(dtype)), **params).asnumpy() + +###################################################################### +# Display results +# --------------------------------------------- +# We put input and output image neck to neck +from matplotlib import pyplot as plt +out_y = Image.fromarray(np.uint8((tvm_output[0, 0]).clip(0, 255)), mode='L') +out_cb = img_cb.resize(out_y.size, Image.BICUBIC) +out_cr = img_cr.resize(out_y.size, Image.BICUBIC) +result = Image.merge('YCbCr', [out_y, out_cb, out_cr]).convert('RGB') +canvas = np.full((672, 672*2, 3), 255) +canvas[0:224, 0:224, :] = np.asarray(img) +canvas[:, 672:, :] = np.asarray(result) +plt.imshow(canvas.astype(np.uint8)) +plt.show() -- GitLab