diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 95633a4d4586da2436368f097b559147e00edc7d..20598400ce2144aad564a45e8da400b61e4fde6b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -106,6 +106,30 @@ class StrAttrsDict(object): raise AttributeError("Required attribute {} not found.".format(key)) return default + def get_float_tuple(self, key, default=RequiredAttr()): + """Get float tuple attribute + + Parameters + ---------- + key : str + The attribute key + + default : float + The default value. + + Returns + ------- + value : The result + """ + + if key in self.attrs: + tshape = self.attrs[key] + return tuple(float(x.strip()) for x in + tshape.strip('()[]').split(',')) + if isinstance(default, RequiredAttr): + raise AttributeError("Required attribute {} not found.".format(key)) + return default + def get_tuple_tuple_int(self, key, default=RequiredAttr()): """Get int list attribute diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index f61c65bbaf6a3aec61592a1813c8b8d0d0ccf88d..7bffbd4f499eaae6b4b67ead040ef8dee8407332 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -241,6 +241,33 @@ def _mx_lrn(inputs, attrs): return _op.nn.lrn(inputs[0], **new_attrs) +def _mx_multibox_prior(inputs, attrs): + new_attrs = {} + new_attrs["sizes"] = attrs.get_float_tuple("sizes", (1.0, )) + new_attrs["steps"] = attrs.get_float_tuple("steps", (-1.0, -1.0)) + new_attrs["offsets"] = attrs.get_float_tuple("offsets", (0.5, 0.5)) + new_attrs["ratios"] = attrs.get_float_tuple("ratios", (1.0, )) + new_attrs["clip"] = attrs.get_bool("clip", False) + return _op.vision.multibox_prior(inputs[0], **new_attrs) + + +def _mx_multibox_detection(inputs, attrs): + new_attrs0 = {} + new_attrs0["clip"] = attrs.get_bool("clip", True) + new_attrs0["threshold"] = attrs.get_float("threshold", 0.01) + new_attrs0["variances"] = attrs.get_float_tuple("variances", (0.1, 0.1, + 0.2, 0.2)) + + new_attrs1 = {} + new_attrs1["overlap_threshold"] = attrs.get_float("nms_threshold", 0.5) + new_attrs1["force_suppress"] = attrs.get_bool("force_suppress", False) + new_attrs1["topk"] = attrs.get_int("nms_topk", -1) + + ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1], + inputs[2], **new_attrs0) + return _op.vision.nms(ret[0], ret[1], **new_attrs1) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -327,13 +354,14 @@ _convert_map = { "LeakyReLU" : _mx_leaky_relu, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, + # vision + "_contrib_MultiBoxPrior" : _mx_multibox_prior, + "_contrib_MultiBoxDetection" : _mx_multibox_detection, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # # "broadcast_to", # "gather_nd", - # "_contrib_MultiBoxPrior" : _rename("multibox_prior"), - # "_contrib_MultiBoxDetection" : _contrib_multibox_detection, # "Crop" : _crop_like, } diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py index 9ecd8a84770a905905259181cbdcc4845ed9162d..ea3ed69e8f384dde5793a15de6c0da6ca9dbf6db 100644 --- a/python/tvm/relay/op/vision/__init__.py +++ b/python/tvm/relay/op/vision/__init__.py @@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs from .multibox import * from .nms import * +from . import _multibox diff --git a/python/tvm/relay/op/vision/_multibox.py b/python/tvm/relay/op/vision/_multibox.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ef43f7e06f9b3fc3f7ddded5a79db12993fecc --- /dev/null +++ b/python/tvm/relay/op/vision/_multibox.py @@ -0,0 +1,77 @@ +# pylint: disable=invalid-name, unused-argument +"""Definition of vision ops""" +from __future__ import absolute_import + +import topi +from topi.util import get_const_int, get_const_float, get_float_tuple +from .. import op as reg +from ..op import OpPattern + + +@reg.register_schedule("vision.multibox_prior") +def schedule_multibox_prior(_, outs, target): + """Schedule definition of multibox_prior""" + with target: + return topi.generic.schedule_multibox_prior(outs) + + +@reg.register_compute("vision.multibox_prior") +def compute_multibox_prior(attrs, inputs, _, target): + """Compute definition of multibox_prior""" + sizes = get_float_tuple(attrs.sizes) + ratios = get_float_tuple(attrs.ratios) + steps = get_float_tuple(attrs.steps) + offsets = get_float_tuple(attrs.offsets) + clip = bool(get_const_int(attrs.clip)) + return [ + topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios, steps, + offsets, clip) + ] + + +reg.register_pattern("vision.multibox_prior", OpPattern.OPAQUE) + + +# multibox_transform_loc +@reg.register_schedule("vision.multibox_transform_loc") +def schedule_multibox_transform_loc(_, outs, target): + """Schedule definition of multibox_detection""" + with target: + return topi.generic.schedule_multibox_transform_loc(outs) + + +@reg.register_compute("vision.multibox_transform_loc") +def compute_multibox_transform_loc(attrs, inputs, _, target): + """Compute definition of multibox_detection""" + clip = bool(get_const_int(attrs.clip)) + threshold = get_const_float(attrs.threshold) + variances = get_float_tuple(attrs.variances) + return topi.vision.ssd.multibox_transform_loc( + inputs[0], inputs[1], inputs[2], clip, threshold, variances) + + +reg.register_pattern("vision.multibox_transform_loc", OpPattern.OPAQUE) +reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE) + + +# non-maximum suppression +@reg.register_schedule("vision.nms") +def schedule_nms(_, outs, target): + """Schedule definition of nms""" + with target: + return topi.generic.schedule_nms(outs) + + +@reg.register_compute("vision.nms") +def compute_nms(attrs, inputs, _, target): + """Compute definition of nms""" + overlap_threshold = get_const_float(attrs.overlap_threshold) + force_suppress = bool(get_const_int(attrs.force_suppress)) + topk = get_const_int(attrs.topk) + return [ + topi.vision.nms(inputs[0], inputs[1], overlap_threshold, + force_suppress, topk) + ] + + +reg.register_pattern("vision.nms", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/multibox.py b/python/tvm/relay/op/vision/multibox.py index b04610aaa0809ab5d33049241fcdfb0acb99372e..90591da925f5e967a57e8b734e5d1c6e4b5578c4 100644 --- a/python/tvm/relay/op/vision/multibox.py +++ b/python/tvm/relay/op/vision/multibox.py @@ -1,6 +1,7 @@ """Multibox operations.""" from __future__ import absolute_import as _abs from . import _make +from ...expr import TupleWrapper def multibox_prior(data, sizes=(1.0,), @@ -43,7 +44,7 @@ def multibox_transform_loc(cls_prob, anchor, clip=True, threshold=0.01, - variance=(0.1, 0.1, 0.2, 0.2)): + variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters @@ -63,12 +64,13 @@ def multibox_transform_loc(cls_prob, threshold : double, optional Threshold to be a positive prediction. - variance : Tuple of float, optional - Variances to be decoded from box regression output. + variances : Tuple of float, optional + variances to be decoded from box regression output. Returns ------- ret : tuple of tvm.relay.Expr """ - return _make.multibox_transform_loc(cls_prob, loc_pred, anchor, clip, - threshold, variance) + return TupleWrapper(_make.multibox_transform_loc(cls_prob, loc_pred, + anchor, clip, threshold, + variances), 2) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 6bd331b9812043144df680f4a9dbe5aa1eca5740..aa31aa96ef45ea105056c79fe0df631dd7d654f8 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1,11 +1,13 @@ """ Support level5 operator test cases. """ +import math import numpy as np import tvm from tvm import relay from tvm.relay.testing import ctx_list import topi.testing + def test_resize_infer_type(): n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) @@ -48,64 +50,163 @@ def test_resize(): for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) + def test_multibox_prior(): + def get_ref_result(dshape, sizes=(1.0,), + ratios=(1.0,), steps=(-1.0, -1.0), + offsets=(0.5, 0.5), clip=True): + in_height = dshape[2] + in_width = dshape[3] + num_sizes = len(sizes) + num_ratios = len(ratios) + size_ratio_concat = sizes + ratios + steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height + steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width + offset_h = offsets[0] + offset_w = offsets[1] + + oshape = (1, in_height * in_width * (num_sizes + num_ratios - 1), 4) + dtype = "float32" + np_out = np.zeros(oshape).astype(dtype) + + for i in range(in_height): + center_h = (i + offset_h) * steps_h + for j in range(in_width): + center_w = (j + offset_w) * steps_w + for k in range(num_sizes + num_ratios - 1): + w = size_ratio_concat[k] * in_height / in_width / 2.0 if k < num_sizes else \ + size_ratio_concat[0] * in_height / in_width * math.sqrt(size_ratio_concat[k + 1]) / 2.0 + h = size_ratio_concat[k] / 2.0 if k < num_sizes else \ + size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0 + count = i * in_width * (num_sizes + num_ratios - 1) + j * (num_sizes + num_ratios - 1) + k + np_out[0][count][0] = center_w - w + np_out[0][count][1] = center_h - h + np_out[0][count][2] = center_w + w + np_out[0][count][3] = center_h + h + if clip: + np_out = np.clip(np_out, 0, 1) + + return np_out + + def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,), + ratios=(1.0,), steps=(-1.0, -1.0), + offsets=(0.5, 0.5), clip=True, check_size=False, + check_type_only=False): + + z = relay.vision.multibox_prior(x, sizes, ratios, steps, offsets, clip) + zz = relay.ir_pass.infer_type(z) + if check_size: + assert "sizes=" in z.astext() + assert zz.checked_type == relay.TensorType( + (1, dshape[2] * dshape[3] * (len(sizes) + len(ratios) - 1), 4), + "float32") + + if check_type_only: + return + + data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") + func = relay.Function([x], z) + func = relay.ir_pass.infer_type(func) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res2 = intrp2.evaluate(func)(data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + sizes = (0.3, 1.5, 0.7) ratios = (1.3, 2.4) steps = (2.0, 1.5) offsets = (0.2, 0.3) - clip = True - - n, c, h, w = tvm.var("n"), 3, 56, 56 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - - z = relay.vision.multibox_prior(x, sizes, ratios, - steps, offsets, clip) - assert "sizes=" in z.astext() - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.TensorType( - (1, h * w * (len(sizes) + len(ratios) - 1), 4), "float32") + dshape = (1, 3, 56, 56) + ref_res = get_ref_result(dshape, sizes, ratios, steps, offsets) + x = relay.var("x", relay.TensorType(dshape, "float32")) + verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets, + check_size=True) + y = relay.var("y", relay.TensorType((tvm.var("n"), 3, 56, 56), "float32")) + verify_multibox_prior(x, dshape, ref_res, sizes, ratios, steps, offsets, + check_size=True, check_type_only=True) - n, c, h, w = tvm.var("n"), 24, 32, 32 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - z = relay.vision.multibox_prior(x) - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.TensorType( - (1, h * w, 4), "float32") + dshape = (1, 24, 32, 32) + ref_res = get_ref_result(dshape, clip=False) + x = relay.var("x", relay.TensorType(dshape, "float32")) + verify_multibox_prior(x, dshape, ref_res, clip=False) + y = relay.var("y", relay.TensorType((tvm.var("n"), 24, 32, 32), "float32")) + verify_multibox_prior(x, dshape, ref_res, clip=False, check_type_only=True) def test_nms(): - num_anchors = 60 - - overlap_threshold = 0.5 - force_suppress = True - nms_topk = 10 - - n = tvm.var("n") - x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32")) - x1 = relay.var("x1", relay.ty.TensorType((n,), "int")) + def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count, + overlap_threshold=0.5, force_suppress=False, topk=-1, + check_type_only=False): + x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) + x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int")) + z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, topk) + assert "overlap_threshold" in z.astext() + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType(dshape, "float32") - z = relay.vision.nms(x0, x1, overlap_threshold, force_suppress, nms_topk) + if check_type_only: + return - assert "overlap_threshold" in z.astext() - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType( - (n, num_anchors, 6), "float32") + func = relay.Function([x0, x1], z) + func = relay.ir_pass.infer_type(func) + ctx_list = [("llvm", tvm.cpu(0))] + for target, ctx in ctx_list: + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x0_data, x1_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res2 = intrp2.evaluate(func)(x0_data, x1_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - n = tvm.var("n") - x0 = relay.var("x0", relay.ty.TensorType((n, num_anchors, 6), "float32")) - x1 = relay.var("x1", relay.ty.TensorType((n,), "int")) + np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], + [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], + [1, 0.5, 100, 60, 70, 110]]]).astype("float32") + np_valid_count = np.array([4]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [0, 0.4, 4, 21, 19, 40], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1]]]) + num_anchors = 5 - z = relay.vision.nms(x0, x1) + dshape = (tvm.var("n"), num_anchors, 6) + verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], + force_suppress=True, topk=2, check_type_only=True) + dshape = (1, num_anchors, 6) + verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], + force_suppress=True, topk=2, check_type_only=False) - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType( - (n, num_anchors, 6), "float32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], + [1, 0.7, 30, 60, 50, 80], [-1, 0.9, 35, 61, 52, 79], + [-1, -1, -1, -1, -1, -1]]]) + dshape = (tvm.var("n"), num_anchors, 6) + verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], + check_type_only=True) + dshape = (1, num_anchors, 6) + verify_nms(np_data, np_valid_count, dshape, np_result, dshape[0], + topk=3) def test_multibox_transform_loc(): def test_default_value(): - num_anchors = 5 - num_classes = 5 + num_anchors = 3 + num_classes = 3 + + np_cls_prob = np.array( + [[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], + [0.7, 0.1, 0.2]]]).astype("float32") + np_loc_preds = np.array( + [[0.1, -0.2, 0.3, 0.2, 0.2, 0.4, 0.5, -0.3, 0.7, -0.2, -0.4, + -0.8]]).astype("float32") + np_anchors = np.array( + [[[-0.1, -0.1, 0.1, 0.1], [-0.2, -0.2, 0.2, 0.2], + [1.2, 1.2, 1.5, 1.5]]]).astype("float32") + + expected_np_out = np.array([[[1, 0.69999999, 0, 0, 0.10818365, 0.10008108], + [0, 0.44999999, 1, 1, 1, 1], + [0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]]) + cls_prob = relay.var( "cls_prob", @@ -115,16 +216,31 @@ def test_multibox_transform_loc(): anchors = relay.var( "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32")) - ret = relay.vision.multibox_transform_loc( + mtl = relay.vision.multibox_transform_loc( cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors) - ret = relay.ir_pass.infer_type(ret) + ret = relay.ir_pass.infer_type(mtl.astuple()) ref_type = relay.ty.TupleType( tvm.convert([ relay.ty.TensorType((1, num_anchors, 6), "float32"), relay.ty.TensorType((1, ), "int") ])) + assert ret.checked_type == ref_type + nms = relay.vision.nms(mtl[0], mtl[1]) + func = relay.Function([cls_prob, loc_pred, anchors], nms) + func = relay.ir_pass.infer_type(func) + ctx_list = [("llvm", tvm.cpu(0))] + for target, ctx in ctx_list: + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, + np_anchors) + tvm.testing.assert_allclose(op_res1.asnumpy(), expected_np_out, rtol=1e-5) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res2 = intrp2.evaluate(func)(np_cls_prob, np_loc_preds, + np_anchors) + tvm.testing.assert_allclose(op_res2.asnumpy(), expected_np_out, rtol=1e-5) + def test_threshold(): num_anchors = 5 num_classes = 5 @@ -137,15 +253,15 @@ def test_multibox_transform_loc(): anchors = relay.var( "anchors", relay.ty.TensorType((1, num_anchors, 4), "float32")) threshold = 0.02 - variance = (0.2, 0.2, 0.3, 0.3) + variances = (0.2, 0.2, 0.3, 0.3) ret = relay.vision.multibox_transform_loc( cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors, threshold=threshold, - variance=variance) - ret = relay.ir_pass.infer_type(ret) + variances=variances) + ret = relay.ir_pass.infer_type(ret.astuple()) ref_type = relay.ty.TupleType( tvm.convert([ relay.ty.TensorType((n, num_anchors, 6), "float32"), diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index de9ff90ae26ba39bafc6622f74f55421c84942e6..edfb0e467e1fe7aa08691b3c0320f41a5990a01d 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -78,6 +78,28 @@ def get_const_int(expr): return int(expr.value) +def get_const_float(expr): + """Verifies expr is a floating point and get the constant value. + + Parameters + ---------- + expr : tvm.Expr or float + The input expression. + + Returns + ------- + out_value : float + The output. + """ + if isinstance(expr, float): + return float(expr) + if not isinstance(expr, tvm.expr.FloatImm): + expr = tvm.ir_pass.Simplify(expr) + if not isinstance(expr, tvm.expr.FloatImm): + raise ValueError("Expect value to be constant float") + return float(expr.value) + + def equal_const_int(expr, value): """Returns if expr equals value. @@ -120,6 +142,26 @@ def get_const_tuple(in_tuple): return out_tuple +def get_float_tuple(in_tuple): + """Verifies input tuple is FloatImm, returns tuple of float. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of float + The output. + """ + out_tuple = () + for elem in in_tuple: + value = get_const_float(elem) + out_tuple = out_tuple + (value, ) + return out_tuple + + def simplify(expr): """Simplify the expression if it is Expr, directly return if it is int. diff --git a/tutorials/nnvm/deploy_ssd.py b/tutorials/nnvm/deploy_ssd.py index 3f5f89a632b6dc25de4b3fa2eb7cfdb73b71b5f3..9afa113959f0407e5ecb08e54c7078c89b62dccb 100644 --- a/tutorials/nnvm/deploy_ssd.py +++ b/tutorials/nnvm/deploy_ssd.py @@ -5,7 +5,7 @@ Deploy Single Shot Multibox Detector(SSD) model This article is an introductory tutorial to deploy SSD models with TVM. We will use mxnet pretrained SSD model with Resnet50 as body network and -convert it to NNVM graph. +convert it to NNVM graph; """ import os import zipfile @@ -16,6 +16,7 @@ import numpy as np from nnvm import compiler from nnvm.frontend import from_mxnet +from tvm import relay from tvm.contrib.download import download from tvm.contrib import graph_runtime from mxnet.model import load_checkpoint @@ -58,7 +59,7 @@ image_url = "https://cloud.githubusercontent.com/assets/3307514/20012567/" \ inference_symbol_folder = "c1904e900848df4548ce5dfb18c719c7-a28c4856c827fe766aa3da0e35bad41d44f0fb26" inference_symbol_url = "https://gist.github.com/kevinthesun/c1904e900848df4548ce5dfb18c719c7/" \ "archive/a28c4856c827fe766aa3da0e35bad41d44f0fb26.zip" - + dir = "ssd_model" if not os.path.exists(dir): os.makedirs(dir) @@ -77,13 +78,31 @@ zip_ref.extractall(dir) zip_ref.close() ###################################################################### -# Convert and compile model with NNVM for CPU. +# Convert and compile model with NNVM or Relay for CPU. sym = mx.sym.load("%s/%s/ssd_resnet50_inference.json" % (dir, inference_symbol_folder)) _, arg_params, aux_params = load_checkpoint("%s/%s" % (dir, model_name), 0) -net, params = from_mxnet(sym, arg_params, aux_params) -with compiler.build_config(opt_level=3): - graph, lib, params = compiler.build(net, target, {"data": dshape}, params=params) + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument( + "-f", "--frontend", + help="Frontend for compilation, nnvm or relay", + type=str, + default="nnvm") +args = parser.parse_args() +if args.frontend == "relay": + net, params = relay.frontend.from_mxnet(sym, {"data": dshape}, arg_params=arg_params, aux_params=aux_params) + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(net, target, params=params) +elif args.frontend == "nnvm": + net, params = from_mxnet(sym, arg_params, aux_params) + with compiler.build_config(opt_level=3): + graph, lib, params = compiler.build( + net, target, {"data": dshape}, params=params) +else: + parser.print_help() + parser.exit() ###################################################################### # Create TVM runtime and do inference @@ -141,4 +160,3 @@ def display(img, out, thresh=0.5): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) display(image, tvm_output.asnumpy()[0], thresh=0.45) -