From b7beb1ebefa18e29bfbf8ff1c4f8f0c8892d93bc Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Tue, 21 Aug 2018 18:35:32 -0700
Subject: [PATCH] [AUTOTVM] Allow fallback for template & Fix bugs in tuners
 (#1615)

* support fallback & fix bugs in tuners & clean topi test

* update task extraction

* update task extraction

* fix arm tutorial

* Update tune_nnvm_arm.py
---
 nnvm/python/nnvm/compiler/build_module.py     |   5 +-
 .../compiler/test_autotvm_task_extraction.py  |  63 +++++++
 python/tvm/autotvm/__init__.py                |   3 +-
 python/tvm/autotvm/measure/measure.py         |   5 +-
 python/tvm/autotvm/measure/measure_methods.py |  44 ++---
 python/tvm/autotvm/task/__init__.py           |   2 +-
 python/tvm/autotvm/task/dispatcher.py         | 117 +++++++++----
 python/tvm/autotvm/task/nnvm_integration.py   | 117 +++++++++----
 python/tvm/autotvm/task/space.py              |  56 ++++++-
 python/tvm/autotvm/task/task.py               |   2 +-
 python/tvm/autotvm/tophub.py                  |   7 +-
 python/tvm/autotvm/tuner/ga_tuner.py          |  10 +-
 python/tvm/autotvm/tuner/model_based_tuner.py |  33 ++--
 .../tvm/autotvm/tuner/sa_model_optimizer.py   |   2 +-
 python/tvm/autotvm/tuner/tuner.py             |  13 +-
 .../tvm/autotvm/tuner/xgboost_cost_model.py   | 119 ++++++++-----
 python/tvm/autotvm/tuner/xgboost_tuner.py     |  17 +-
 python/tvm/exec/tophub.py                     |   9 +-
 python/tvm/target.py                          |   1 +
 .../unittest/test_autotvm_dispatch_context.py |  44 +++--
 tests/python/unittest/test_autotvm_space.py   |  15 +-
 .../unittest/test_autotvm_xgboost_model.py    |   6 +-
 topi/python/topi/arm_cpu/conv2d.py            |  86 +++++++---
 topi/python/topi/arm_cpu/depthwise_conv2d.py  |  14 +-
 topi/python/topi/x86/injective.py             |   2 +-
 topi/tests/python/common.py                   |  12 ++
 .../python/test_topi_bitserial_conv2d.py      |  25 ++-
 .../python/test_topi_bitserial_conv2d_rasp.py |  16 +-
 topi/tests/python/test_topi_bnn.py            |   2 +-
 topi/tests/python/test_topi_broadcast.py      |  25 +--
 topi/tests/python/test_topi_clip.py           |   3 +-
 topi/tests/python/test_topi_conv2d.py         |  47 ------
 topi/tests/python/test_topi_conv2d_hwcn.py    |  14 +-
 topi/tests/python/test_topi_conv2d_nchw.py    | 157 ++++++++++++------
 .../python/test_topi_conv2d_transpose_nchw.py |  22 +--
 topi/tests/python/test_topi_dense.py          |   9 +-
 .../python/test_topi_depthwise_conv2d.py      |  38 ++---
 tutorials/autotvm/tune_nnvm_arm.py            |   5 +-
 38 files changed, 756 insertions(+), 411 deletions(-)
 create mode 100644 nnvm/tests/python/compiler/test_autotvm_task_extraction.py
 create mode 100644 topi/tests/python/common.py
 delete mode 100644 topi/tests/python/test_topi_conv2d.py

diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py
index 217598c9d..6fab4460b 100644
--- a/nnvm/python/nnvm/compiler/build_module.py
+++ b/nnvm/python/nnvm/compiler/build_module.py
@@ -239,8 +239,9 @@ def build(graph, target=None, shape=None, dtype="float32",
         raise ValueError("Target is not set in env or passed as argument.")
     target = tvm.target.create(target)
 
-    # if not inside an autotvm config dispatch context, load pre-tuned parameters from TopHub
-    if autotvm.task.DispatchContext.current is None:
+    # If current dispatch context is fallback context (the default root context),
+    # then load pre-tuned parameters from TopHub
+    if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
         tophub_context = autotvm.tophub.context(target)
     else:
         tophub_context = autotvm.util.EmptyContext()
diff --git a/nnvm/tests/python/compiler/test_autotvm_task_extraction.py b/nnvm/tests/python/compiler/test_autotvm_task_extraction.py
new file mode 100644
index 000000000..fd14934f8
--- /dev/null
+++ b/nnvm/tests/python/compiler/test_autotvm_task_extraction.py
@@ -0,0 +1,63 @@
+"""Test task extraction for autotvm"""
+
+import nnvm.testing
+import nnvm.compiler
+from tvm import autotvm
+
+def get_network(name, batch_size):
+    """Get the symbol definition and random weight of a network"""
+    input_shape = (batch_size, 3, 224, 224)
+    output_shape = (batch_size, 1000)
+
+    if name == 'resnet-18':
+        net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
+    elif name == 'mobilenet':
+        net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
+    elif name == 'squeezenet v1.1':
+        net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
+    elif name == 'vgg-16':
+        net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size)
+    elif name == 'dcgan':
+        net, params = nnvm.testing.dcgan.get_workload(batch_size=batch_size)
+        input_shape = (batch_size, 100)
+    else:
+        raise ValueError("Unsupported network: " + name)
+
+    return net, params, input_shape, output_shape
+
+def test_task_extraction():
+    target = 'llvm'
+    dtype = 'float32'
+
+    net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_graph(net, target=target,
+                                            shape={'data': input_shape}, dtype=dtype,
+                                            symbols=(nnvm.sym.conv2d,))
+    assert len(tasks) == 12
+
+    net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_graph(net, target=target,
+                                            shape={'data': input_shape}, dtype=dtype,
+                                            symbols=(nnvm.sym.dense,))
+    assert len(tasks) == 1
+
+    net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_graph(net, target=target,
+                                            shape={'data': input_shape}, dtype=dtype,
+                                            symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
+    assert len(tasks) == 13
+
+    net, params, input_shape, out_shape = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_graph(net, target=target,
+                                            shape={'data': input_shape}, dtype=dtype,
+                                            symbols=(nnvm.sym.conv2d, nnvm.sym.dense))
+    assert len(tasks) == 20
+
+    net, params, input_shape, out_shape = get_network('dcgan', batch_size=1)
+    tasks = autotvm.task.extract_from_graph(net, target=target,
+                                            shape={'data': input_shape}, dtype=dtype,
+                                            symbols=(nnvm.sym.conv2d_transpose,))
+    assert len(tasks) == 4
+
+if __name__ == '__main__':
+    test_task_extraction()
diff --git a/python/tvm/autotvm/__init__.py b/python/tvm/autotvm/__init__.py
index 5b312d93d..625b50c10 100644
--- a/python/tvm/autotvm/__init__.py
+++ b/python/tvm/autotvm/__init__.py
@@ -25,5 +25,6 @@ from . import tophub
 from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
 from .tuner import callback
 from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
-    ApplyHistoryBest as apply_history_best
+    register_topi_compute, register_topi_schedule, \
+    DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best
 from .env import GLOBAL_SCOPE
diff --git a/python/tvm/autotvm/measure/measure.py b/python/tvm/autotvm/measure/measure.py
index 2325a970b..2d780eeaf 100644
--- a/python/tvm/autotvm/measure/measure.py
+++ b/python/tvm/autotvm/measure/measure.py
@@ -89,8 +89,9 @@ def measure_option(measure_func,
 
         callable: customized build function for other backends (e.g. VTA).
                   See measure/measure_methods.py::default_build_func for example.
-    check_correctness: bool
-        Whether check correctness after measurement. This will use llvm cpu as reference.
+    check_correctness: bool, optional
+        Whether check correctness after measurement. This will use llvm cpu target to generate
+        reference output.
     replay_db : Database, optional
         The database that we retrieve saved MeasureResult from.
 
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index d845cc1f8..2d740b949 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -83,7 +83,7 @@ def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
         The priority of this request, larger is more prior
     timeout: float, optional
         The timeout of this check (units: seconds).
-        If time is out, a RuntimerError will be raised.
+        If time is out, a RuntimeError will be raised.
     """
     def _check():
         remote = request_remote(device_key, tracker_addr, priority)
@@ -281,11 +281,11 @@ def rpc(key,
         results: List of MeasureResult
             The results for input_pack
         """
-        remote = request_remote(key, (host, port), priority, session_timeout)
+        remote_args = (key, (host, port), priority, session_timeout)
 
         res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
                               ref_input, ref_output,
-                              remote)
+                              remote_args)
         return res
 
     fmeasure.pack_size = pack_size
@@ -294,7 +294,7 @@ def rpc(key,
 
 
 def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
-                    ref_input=None, ref_output=None, remote=None):
+                    ref_input=None, ref_output=None, remote_args=None):
     """Measure the time cost for a pack of inputs.
 
     (Note: A pack is a list of inputs which will be measured inside a same RPC session)
@@ -318,8 +318,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
         Reference input for checking correctness
     ref_output: Array of np.ndarray, optional
         Reference output for checking correctness
-    remote: RPCSession, optional
-        The remote RPC session
+    remote_args: Tuple, optional
+        The arguments to request_remote. If is not None, will use remote rpc devices.
 
     Returns
     -------
@@ -327,7 +327,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
         The list of results of measurement.
     """
     res_pack = []
-    tmp_dir = util.tempdir() if remote else None
+    tmp_dir = util.tempdir() if remote_args else None
+    assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
 
     for inp in input_pack:
         tic = time.time()
@@ -360,31 +361,36 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
                                           tstamp - tic, tstamp))
             continue
 
-        # upload built module
-        if remote:
-            remote.upload(tmp_dir.relpath(filename))
-            func = remote.load_module(filename)
-            ctx = remote.context(str(inp.target), 0)
-            time_f = func.time_evaluator(
-                func.entry_name, ctx, number=number, repeat=repeat)
-        else:
-            ctx = context(str(inp.target), 0)
-            time_f = func.time_evaluator(
-                func.entry_name, ctx, number=number, repeat=repeat)
-
         # measure time
         errno = MeasureErrorNo.NO_ERROR
         try:
+            # upload built module
+            if remote_args:
+                remote = request_remote(*remote_args)
+                remote.upload(tmp_dir.relpath(filename))
+                func = remote.load_module(filename)
+                ctx = remote.context(str(inp.target), 0)
+                time_f = func.time_evaluator(
+                    func.entry_name, ctx, number=number, repeat=repeat)
+            else:
+                ctx = context(str(inp.target), 0)
+                time_f = func.time_evaluator(
+                    func.entry_name, ctx, number=number, repeat=repeat)
+
+            # set input
             if ref_input:
                 args = [nd.array(x, ctx=ctx) for x in ref_input]
             else:
                 args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
                         for x in arg_bufs]
+
             costs = time_f(*args).results
             if len(costs) > 2:  # remove largest and smallest value to reduce variance
                 costs = list(costs)
                 costs.sort()
                 costs = tuple(costs[1:-1])
+
+            # check correctness of output
             if ref_output:
                 for expected, real in zip(ref_output, args):
                     if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py
index 0d43f9265..7592fc5af 100644
--- a/python/tvm/autotvm/task/__init__.py
+++ b/python/tvm/autotvm/task/__init__.py
@@ -9,7 +9,7 @@ of typical tasks of interest.
 from .task import Task, create, register, template, get_config, args_to_workload
 from .space import ConfigSpace, ConfigEntity
 from .code_hash import attach_code_hash, attach_code_hash_to_arg
-from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, dispatcher
+from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, FallbackContext, dispatcher
 
 from .topi_integration import register_topi_compute, register_topi_schedule
 from .nnvm_integration import extract_from_graph
diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py
index 93f6d584a..ec1dcc44f 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -21,7 +21,7 @@ import numpy as np
 
 from tvm import target as _target
 
-from .space import ConfigSpace
+from .space import FallbackConfigEntity
 
 logger = logging.getLogger('autotvm')
 
@@ -34,9 +34,36 @@ class DispatchContext(object):
     """
     current = None
 
+    def __init__(self):
+        self._old_ctx = DispatchContext.current
+
     def query(self, target, workload):
         """
-        Query the context to get the specific implementation.
+        Query the context to get the specific config for a template.
+        If cannot find the result inside this context, this function will query it
+        from the upper contexts.
+
+        Parameters
+        ----------
+        target: Target
+            The current target
+        workload : Workload
+            The current workload.
+
+        Returns
+        -------
+        cfg : ConfigSpace
+            The specific configuration.
+        """
+        ret = self._query_inside(target, workload)
+        if ret is None:
+            ret = self._old_ctx.query(target, workload)
+        return ret
+
+    def _query_inside(self, target, workload):
+        """
+        Query the context to get the specific config for a template.
+        This function only query config inside this context.
 
         Parameters
         ----------
@@ -117,17 +144,17 @@ def dispatcher(fworkload):
     def dispatch_func(func, *args, **kwargs):
         """The wrapped dispatch function"""
         tgt = _target.current_target()
-        context = DispatchContext.current
-        if context is None:
-            raise RuntimeError("DispatchContext is not initialized")
         workload = func(*args, **kwargs)
-        cfg = context.query(tgt, workload)
-        if cfg.template_key:
-            return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
-        else:
-            assert dispatch_dict, "No func registered for this dispatcher"
+        cfg = DispatchContext.current.query(tgt, workload)
+        if cfg.is_fallback and not cfg.template_key:
+            # first try 'direct' template
+            if 'direct' in dispatch_dict:
+                return dispatch_dict['direct'](cfg, *args, **kwargs)
+            # otherwise pick a random template
             for v in dispatch_dict.values():
                 return v(cfg, *args, **kwargs)
+        else:
+            return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
 
     fdecorate = decorate(fworkload, dispatch_func)
     fdecorate.register = register
@@ -135,7 +162,7 @@ def dispatcher(fworkload):
 
 
 class ApplyConfig(DispatchContext):
-    """Apply a specific config entity during query.
+    """Apply a deterministic config entity for all queries.
 
     Parameters
     ----------
@@ -147,7 +174,7 @@ class ApplyConfig(DispatchContext):
         self._config = config
         self.workload = None
 
-    def query(self, target, workload):
+    def _query_inside(self, target, workload):
         """Override query"""
         self.workload = workload
         return self._config
@@ -164,20 +191,12 @@ class ApplyHistoryBest(DispatchContext):
         If is str, then it should be the filename of a records log file.
                    Each row of this file is an encoded record pair.
         Otherwise, it is an iterator.
-    default: ConfigEntity, optional
-        The default config to return when no history records
-    allow_fallback: bool
-        Whether allow to use a fallback configuration if cannot find
-        tuned result.
     """
-    def __init__(self, records, default=None, allow_fallback=False):
+    def __init__(self, records):
         super(ApplyHistoryBest, self).__init__()
 
         self.best_by_targetkey = {}
         self.best_by_model = {}
-        self._default = default
-        self._allow_fallback = allow_fallback
-        self.fallback = {}
 
         if records:
             self.load(records)
@@ -234,7 +253,7 @@ class ApplyHistoryBest(DispatchContext):
 
         logger.debug("Finish loading %d records", counter)
 
-    def query(self, target, workload):
+    def _query_inside(self, target, workload):
         if target is None:
             raise RuntimeError("Need a target context to find the history best. "
                                "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
@@ -254,20 +273,50 @@ class ApplyHistoryBest(DispatchContext):
             if key in self.best_by_targetkey:
                 return self.best_by_targetkey[key][0].config
 
-        if self._default:
-            return self._default
+        return None
+
+
+class FallbackContext(DispatchContext):
+    """
+    A fallback dispatch context.
+
+    Any tunable template can be called under this context.
+    This is the root context.
+    """
+
+    def __init__(self):
+        super(FallbackContext, self).__init__()
+        self.memory = {}
+        self.silent = False
+
+    def _query_inside(self, target, workload):
+        key = (str(target), workload)
+        if key in self.memory:
+            return self.memory[key]
 
-        if self._allow_fallback:
-            key = (target, workload)
-            if key in self.fallback:
-                return self.fallback[key]
+        if not self.silent:
             logger.warning(
                 "Cannot find config for target=%s, workload=%s. A fallback configuration "
                 "is used, which may bring great performance regression.", target, workload)
-            cfg = ConfigSpace()
-            self.fallback[key] = cfg
-            return cfg
+        cfg = FallbackConfigEntity()
+
+        # cache this config
+        self.memory[key] = cfg
+        return cfg
+
+    def clear_cache(self, target, workload):
+        """Clear fallback cache. Pass the same argument as _query_inside to this function
+        to clean the cache.
+
+        Parameters
+        ----------
+        target: Target
+            The current target
+        workload : Workload
+            The current workload.
+        """
+        key = (str(target), workload)
+        if key in self.memory:
+            del self.memory[key]
 
-        raise RuntimeError(
-            "Cannot find config for target=%s, workload=%s. You need to do tuning "
-            "for this workload to get the config." % (target, workload))
+DispatchContext.current = FallbackContext()
diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py
index 1b50869fc..9138cc288 100644
--- a/python/tvm/autotvm/task/nnvm_integration.py
+++ b/python/tvm/autotvm/task/nnvm_integration.py
@@ -7,11 +7,10 @@ import warnings
 import logging
 
 
-from ... import tensor, placeholder, target as _target
+from ... import tensor, placeholder, create_schedule, target as _target
 
 from ..util import get_const_tuple
 from .task import create, register
-from .dispatcher import ApplyHistoryBest
 
 logger = logging.getLogger('autotvm')
 
@@ -56,40 +55,68 @@ class TaskExtractEnv:
         import topi
         import nnvm
 
+        # NOTE: To add more symbols, you only need to change the following lists
+        # nnvm symbol -> topi compute
         self.symbol2topi = {
             nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
-            nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose],
+            nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
+            nnvm.sym.dense: [topi.nn.dense],
         }
 
+        # topi compute -> autotvm task name
         self.topi_to_task = {
             topi.nn.conv2d: "topi_nn_conv2d",
             topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
             topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
+            topi.nn.dense: "topi_nn_dense",
         }
 
-        self._register_dummy()
+        self.topi_to_schedule = {
+            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
+                             topi.generic.schedule_conv2d_nhwc],
+            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
+                                            topi.generic.schedule_depthwise_conv2d_nhwc],
+            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
+            topi.nn.dense: [topi.generic.schedule_dense],
+        }
+
+        self._register_tracing()
         self._register_topi_task()
         self.task_collection = []
+        self.wanted_topi_funcs = list(self.topi_to_task.keys())
+
+    def _register_tracing(self):
+        """Register tracing function to track the topi function call"""
+        # register topi compute for "tracing" target
+        for topi_compute in self.topi_to_task:
+            def _local_scope(compute_func):
+                """start a scope to hold the local function in for loop"""
 
-    def _register_dummy(self):
-        """Register dummy function to track the topi function call"""
-        for func in self.topi_to_task:
-            def _local_scope(local_func):
-                """build a scope to holds the function"""
-                @local_func.register("dummy", )
-                def _dummy_func(*args, **kwargs):
+                @compute_func.register("tracing", )
+                def _tracing_topi_compute(*args, **kwargs):
                     assert not kwargs, "Do not support extracting tuning tasks when" \
                                        "kwargs is used in TOPI function call." \
                                        "Please modify it to use only positional args."
 
-                    if (self.topi_to_task[local_func], serialize_args(args)) \
-                            not in self.task_collection:
-                        self.task_collection.append((self.topi_to_task[local_func],
-                                                     serialize_args(args)))
-                    with _target.create("opencl"):
-                        return local_func(*args)
+                    if compute_func in self.wanted_topi_funcs:  # record this call
+                        key = (self.topi_to_task[compute_func], serialize_args(args))
+                        if key not in self.task_collection:
+                            self.task_collection.append(key)
+
+                    return compute_func.fdefault(*args)
+            _local_scope(topi_compute)
+
+        # register topi schedule for "tracing" target
+        for topi_compute in self.topi_to_task:
+            for topi_schedule in self.topi_to_schedule[topi_compute]:
+                def _local_scope_(schedule_func):
+                    """start a scope to hold the local function in for loop"""
 
-            _local_scope(func)
+                    @schedule_func.register("tracing", )
+                    def _tracing_topi_compute(outs):
+                        outs = [outs] if isinstance(outs, tensor.Tensor) else outs
+                        return create_schedule([x.op for x in outs])
+                _local_scope_(topi_schedule)
 
     def _register_topi_task(self):
         """register tuning wrapper for topi function"""
@@ -125,17 +152,47 @@ class TaskExtractEnv:
             s = topi.generic.schedule_conv2d_transpose_nchw([C])
             return s, [A, W, C]
 
-    def reset(self):
-        """Reset task collections"""
+        @register("topi_nn_dense")
+        def _topi_nn_dense(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            data, weight, bias = args
+            C = topi.nn.dense(*args, **kwargs)
+            s = topi.generic.schedule_dense([C])
+            if bias is not None:
+                return s, [data, weight, bias, C]
+            return s, [data, weight, C]
+
+    def reset(self, wanted_topi_funcs):
+        """Reset task collections
+
+        Parameters
+        ----------
+        wanted_topi_funcs: List of function
+            The topi function to be extracted
+        """
         self.task_collection = []
+        self.wanted_topi_funcs = wanted_topi_funcs
 
     def get_tasks(self):
-        """Get collected tasks"""
+        """Get collected tasks
+
+        Returns
+        -------
+        tasks: List of tuple(name, args)
+            A list of tasks extracted from the nnvm graph
+        """
         return self.task_collection
 
     @staticmethod
     def get():
-        """Get the single instance of TaskExtractEnv"""
+        """Get the single instance of TaskExtractEnv
+
+        Returns
+        -------
+        env: TaskExtractEnv
+            The single instance of TaskExtractEnv
+        """
         if not TaskExtractEnv.current:
             TaskExtractEnv.current = TaskExtractEnv()
         return TaskExtractEnv.current
@@ -144,8 +201,8 @@ class TaskExtractEnv:
 def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
     """ Extract tuning tasks from a nnvm graph.
 
-    This function collects tunning tasks by building the graph
-    with a "dummy" target and tracing all the calls to topi.
+    This function collects tuning tasks by building the graph
+    with a "tracing" target and tracing all the calls to topi.
 
     Parameters
     ----------
@@ -158,7 +215,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
     target: tvm.target.Target
         The compilation target
     symbols : Array of nnvm.symbol
-        Array of nnvm symbols
+        Array of nnvm symbols want to be tuned
     target_host: tvm.target.Target
         The host compilation target
 
@@ -179,16 +236,16 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
             warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
 
     # run compiler to collect all TOPI calls during compilation
-    env.reset()
+    env.reset(topi_funcs)
 
     # disable logger temporarily
     old_state = logger.disabled
     logger.disabled = True
 
-    # use a dummy target to do a fake compile for collecting topi calls
-    dummy_target = _target.create("opencl -device=dummy")
-    with ApplyHistoryBest([], allow_fallback=True):
-        nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
+    # use a "tracing" target to do a fake compile for collecting topi calls
+    tracing_target = _target.create("llvm -device=tracing")
+    nnvm.compiler.engine.clear_cache()
+    nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
 
     logger.disabled = old_state
 
diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py
index ea823c6f2..5a34353ac 100644
--- a/python/tvm/autotvm/task/space.py
+++ b/python/tvm/autotvm/task/space.py
@@ -567,15 +567,16 @@ class ConfigSpace(object):
     """
     def __init__(self):
         # private dict to provide sugar
-        self.space_map = OrderedDict()  # name -> space
+        self.space_map = OrderedDict()    # name -> space
         self._collect = True
         self._length = None
-        self._entity_map = OrderedDict()
+        self._entity_map = OrderedDict()  # name -> entity
         self._constraints = []
         self.errors = []
         self.template_key = None
         self.code_hash = None
         self.flop = 0
+        self.is_fallback = False
 
     @staticmethod
     def axis(var):
@@ -607,6 +608,15 @@ class ConfigSpace(object):
             If is 'candidate', try listed candidate.
         kwargs: dict
             extra arguments for policy
+            see examples below for how to use filter
+
+        Examples
+        --------
+        >>> # use custom candidates
+        >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
+
+        >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
+        >>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4)
         """
         axes = [axis]
         return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
@@ -889,3 +899,45 @@ class ConfigEntity(ConfigSpace):
     def __repr__(self):
         return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
                                 self.code_hash, self.index)
+
+class FallbackConfigEntity(ConfigSpace):
+    """The config entity created to support fallback"""
+
+    def __init__(self):
+        super(FallbackConfigEntity, self).__init__()
+        self.is_fallback = True
+
+    def fallback_split(self, name, constraints):
+        """Fallback a split knob
+
+        Parameters
+        ----------
+        name: str
+            name of the knob
+        constraints: List of int
+            The maximum tile size for every dimension. Value `-1` means no constraint.
+
+        Examples
+        --------
+        If you use cfg.define_split('tile_0', 128, num_outputs=3),
+        Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4]
+
+        If you use cfg.define_split('tile_0', 49, num_outputs=3),
+        Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1]
+        """
+        space = self.space_map[name]
+        assert len(constraints) == space.num_outputs
+        indices = np.arange(space.num_outputs)
+
+        # '-1' means no constraint
+        constraints = [x if x != -1 else 1e10 for x in constraints]
+
+        for entity in reversed(space.entities):
+            if all([entity.size[i] <= constraints[i] for i in indices]):
+                self._entity_map[name] = entity
+                return
+
+        raise RuntimeError("Cannot find feasible fallback split entity for node: " + name)
+
+    def __repr__(self):
+        return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index f8923fca5..ab52788c8 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -206,7 +206,7 @@ def args_to_workload(x):
     elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
         return x.value
     elif x is None:
-        return None
+        return 0
     else:
         raise RuntimeError('Do not support type "%s" in argument. Consider to use'
                            'primitive types only' % type(x))
diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py
index e11bb7a4f..3d7b249df 100644
--- a/python/tvm/autotvm/tophub.py
+++ b/python/tvm/autotvm/tophub.py
@@ -28,7 +28,7 @@ def _alias(name):
     return table.get(name, name)
 
 
-def context(target, extra_files=None, allow_fallback=False):
+def context(target, extra_files=None):
     """Return the dispatch context with pre-tuned parameters.
     The corresponding downloaded *.log files under tophub root path will be loaded.
     Users can also add their own files in argument `extra_files`.
@@ -39,12 +39,9 @@ def context(target, extra_files=None, allow_fallback=False):
         The compilation target
     extra_files: list of str, optional
         Extra log files to load
-    allow_fallback: bool
-        Whether allow to use a fallback configuration if cannot find
-        tuned result.
     """
     rootpath = AUTOTVM_TOPHUB_ROOT_PATH
-    best_context = ApplyHistoryBest([], allow_fallback=allow_fallback)
+    best_context = ApplyHistoryBest([])
 
     if isinstance(target, str):
         target = _target.create(target)
diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py
index b92737ed5..b9d900e49 100644
--- a/python/tvm/autotvm/tuner/ga_tuner.py
+++ b/python/tvm/autotvm/tuner/ga_tuner.py
@@ -86,13 +86,9 @@ class GATuner(Tuner):
 
             # cross over
             indices = np.arange(len(genes))
-            max_score = np.max(scores)
-            if max_score < 1e-8:
-                probs = np.empty_like(scores)
-                probs[:] = 1.0 / len(scores)
-            else:
-                scores /= max_score
-                probs = scores / np.sum(scores)
+            scores += 1e-8
+            scores /= np.max(scores)
+            probs = scores / np.sum(scores)
             tmp_genes = []
             for _ in range(self.pop_size):
                 p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs)
diff --git a/python/tvm/autotvm/tuner/model_based_tuner.py b/python/tvm/autotvm/tuner/model_based_tuner.py
index d1c1b16d3..62fc57f2e 100644
--- a/python/tvm/autotvm/tuner/model_based_tuner.py
+++ b/python/tvm/autotvm/tuner/model_based_tuner.py
@@ -8,7 +8,7 @@ import gc
 import numpy as np
 
 from .tuner import Tuner
-
+from ..env import GLOBAL_SCOPE
 
 class FeatureCache(object):
     """Feature cache manager for cache sharing between different cost models"""
@@ -119,11 +119,9 @@ class CostModel(object):
         """
         raise NotImplementedError()
 
-    def clone_new(self):
-        """Clone a new model with the same parameters.
-        This function will only copy hyperparameters of the tuner, not all the trained model
-
-        This is used for deriving a base model conveniently
+    def spawn_base_model(self):
+        """Clone a base model with the same parameters.
+        The base model is used to fit history data in transfer learning.
 
         Returns
         -------
@@ -221,7 +219,9 @@ class ModelBasedTuner(Tuner):
                     break
                 self.trial_pt += 1
 
-            if self.trial_pt >= len(self.trials):  # trial list is empty, choose randomly
+            if self.trial_pt >= len(self.trials) - int(0.05 * self.plan_size):
+                # if the trial list is empty or
+                # the tuner is doing the last 5% trials (e-greedy), choose randomly
                 index = np.random.randint(len(self.space))
                 while index in self.visited:
                     index = np.random.randint(len(self.space))
@@ -264,18 +264,16 @@ class ModelBasedTuner(Tuner):
             self.train_ct += 1
 
     def load_history(self, data_set):
-        # filter data, only pick the data with a same task
-        data = []
-        for inp, res in data_set:
-            if inp.task.name == self.task.name and \
-                            inp.config.template_key == self.task.config_space.template_key:
-                data.append((inp, res))
-        if not data:
-            return
+        # set in_tuning as True to make the feature extraction consistent
+        GLOBAL_SCOPE.in_tuning = True
 
         # fit base model
-        base_model = self.cost_model.clone_new()
-        base_model.fit_log(data, self.plan_size)
+        base_model = self.cost_model.spawn_base_model()
+        success = base_model.fit_log(data_set, self.plan_size)
+
+        if not success:
+            GLOBAL_SCOPE.in_tuning = False
+            return
 
         # use base model to select initial points
         if not self.trials:
@@ -285,6 +283,7 @@ class ModelBasedTuner(Tuner):
             self.trial_pt = 0
 
         self.cost_model.load_basemodel(base_model)
+        GLOBAL_SCOPE.in_tuning = False
 
     def has_next(self):
         return len(self.visited) < len(self.space)
diff --git a/python/tvm/autotvm/tuner/sa_model_optimizer.py b/python/tvm/autotvm/tuner/sa_model_optimizer.py
index 6e1c373c1..1947c6dde 100644
--- a/python/tvm/autotvm/tuner/sa_model_optimizer.py
+++ b/python/tvm/autotvm/tuner/sa_model_optimizer.py
@@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
 
             new_scores = model.predict(new_points)
 
-            ac_prob = np.exp((new_scores - scores) / t)
+            ac_prob = np.exp((new_scores - scores) / (t + 1e-2))
             ac_index = np.random.random(len(ac_prob)) < ac_prob
 
             points[ac_index] = new_points[ac_index]
diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py
index 91004cba4..cffbb9798 100644
--- a/python/tvm/autotvm/tuner/tuner.py
+++ b/python/tvm/autotvm/tuner/tuner.py
@@ -31,6 +31,10 @@ class Tuner(object):
         self.best_measure_pair = None
         self.best_iter = 0
 
+        # time to leave
+        self.ttl = None
+        self.n_trial = None
+
     def has_next(self):
         """Whether has next untried config in the space
 
@@ -76,7 +80,7 @@ class Tuner(object):
         measure_option: dict
             The options for how to measure generated code.
             You should use the return value ot autotvm.measure_option for this argument.
-        early_stopping: int
+        early_stopping: int, optional
             Early stop the tuning when not finding better configs in this number of trials
         callbacks: List of callable
             A list of callback functions. The signature of callback function is
@@ -87,6 +91,8 @@ class Tuner(object):
         measure_batch = create_measure_batch(self.task, measure_option)
         n_parallel = getattr(measure_batch, 'n_parallel', 1)
         early_stopping = early_stopping or 1e9
+        self.n_trial = n_trial
+
         old_level = logger.level
 
         GLOBAL_SCOPE.in_tuning = True
@@ -127,11 +133,12 @@ class Tuner(object):
             for callback in callbacks:
                 callback(self, inputs, results)
 
-            if i > self.best_iter + early_stopping:
+            self.ttl = min(early_stopping + self.best_iter, n_trial) - i
+            if i >= self.best_iter + early_stopping:
                 logger.debug("Early stopped. Best iter: %d.", self.best_iter)
                 break
 
-            if error_ct > 50:
+            if error_ct > 150:
                 logger.warning("Too many errors happen in the tuning. Now is in debug mode")
                 logger.setLevel(logging.DEBUG)
             else:
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index 178e92476..bda3ee26e 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -31,8 +31,12 @@ class XGBoostCostModel(CostModel):
         If is 'curve', use sampled curve feature (relation feature).
 
         Note on choosing feature type:
-        For single task tuning, 'itervar' and 'knob' is good.
+        For single task tuning, 'itervar' and 'knob' are good.
                                 'itervar' is more accurate but 'knob' is much faster.
+                                There are some constraints on 'itervar', if you meet
+                                problems with feature extraction when using 'itervar',
+                                you can swith to 'knob'.
+
         For cross-shape tuning (e.g. many convolutions with different shapes),
                                'itervar' and 'curve' has better transferability,
                                'knob' is faster.
@@ -46,8 +50,11 @@ class XGBoostCostModel(CostModel):
         The number of threads.
     log_interval: int, optional
         If is not none, the cost model will print training log every `log_interval` iterations.
+    upper_model: XGBoostCostModel, optional
+        The upper model used in transfer learning
     """
-    def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25):
+    def __init__(self, task, feature_type, loss_type, num_threads=4, log_interval=25,
+                 upper_model=None):
         super(XGBoostCostModel, self).__init__()
 
         if xgb is None:
@@ -109,35 +116,51 @@ class XGBoostCostModel(CostModel):
         else:
             raise RuntimeError("Invalid feature type " + feature_type)
 
-        self.feature_cache = FeatureCache()
+        if upper_model:  # share a same feature cache with upper model
+            self.feature_cache = upper_model.feature_cache
+        else:
+            self.feature_cache = FeatureCache()
+        self.upper_model = upper_model
         self.feature_extra_ct = 0
         self.pool = None
         self.base_model = None
-        self.upper_model = None
 
         self._sample_size = 0
+        self._reset_pool(self.space, self.target, self.task)
 
-        self._reset_pool()
+    def _reset_pool(self, space, target, task):
+        """reset processing pool for feature extraction"""
+
+        if self.upper_model:  # base model will reuse upper model's pool,
+            self.upper_model._reset_pool(space, target, task)
+            return
+
+        self._close_pool()
 
-    def _reset_pool(self):
-        # reset processing pool for feature extraction
-        if self.pool:
-            self.pool.terminate()
-            self.pool.join()
-            del self.pool
         # use global variable to pass common arguments
         global _extract_space, _extract_target, _extract_task
-        _extract_space = self.space
-        _extract_target = self.target
-        _extract_task = self.task
+        _extract_space = space
+        _extract_target = target
+        _extract_task = task
         self.pool = multiprocessing.Pool(self.num_threads)
 
+    def _close_pool(self):
+        if self.pool:
+            self.pool.terminate()
+            self.pool.join()
+            self.pool = None
+
+    def _get_pool(self):
+        if self.upper_model:
+            return self.upper_model._get_pool()
+        return self.pool
+
     def _base_model_discount(self):
-        return 1.0 / (2 ** (self._sample_size / 50.0))
+        return 1.0 / (2 ** (self._sample_size / 64.0))
 
     def fit(self, xs, ys, plan_size):
         tic = time.time()
-        self._reset_pool()
+        self._reset_pool(self.space, self.target, self.task)
 
         x_train = self._get_feature(xs)
         y_train = np.array(ys)
@@ -150,8 +173,12 @@ class XGBoostCostModel(CostModel):
         self._sample_size = len(x_train)
 
         if self.base_model:
-            dtrain.set_base_margin(self._base_model_discount() *
-                                   self.base_model.predict(xs, output_margin=True))
+            discount = self._base_model_discount()
+            if discount < 0.05:  # discard base model
+                self.base_model.upper_model = None
+                self.base_model = None
+            else:
+                dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True))
 
         self.bst = xgb.train(self.xgb_params, dtrain,
                              num_boost_round=8000,
@@ -172,11 +199,19 @@ class XGBoostCostModel(CostModel):
 
     def fit_log(self, records, plan_size):
         tic = time.time()
-        self._reset_pool()
 
-        args = list(records)
-        logger.debug("XGB load %d entries from history log file", len(args))
+        # filter data, only pick the data with a same task
+        data = []
+        for inp, res in records:
+            if inp.task.name == self.task.name and \
+                            inp.config.template_key == self.task.config_space.template_key:
+                data.append((inp, res))
+
+        logger.debug("XGB load %d entries from history log file", len(data))
 
+        # extract feature
+        self._reset_pool(self.space, self.target, self.task)
+        pool = self._get_pool()
         if self.fea_type == 'itervar':
             feature_extract_func = _extract_itervar_feature_log
         elif self.fea_type == 'knob':
@@ -185,10 +220,21 @@ class XGBoostCostModel(CostModel):
             feature_extract_func = _extract_curve_feature_log
         else:
             raise RuntimeError("Invalid feature type: " + self.fea_type)
-        res = self.pool.map(feature_extract_func, args)
-        xs, ys = zip(*res)
-        xs, ys = np.array(xs), np.array(ys)
+        res = pool.map(feature_extract_func, data)
+
+        # filter out feature with different shapes
+        fea_len = len(self._get_feature([0])[0])
+
+        xs, ys = [], []
+        for x, y in res:
+            if len(x) == fea_len:
+                xs.append(x)
+                ys.append(y)
 
+        if len(xs) < 500:  # no enough samples
+            return False
+
+        xs, ys = np.array(xs), np.array(ys)
         x_train = xs
         y_train = ys
         y_max = np.max(y_train)
@@ -212,6 +258,8 @@ class XGBoostCostModel(CostModel):
 
         logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
 
+        return True
+
     def predict(self, xs, output_margin=False):
         feas = self._get_feature(xs)
         dtest = xgb.DMatrix(feas)
@@ -224,20 +272,12 @@ class XGBoostCostModel(CostModel):
 
     def load_basemodel(self, base_model):
         self.base_model = base_model
-        if isinstance(base_model, XGBoostCostModel):
-            # share feature cache
-            base_model.feature_cache = self.feature_cache
-
-            # close thread pool
-            if base_model.pool:
-                base_model.pool.terminate()
-                base_model.pool.join()
-                del base_model.pool
-            self.base_model.upper_model = self
-
-    def clone_new(self):
+        self.base_model._close_pool()
+        self.base_model.upper_model = self
+
+    def spawn_base_model(self):
         return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
-                                self.num_threads, self.log_interval)
+                                self.num_threads, self.log_interval, self)
 
     def _get_feature(self, indexes):
         """get features for indexes, run extraction if we do not have cache for them"""
@@ -251,7 +291,7 @@ class XGBoostCostModel(CostModel):
         need_extract = [x for x in indexes if x not in fea_cache]
 
         if need_extract:
-            pool = self.pool if self.upper_model is None else self.upper_model.pool
+            pool = self._get_pool()
             feas = pool.map(self.feature_extract_func, need_extract)
             for i, fea in zip(need_extract, feas):
                 fea_cache[i] = fea
@@ -261,6 +301,9 @@ class XGBoostCostModel(CostModel):
             ret[i, :] = fea_cache[ii]
         return ret
 
+    def __del__(self):
+        self._close_pool()
+
 
 _extract_space = None
 _extract_target = None
diff --git a/python/tvm/autotvm/tuner/xgboost_tuner.py b/python/tvm/autotvm/tuner/xgboost_tuner.py
index 237ac4e19..886c82a4d 100644
--- a/python/tvm/autotvm/tuner/xgboost_tuner.py
+++ b/python/tvm/autotvm/tuner/xgboost_tuner.py
@@ -20,8 +20,12 @@ class XGBTuner(ModelBasedTuner):
         If is 'curve', use sampled curve feature (relation feature).
 
         Note on choosing feature type:
-        For single task tuning, 'itervar' and 'knob' is good.
+        For single task tuning, 'itervar' and 'knob' are good.
                                 'itervar' is more accurate but 'knob' is much faster.
+                                There are some constraints on 'itervar', if you meet
+                                problems with feature extraction when using 'itervar',
+                                you can swith to 'knob'.
+
         For cross-shape tuning (e.g. many convolutions with different shapes),
                                'itervar' and 'curve' has better transferability,
                                'knob' is faster.
@@ -32,8 +36,7 @@ class XGBTuner(ModelBasedTuner):
         If is 'rank', use pairwise rank loss to train cost model.
                      The cost model predicts relative rank score.
     num_threads: int, optional
-        The number of threads.
-    optimizer: str or ModelOptimizer, optional
+        The number of threads.  optimizer: str or ModelOptimizer, optional
         If is 'sa', use a default simulated annealing optimizer.
         Otherwise it should be a ModelOptimizer object.
     diversity_filter_ratio: int or float, optional
@@ -45,7 +48,7 @@ class XGBTuner(ModelBasedTuner):
         If is 0, output nothing.
         Otherwise, output debug information every `verbose` iterations.
     """
-    def __init__(self, task, plan_size=32,
+    def __init__(self, task, plan_size=64,
                  feature_type='itervar', loss_type='rank', num_threads=None,
                  optimizer='sa', diversity_filter_ratio=None, log_interval=50):
         cost_model = XGBoostCostModel(task,
@@ -62,3 +65,9 @@ class XGBTuner(ModelBasedTuner):
 
         super(XGBTuner, self).__init__(task, cost_model, optimizer,
                                        plan_size, diversity_filter_ratio)
+
+    def tune(self, *args, **kwargs):  # pylint: disable=arguments-differ
+        super(XGBTuner, self).tune(*args, **kwargs)
+
+        # manually close pool to avoid multiprocessing issues
+        self.cost_model._close_pool()
diff --git a/python/tvm/exec/tophub.py b/python/tvm/exec/tophub.py
index 9dd951a52..9bfd68665 100644
--- a/python/tvm/exec/tophub.py
+++ b/python/tvm/exec/tophub.py
@@ -8,8 +8,8 @@ from ..autotvm.tophub import list_packages, download_package
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument("--download", type=str, nargs='+',
-                        help="Target to download. Use 'all' to download for all targets")
+    parser.add_argument("-d", "--download", type=str, nargs='+',
+                        help="The targets to download. Use 'all' to download for all targets")
     parser.add_argument("-l", "--list", action='store_true', help="List available packages")
     args = parser.parse_args()
 
@@ -21,8 +21,7 @@ if __name__ == '__main__':
         print("-" * 41)
         for target, info in info:
             print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000)))
-
-    if args.download:
+    elif args.download:
         info = list_packages()
         all_targets = [x[0] for x in info]
         if 'all' in args.download:
@@ -34,3 +33,5 @@ if __name__ == '__main__':
             if t not in all_targets:
                 print("Warning : cannot find tuned parameters of " + t + ". (ignored)")
             download_package(t)
+    else:
+        parser.print_help()
diff --git a/python/tvm/target.py b/python/tvm/target.py
index e2d780f75..9d5200661 100644
--- a/python/tvm/target.py
+++ b/python/tvm/target.py
@@ -263,6 +263,7 @@ def override_native_generic_func(func_name):
                     "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
             return generic_func_node(*args)
         fresult = decorate(fdefault, dispatch_func)
+        fresult.fdefault = fdefault
         fresult.register = register
         return fresult
     return fdecorate
diff --git a/tests/python/unittest/test_autotvm_dispatch_context.py b/tests/python/unittest/test_autotvm_dispatch_context.py
index 6c718e5bd..1f2a7e276 100644
--- a/tests/python/unittest/test_autotvm_dispatch_context.py
+++ b/tests/python/unittest/test_autotvm_dispatch_context.py
@@ -3,34 +3,48 @@ The dispatcher can choose which template to use according
 to the parameters of workload"""
 
 from collections import namedtuple
+from tvm import autotvm
 from tvm.autotvm.task import dispatcher, DispatchContext
 
-SimpleWorkload = namedtuple("SimpleWorkload", ["key"])
-SimpleConfig = namedtuple("SimpleConfig", ["template_key"])
+SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback'))
 
 def test_dispatch():
     @dispatcher
     def my_dispatcher(a, b):
-        return SimpleWorkload(key=a + b)
-
-    @my_dispatcher.register("spatial_pack")
-    def _sp_pack_add(cfg, a, b):
-        return b + 100
+        return (a, b)
 
     @my_dispatcher.register("im2col")
-    def _im2col_add(cfg, a, b):
-        return a + 1
+    def _im2col(cfg, a, b):
+        return a
+
+    @my_dispatcher.register("spatial_pack")
+    def _spatial_pack(cfg, a, b):
+        return b
 
     class SimpleDispatcher(DispatchContext):
         def query(self, target, workload):
-            tkey = "spatial_pack" if workload.key > 2 else "im2col"
-            return SimpleConfig(tkey)
+            a, b = workload
+            tkey = "spatial_pack" if a + b > 2 else "im2col"
+            cfg = SimpleConfig(tkey, False)
+            return cfg
 
     with SimpleDispatcher():
-        # im2col
-        assert my_dispatcher(1, 0) == 2
-        # spack
-        assert my_dispatcher(1, 100) == 200
+        # this will call im2col
+        assert my_dispatcher(1, 0) == 1
+
+        # this will call spatial pack
+        assert my_dispatcher(1, 100) == 100
+
+def test_fallback():
+
+    @autotvm.template
+    def simple_template(a, b):
+        cfg = autotvm.get_config()
+        assert cfg.is_fallback
+
+    simple_template(2, 3)
+
 
 if __name__ == "__main__":
     test_dispatch()
+    test_fallback()
diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py
index 0320ef1c6..e51e34e95 100644
--- a/tests/python/unittest/test_autotvm_space.py
+++ b/tests/python/unittest/test_autotvm_space.py
@@ -1,7 +1,7 @@
 """Test space definition primitives"""
 
 import tvm
-from tvm.autotvm.task.space import ConfigSpace
+from tvm.autotvm.task.space import ConfigSpace, FallbackConfigEntity
 
 def gemm_func(cfg, N):
     A = tvm.placeholder((N, N), name='A')
@@ -26,5 +26,18 @@ def test_split():
     assert len(cfg) == 64
     assert len(cfg.space_map['tile_y']) == 8
 
+    # test fallback
+    cfg = FallbackConfigEntity()
+    cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
+    cfg.fallback_split('tile_n', [-1, 8, 4])
+
+    assert cfg['tile_n'].size == [4, 8, 4]
+
+    cfg = FallbackConfigEntity()
+    cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
+    cfg.fallback_split('tile_n', [-1, 8, 4])
+
+    assert cfg['tile_n'].size == [7, 7, 1]
+
 if __name__ == '__main__':
     test_split()
diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py
index 3488d0f59..58da219f2 100644
--- a/tests/python/unittest/test_autotvm_xgboost_model.py
+++ b/tests/python/unittest/test_autotvm_xgboost_model.py
@@ -12,7 +12,7 @@ from test_autotvm_common import get_sample_task, get_sample_records
 
 def test_fit():
     task, target = get_sample_task()
-    records = get_sample_records(n=100)
+    records = get_sample_records(n=500)
 
     base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
     base_model.fit_log(records, plan_size=32)
@@ -20,8 +20,8 @@ def test_fit():
     upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank')
     upper_model.load_basemodel(base_model)
 
-    xs = np.arange(100)
-    ys = np.arange(100)
+    xs = np.arange(10)
+    ys = np.arange(10)
 
     upper_model.fit(xs, ys, plan_size=32)
 
diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py
index 48bb4fb02..a3945a4c9 100644
--- a/topi/python/topi/arm_cpu/conv2d.py
+++ b/topi/python/topi/arm_cpu/conv2d.py
@@ -27,7 +27,14 @@ def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
 @autotvm.task.dispatcher
 def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype):
     """TOPI compute callback. Mark this function as a dispatcher, so
-    this template can assign config according to workload"""
+    this template can assign config according to workload
+
+    Returns
+    -------
+    workload: Tuple
+        Dispatcher will use this workload to query corresponding config.
+        Then use cfg.template_key to call a registered template.
+    """
     return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
 
 @conv2d_arm_cpu.register(['direct'])
@@ -70,8 +77,10 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
 
 def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
     assert layout == "NCHW", "Only support NCHW"
-    out_dtype = out_dtype or data.dtype
+    # create workload according to raw arguments
+    wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype)
 
+    out_dtype = out_dtype or data.dtype
     N, CI, IH, IW = get_const_tuple(data.shape)
     if len(kernel.shape) == 4:
         pre_packed = False
@@ -113,6 +122,18 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
     cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
     # ====================================================================
 
+    if cfg.is_fallback:
+        if num_tile == 2:
+            cfg.fallback_split('tile_co', [-1, 8])
+            cfg.fallback_split('tile_oh', [-1, 2])
+            cfg.fallback_split('tile_ow', [-1, 8])
+        else:
+            cfg.fallback_split('tile_co', [-1, 16, 4])
+            cfg.fallback_split('tile_oh', [-1, 1, 1])
+            cfg.fallback_split('tile_ow', [-1, 1, 4])
+        cfg['ann_reduce'].anns = ['unroll', 'unroll']
+        cfg['ann_spatial'].anns = ['none', 'unroll', 'vec']
+
     VC = cfg["tile_co"].size[-1]
     VH = cfg["tile_oh"].size[-1]
     VW = cfg["tile_ow"].size[-1]
@@ -145,8 +166,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
     output = tvm.compute(oshape, lambda n, co, h, w:
                          conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
                          name='output_unpack', tag='spatial_conv2d_output',
-                         attrs={'workload': _conv_arg_to_workload(data, kernel, strides, padding,
-                                                                  layout, out_dtype)})
+                         attrs={'workload': wkl})
     return output
 
 def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
@@ -212,6 +232,10 @@ def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype):
     return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size)
 
 def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size):
+    # create workload according to raw arguments
+    wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout,
+                                         out_dtype, tile_size)
+
     N, CI, IH, IW = get_const_tuple(data.shape)
     if len(kernel.shape) == 4:
         pre_computed = False
@@ -333,10 +357,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_
     output = tvm.compute((N, K, H, W), lambda n, k, h, w:
                          Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
                          name='output', tag='winograd_conv2d_output',
-                         attrs={'workload': _winograd_conv_arg_to_workload(
-                             data, kernel, strides, padding, layout, out_dtype, tile_size)})
+                         attrs={'workload': wkl})
 
-    # we have to manually assign effective GFLOP for winogard
+    # we have to manually assign effective GFLOP for winograd
     cfg.add_flop(2 * N * K * H * W * KH * KW * C)
     return output
 
@@ -358,30 +381,29 @@ def _schedule_winograd(cfg, s, output, last):
         kernel, G = U.op.input_tensors
         s[G].compute_inline()
         eps, nu, k, c, kk, = s[U].op.axis
-        r_kh, r_kw = s[U].op.reduce_axis
-        s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
-        s[U].unroll(eps)
-        s[U].unroll(nu)
-        s[U].unroll(r_kh)
-        s[U].unroll(r_kw)
-        s[U].vectorize(kk)
         if autotvm.GLOBAL_SCOPE.in_tuning:
             # kernel transformation will be pre-computed during compilation, so we skip
             # this part to make tuning records correct
-            s[U].pragma(k, 'debug_skip_region')
+            s[U].pragma(eps, 'debug_skip_region')
         else:
+            r_kh, r_kw = s[U].op.reduce_axis
+            s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk)
+            for axis in [eps, nu, r_kh, r_kw]:
+                s[U].unroll(axis)
+            s[U].vectorize(kk)
             s[U].parallel(k)
 
+        if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
+            s[kernel].compute_inline()
+
     # transform image
     DD = s.cache_read(d, 'global', [V])
     s[B].compute_inline()
     eps, nu, b, c, bb = s[V].op.axis
     r_eps, r_nu = s[V].op.reduce_axis
     s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb)
-    s[V].unroll(eps)
-    s[V].unroll(nu)
-    s[V].unroll(r_eps)
-    s[V].unroll(r_nu)
+    for axis in [eps, nu, r_eps, r_nu]:
+        s[V].unroll(axis)
     s[DD].compute_at(s[V], c)
     s[V].vectorize(bb)
     s[V].parallel(b)
@@ -405,10 +427,8 @@ def _schedule_winograd(cfg, s, output, last):
     s[A].compute_inline()
     k, b, vh, vw = s[Y].op.axis
     r_eps, r_nu = s[Y].op.reduce_axis
-    s[Y].unroll(vh)
-    s[Y].unroll(vw)
-    s[Y].unroll(r_eps)
-    s[Y].unroll(r_nu)
+    for axis in [vh, vw, r_eps, r_nu]:
+        s[Y].unroll(axis)
 
     # output
     n, co, h, w = s[last].op.axis
@@ -444,6 +464,7 @@ def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_d
         [data, raw_kernel, strides, padding, layout, out_dtype])
 
 
+##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
 @conv2d_winograd_without_weight_transform.register(['arm_cpu'])
 @autotvm.task.dispatcher
 def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size):
@@ -472,6 +493,7 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
     return s
 
 
+##### REGISTER ALTER OP LAYOUT #####
 @conv2d_alter_layout.register(["arm_cpu", "mali"])
 def _alter_conv2d_layout(attrs, inputs, tinfos):
     """Alter op layout for pre-computing kernel transformation"""
@@ -493,18 +515,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
         # query config of this workload
         workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding,
                                          layout, out_dtype)
-        cfg = autotvm.task.DispatchContext.current.query(tvm.target.current_target(), workload)
+        cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
+
+        if cfg.is_fallback: # if is fallback, clear query cache and return None
+            context = autotvm.DispatchContext.current
+            while not isinstance(context, autotvm.FallbackContext):
+                context = context._old_ctx
+            context.clear_cache(tvm.target.current_target(), workload)
+            return None
 
         if cfg.template_key == 'direct':  # packing weight tensor
             new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
             return sym.conv2d(*copy_inputs, **new_attrs)
         else:  # pre-compute weight transformation in winograd
-            tile_size = 4
+            if "-device=arm_cpu" in tvm.target.current_target().options:
+                tile_size = 4
+                VC = cfg['tile_k'].size[-1]
+            else:
+                from ..mali.conv2d import _pick_tile_size
+                tile_size = _pick_tile_size(tinfos[0], tinfos[1])
+                VC = cfg['tile_bna'].val
 
             weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1],
                                                                   tile_size=tile_size)
             CO, CI, KH, KW = get_const_tuple(tinfos[1].shape)
-            VC = cfg['tile_k'].size[-1]
             weight = sym.reshape(weight,
                                  shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
             weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3])
diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py
index 8aafc4363..e066a1e29 100644
--- a/topi/python/topi/arm_cpu/depthwise_conv2d.py
+++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py
@@ -14,16 +14,21 @@ autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
 
 # register customized schedule for arm cpu.
 @autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct')
-def schedule_depthwise_conv2d_nchw_(cfg, outs):
+def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
     """Schedule depthwise conv2d
 
     Parameters
     ----------
     cfg: ConfigEntity
-        The configuration of this tempalte
+        The configuration of this template
     outs: Array of Tensor
         The computation graph description of depthwise convolution2d
         in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for depthwise_conv2d nchw.
     """
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
@@ -38,6 +43,11 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
         cfg.define_split('tile_h', h, num_outputs=2)
         cfg.define_split('tile_w', w, num_outputs=2)
 
+        if cfg.is_fallback:
+            cfg.fallback_split('tile_c', [-1, 8])
+            cfg.fallback_split('tile_h', [-1, 2])
+            cfg.fallback_split('tile_w', [-1, 8])
+
         # park data to vector form  [n, c, h, w] -> [n, C, h, w, VC]
         A0 = s.cache_read(data_pad, "global", C)
         _, c, h, w = s[A0].op.axis
diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py
index ac552903a..06847bf9f 100644
--- a/topi/python/topi/x86/injective.py
+++ b/topi/python/topi/x86/injective.py
@@ -29,7 +29,7 @@ def schedule_injective(outs):
     elif len(s[x].op.axis) >= 3:
         fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
         s[x].parallel(fused)
-    else:
+    elif len(s[x].op.axis) >= 1:
         s[x].parallel(s[x].op.axis[0])
     return s
 
diff --git a/topi/tests/python/common.py b/topi/tests/python/common.py
new file mode 100644
index 000000000..d992be929
--- /dev/null
+++ b/topi/tests/python/common.py
@@ -0,0 +1,12 @@
+"""Common utility for topi test"""
+
+def get_all_backend():
+    """return all supported target
+
+    Returns
+    -------
+    targets: list
+        A list of all supported targets
+    """
+    return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
+            'llvm -device=arm_cpu']
diff --git a/topi/tests/python/test_topi_bitserial_conv2d.py b/topi/tests/python/test_topi_bitserial_conv2d.py
index 6df18483a..82af0006c 100644
--- a/topi/tests/python/test_topi_bitserial_conv2d.py
+++ b/topi/tests/python/test_topi_bitserial_conv2d.py
@@ -1,11 +1,8 @@
-import os
 import numpy as np
 import tvm
 import topi
 import topi.testing
-from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
-from tvm.contrib import util
 from tvm.contrib.pickle_memoize import memoize
 
 def generate_quantized_np(shape, bits, out_dtype):
@@ -16,23 +13,23 @@ def generate_quantized_np(shape, bits, out_dtype):
 def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, 
     activation_bits, weight_bits, dorefa):
     in_height = in_width = in_size
-    input_type='uint32'
-    out_dtype='int32'
+    input_type = 'uint32'
+    out_dtype = 'int32'
 
     with tvm.target.create('llvm'):
         A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A')
         W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W')
         B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, 
-            out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
+                                     out_dtype=out_dtype, layout="NCHW", dorefa=dorefa)
         s = topi.generic.schedule_bitserial_conv2d_nchw([B])
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
-    dtype = A.dtype
 
+    @memoize("topi.tests.test_topi_bitseral_conv2d_nchw")
     def get_ref_data():
-        a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
-        w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
+        a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
+        w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
         if dorefa:
             w_ = np.copy(w_np).astype(out_dtype)
             for x in np.nditer(w_, op_flags=['readwrite']):
@@ -61,16 +58,16 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
         A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
         W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
         B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, 
-                            layout="NHWC", dorefa=dorefa)
+                                     layout="NHWC", dorefa=dorefa)
         s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
-    dtype = A.dtype
 
+    @memoize("topi.tests.test_topi_bitseral_conv2d_nhwc")
     def get_ref_data():
-        a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type)
-        w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type)
+        a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type)
+        w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type)
         if dorefa:
             w_ = np.copy(w_np).astype(out_dtype)
             for x in np.nditer(w_, op_flags=['readwrite']):
@@ -109,4 +106,4 @@ def test_bitserial_conv2d():
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False)
 
 if __name__ == "__main__":
-    test_bitserial_conv2d()
\ No newline at end of file
+    test_bitserial_conv2d()
diff --git a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py
index 3de954abc..de467818d 100644
--- a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py
+++ b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py
@@ -4,10 +4,6 @@ import numpy as np
 import tvm
 import topi
 import topi.testing
-from topi.util import get_const_tuple
-from tvm.contrib import util
-
-target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
 
 def generate_quantized_np(shape, bits, out_dtype):
     np.random.seed(0)
@@ -17,20 +13,19 @@ def generate_quantized_np(shape, bits, out_dtype):
 
 # Verify that certain special instructions from the tensorize pass exist
 def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, 
-                        activation_bits, weight_bits, dorefa):
+                                 activation_bits, weight_bits, dorefa):
     in_height = in_width = in_size
-    input_type='uint32'
-    out_dtype='int32'
+    input_type = 'uint32'
+    out_dtype = 'int32'
 
     with tvm.target.arm_cpu('rasp3b'):
         A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
         W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
         B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, 
-                            layout="NHWC", dorefa=dorefa)
+                                     layout="NHWC", dorefa=dorefa)
         s = topi.generic.schedule_bitserial_conv2d_nhwc([B])
 
-    
-    func = tvm.build(s, [A, W, B], target)
+    func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b'))
    
     assembly = func.get_source('asm')
     matches = re.findall("vpadal", assembly)
@@ -47,7 +42,6 @@ def test_bitserial_conv2d():
     stride = 1
     pad = 1
 
-
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False)
     verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False)
 
diff --git a/topi/tests/python/test_topi_bnn.py b/topi/tests/python/test_topi_bnn.py
index 90abc68e6..cf9f377e9 100644
--- a/topi/tests/python/test_topi_bnn.py
+++ b/topi/tests/python/test_topi_bnn.py
@@ -28,7 +28,7 @@ def verify_binary_dense(batch, in_dim, out_dim):
         a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype)
         b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype)
         c_np = np.dot(a_np, b_np.T)
-        return (a_np, b_np, c_np)
+        return a_np, b_np, c_np
 
     a_np, b_np, c_np = get_ref_data()
 
diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py
index f888033b3..4ed5b3170 100644
--- a/topi/tests/python/test_topi_broadcast.py
+++ b/topi/tests/python/test_topi_broadcast.py
@@ -1,5 +1,5 @@
 """Test code for broadcasting operators."""
-import os
+from common import get_all_backend
 import numpy as np
 import tvm
 import topi
@@ -8,6 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
     # Build the logic and compile the function
     A = tvm.placeholder(shape=in_shape, name="A")
     B = fbcast(A, out_shape)
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
@@ -21,16 +22,11 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
         out_npy = np.broadcast_to(data_npy, out_shape)
         data_nd = tvm.nd.array(data_npy, ctx)
         out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
-        for _ in range(1):
-            foo(data_nd, out_nd)
+        foo(data_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("vulkan")
-    check_device("opencl")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
-    check_device("nvptx")
+    for target in get_all_backend():
+        check_device(target)
     check_device("sdaccel")
 
 
@@ -45,9 +41,10 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
     B = (tvm.var("B", dtype=dtype) if rhs_shape is None
          else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
     C = ftopi(A, B)
-    if (isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr)):
+    if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
         assert(isinstance(C, tvm.expr.Expr))
         return
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
@@ -82,12 +79,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
         foo(lhs_nd, rhs_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
 
-    check_device("opencl")
-    check_device("vulkan")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
-    check_device("nvptx")
+    for target in get_all_backend():
+        check_device(target)
     check_device("sdaccel")
 
 def test_broadcast_to():
diff --git a/topi/tests/python/test_topi_clip.py b/topi/tests/python/test_topi_clip.py
index ffc89aeb9..f1367463e 100644
--- a/topi/tests/python/test_topi_clip.py
+++ b/topi/tests/python/test_topi_clip.py
@@ -5,6 +5,7 @@ import topi
 from topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
 
+from common import get_all_backend
 
 def verify_clip(N, a_min, a_max, dtype):
     A = tvm.placeholder((N, N), dtype=dtype, name='A')
@@ -34,7 +35,7 @@ def verify_clip(N, a_min, a_max, dtype):
         f(a, b)
         np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['llvm', 'opencl', 'sdaccel']:
+    for device in get_all_backend():
         check_device(device)
 
 def test_clip():
diff --git a/topi/tests/python/test_topi_conv2d.py b/topi/tests/python/test_topi_conv2d.py
deleted file mode 100644
index 365fdf551..000000000
--- a/topi/tests/python/test_topi_conv2d.py
+++ /dev/null
@@ -1,47 +0,0 @@
-"""Example code to do conv2d."""
-import os
-import numpy as np
-import tvm
-from tvm import autotvm
-import topi
-import topi.testing
-from tvm.contrib.pickle_memoize import memoize
-from topi.util import get_const_tuple
-
-
-def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding):
-    in_height = in_width = in_size
-
-    with tvm.target.arm_cpu():
-        A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
-        W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
-        B = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), 'NCHW', 'float32')
-        s = topi.generic.schedule_conv2d_nchw([B])
-
-    a_shape = get_const_tuple(A.shape)
-    w_shape = get_const_tuple(W.shape)
-    dtype = A.dtype
-
-    @memoize("topi.tests.test_topi_conv2d.verify_conv2d")
-    def get_ref_data():
-        a_np = np.random.uniform(size=a_shape).astype(dtype)
-        w_np = np.random.uniform(size=w_shape).astype(dtype)
-        b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
-        return a_np, w_np, b_np
-
-    a_np, w_np, b_np = get_ref_data()
-
-    ctx = tvm.cpu(0)
-    a = tvm.nd.array(a_np, ctx)
-    w = tvm.nd.array(w_np, ctx)
-    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
-    func = tvm.build(s, [A, W, B], "llvm")
-    func(a, w, b)
-    np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-
-def test_conv2d():
-    with autotvm.tophub.context(tvm.target.arm_cpu('rasp3b'), allow_fallback=True):
-        verify_conv2d(1, 56, 64, 64, 3, 1, 1)
-
-if __name__ == "__main__":
-    test_conv2d()
diff --git a/topi/tests/python/test_topi_conv2d_hwcn.py b/topi/tests/python/test_topi_conv2d_hwcn.py
index 1ff4b0247..af1afcb9e 100644
--- a/topi/tests/python/test_topi_conv2d_hwcn.py
+++ b/topi/tests/python/test_topi_conv2d_hwcn.py
@@ -43,14 +43,12 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        with tvm.build_config(auto_unroll_max_step=128,
-                              unroll_explicit=(device != "cuda")):
-            func1 = tvm.build(s1, [A, W, B], device)
-            func2 = tvm.build(s2, [A, W, C], device)
-            func1(a, w, b)
-            func2(a, w, c)
-            np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-            np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+        func1 = tvm.build(s1, [A, W, B], device)
+        func2 = tvm.build(s2, [A, W, C], device)
+        func1(a, w, b)
+        func2(a, w, c)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+        np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
     for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
         check_device(device)
diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py
index c663384b8..6f367d10c 100644
--- a/topi/tests/python/test_topi_conv2d_nchw.py
+++ b/topi/tests/python/test_topi_conv2d_nchw.py
@@ -1,31 +1,41 @@
 """Example code to do convolution."""
-import os
+
 import numpy as np
 import tvm
+from tvm import autotvm
 import topi
 import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
-def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
+from common import get_all_backend
+
+def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
     print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
 
     in_height = in_width = in_size
 
     A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
     W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
+    bias = tvm.placeholder((num_filter, 1, 1), name='bias')
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
     dtype = A.dtype
 
     @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
     def get_ref_data():
         a_np = np.random.uniform(size=a_shape).astype(dtype)
         w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
         dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
-        b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
-        c_np = np.maximum(b_np, 0)
+        c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
+        if add_bias:
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
         return a_np, w_np, b_np, c_np
 
     a_np, w_np, b_np, c_np = get_ref_data()
@@ -38,66 +48,103 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
-            B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW')
-            C = topi.nn.relu(B)
-            s1 = topi.generic.schedule_conv2d_nchw([B])
-            s2 = topi.generic.schedule_conv2d_nchw([C])
+            C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = topi.generic.schedule_conv2d_nchw([C])
+
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
-        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
+        b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        no_unroll_explicit = device in ["cuda", "nvptx", "rocm"]
-        with tvm.build_config(auto_unroll_max_step=1400,
-                              unroll_explicit=not no_unroll_explicit):
-            func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
-            func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
-            func1(a, w, b)
-            func2(a, w, c)
-            np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-            np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
-
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+        if add_bias:
+            func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
+            func(a, w, c)
+        np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in get_all_backend():
         check_device(device)
 
 
 def test_conv2d_nchw():
+    autotvm.DispatchContext.current.silent = True
+
     # ResNet18 workloads
-    verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
-    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
-    verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
-    verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
-    verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
-    verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
-    verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
-    verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
-    verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
-    verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
-    verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
-    verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
-    # ResNet50 workloads
-    verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0)
-    verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0)
-    verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0)
-    verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0)
-    verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0)
-    verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0)
-    verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0)
-    verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0)
-    verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0)
-    verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0)
-    verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0)
-    verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0)
-    verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0)
-    verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0)
-    # Vgg16 workloads
-    verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
-    # Super resolution workloads
-    verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2)
-    verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1)
-    verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1)
-    verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1)
+    verify_conv2d_nchw(1,   3, 224,  64, 7, 2, 3)
+    verify_conv2d_nchw(1,  64,  56,  64, 3, 1, 1)
+    verify_conv2d_nchw(1,  64,  56,  64, 1, 1, 0)
+    verify_conv2d_nchw(1,  64,  56, 128, 3, 2, 1)
+    verify_conv2d_nchw(1,  64,  56, 128, 1, 2, 0)
+    verify_conv2d_nchw(1, 128,  28, 128, 3, 1, 1)
+    verify_conv2d_nchw(1, 128,  28, 256, 3, 2, 1)
+    verify_conv2d_nchw(1, 128,  28, 256, 1, 2, 0)
+    verify_conv2d_nchw(1, 256,  14, 256, 3, 1, 1)
+    verify_conv2d_nchw(1, 256,  14, 512, 3, 2, 1)
+    verify_conv2d_nchw(1, 256,  14, 512, 1, 2, 0)
+    verify_conv2d_nchw(1, 512,   7, 512, 3, 1, 1)
+
+    # bias, relu
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True)
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True)
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True)
+
     # dilation = 2
-    verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2)
+    verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2)
+
+    # weird workloads
+    verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1)
+    verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2)
+
+    # inception v3 workloads
+    verify_conv2d_nchw(1,    3, 299,  32, 3, 2, 0)
+    verify_conv2d_nchw(1,   32, 149,  32, 3, 1, 0)
+    verify_conv2d_nchw(1,   32, 147,  64, 3, 1, 1)
+    verify_conv2d_nchw(1,   64,  73,  80, 1, 1, 0)
+    verify_conv2d_nchw(1,   80,  73, 192, 3, 1, 0)
+    verify_conv2d_nchw(1,  192,  35,  64, 1, 1, 0)
+    verify_conv2d_nchw(1,  192,  35,  48, 1, 1, 0)
+    verify_conv2d_nchw(1,   48,  35,  64, 5, 1, 2)
+    verify_conv2d_nchw(1,   64,  35,  96, 3, 1, 1)
+    verify_conv2d_nchw(1,   96,  35,  96, 3, 1, 1)
+    verify_conv2d_nchw(1,  192,  35,  32, 1, 1, 0)
+    verify_conv2d_nchw(1,  256,  35,  64, 1, 1, 0)
+    verify_conv2d_nchw(1,  256,  35,  48, 1, 1, 0)
+    verify_conv2d_nchw(1,  288,  35,  64, 1, 1, 0)
+    verify_conv2d_nchw(1,  288,  35,  48, 1, 1, 0)
+    verify_conv2d_nchw(1,  288,  35, 384, 3, 2, 0)
+    # verify_conv2d_nchw(1,   96,  35,  96, 3, 2, 0)
+    # verify_conv2d_nchw(1,  768,  17, 192, 1, 1, 0)
+    # verify_conv2d_nchw(1,  768,  17, 128, 1, 1, 0)
+    # verify_conv2d_nchw(1,  128,  17, 128, 1, 1, 0)
+    # verify_conv2d_nchw(1,  128,  17, 192, 7, 1, 3)
+    # verify_conv2d_nchw(1,  128,  17, 128, 7, 1, 3)
+    # verify_conv2d_nchw(1,  128,  17, 192, 1, 1, 0)
+    # verify_conv2d_nchw(1,  768,  17, 160, 1, 1, 0)
+    # verify_conv2d_nchw(1,  160,  17, 160, 1, 1, 0)
+    # verify_conv2d_nchw(1,  160,  17, 192, 7, 1, 3)
+    # verify_conv2d_nchw(1,  160,  17, 160, 7, 1, 3)
+    # verify_conv2d_nchw(1,  160,  17, 192, 1, 1, 0)
+    # verify_conv2d_nchw(1,  192,  17, 192, 1, 1, 0)
+    # verify_conv2d_nchw(1,  192,  17, 192, 7, 1, 3)
+    # verify_conv2d_nchw(1,  192,  17, 320, 3, 2, 0)
+    # verify_conv2d_nchw(1,  192,  17, 192, 3, 2, 0)
+    verify_conv2d_nchw(1, 1280,   8, 320, 1, 1, 0)
+    verify_conv2d_nchw(1, 1280,   8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1,  384,   8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1,  384,   8, 384, 3, 1, 1)
+    verify_conv2d_nchw(1, 1280,   8, 448, 1, 1, 0)
+    verify_conv2d_nchw(1,  448,   8, 384, 3, 1, 1)
+    verify_conv2d_nchw(1, 1280,   8, 192, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048,   8, 320, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048,   8, 384, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048,   8, 448, 1, 1, 0)
+    verify_conv2d_nchw(1, 2048,   8, 192, 1, 1, 0)
+
 
 if __name__ == "__main__":
     test_conv2d_nchw()
diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
index 0c9854000..5f65c038b 100644
--- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py
+++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
@@ -6,14 +6,13 @@ import topi.testing
 from tvm.contrib.pickle_memoize import memoize
 from topi.util import get_const_tuple
 
+from common import get_all_backend
 
 def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
     in_height = in_width = in_size
 
     A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
     W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
-    B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding, A.dtype)
-    C = topi.nn.relu(B)
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -36,22 +35,23 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
+            B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
+            C = topi.nn.relu(B)
             s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
             s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
         a = tvm.nd.array(a_np, ctx)
         w = tvm.nd.array(w_np, ctx)
         b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
         c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
-        with tvm.build_config(auto_unroll_max_step=128,
-                              unroll_explicit=(device != "cuda")):
-            func1 = tvm.build(s1, [A, W, B], device)
-            func2 = tvm.build(s2, [A, W, C], device)
-            func1(a, w, b)
-            func2(a, w, c)
-            np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
-            np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+        func1 = tvm.build(s1, [A, W, B], device)
+        func2 = tvm.build(s2, [A, W, C], device)
+        func1(a, w, b)
+        func2(a, w, c)
+        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
+        np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in get_all_backend():
         check_device(device)
 
 
diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py
index 2df43eb30..92f95f3e0 100644
--- a/topi/tests/python/test_topi_dense.py
+++ b/topi/tests/python/test_topi_dense.py
@@ -6,13 +6,12 @@ import topi.testing
 from topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
 
+from common import get_all_backend
 
 def verify_dense(batch, in_dim, out_dim, use_bias=True):
     A = tvm.placeholder((batch, in_dim), name='A')
     B = tvm.placeholder((out_dim, in_dim), name='B')
     C = tvm.placeholder((out_dim,), name='C')
-    D = topi.nn.dense(A, B, C if use_bias else None)
-    D = topi.nn.relu(D)
     dtype = A.dtype
 
     # use memoize to pickle the test data for next time use
@@ -36,6 +35,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
+            D = topi.nn.dense(A, B, C if use_bias else None)
+            D = topi.nn.relu(D)
             s = topi.generic.schedule_dense(D)
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(b_np, ctx)
@@ -45,13 +46,15 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
         f(a, b, c, d)
         np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
+    for device in get_all_backend():
         check_device(device)
 
 def test_dense():
     verify_dense(1, 1024, 1000, use_bias=True)
     verify_dense(1, 1024, 1000, use_bias=False)
 
+    verify_dense(2, 1024, 1000, use_bias=True)
+
 
 if __name__ == "__main__":
     test_dense()
diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py
index 3086054ba..8c27af839 100644
--- a/topi/tests/python/test_topi_depthwise_conv2d.py
+++ b/topi/tests/python/test_topi_depthwise_conv2d.py
@@ -2,11 +2,10 @@ import tvm
 import topi
 import topi.testing
 import numpy as np
-from scipy import signal
 from topi.util import get_const_tuple
 from tvm.contrib.pickle_memoize import memoize
-from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc
 
+from common import get_all_backend
 
 def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1):
     in_width = in_height
@@ -18,10 +17,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
     DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
     Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
     Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
-    # declare
-    DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
-    ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
-    Relu = topi.nn.relu(ScaleShift)
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -30,6 +25,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
+            # declare
+            DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding)
+            ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
+            Relu = topi.nn.relu(ScaleShift)
             # schedule
             s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
             s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
@@ -88,12 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
         np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
         np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
-    check_device("opencl")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
-    check_device("vulkan")
-    check_device("nvptx")
+    for device in get_all_backend():
+        check_device(device)
 
 
 def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1):
@@ -107,11 +102,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
     DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter')
     Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
     Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
-    # declare
-    DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
-    ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
-    Relu = topi.nn.relu(ScaleShift)
-    # schedule
 
     def check_device(device):
         ctx = tvm.context(device, 0)
@@ -121,6 +111,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         print("Running on target: %s" % device)
 
         with tvm.target.create(device):
+            # declare
+            DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding)
+            ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
+            Relu = topi.nn.relu(ScaleShift)
+            # schedule
             s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
             s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
             s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
@@ -180,12 +175,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
         np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
         np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
 
-    check_device("opencl")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
-    check_device("vulkan")
-    check_device("nvptx")
+    for device in get_all_backend():
+        check_device(device)
+
 
 def test_depthwise_conv2d():
     print("testing nchw")
diff --git a/tutorials/autotvm/tune_nnvm_arm.py b/tutorials/autotvm/tune_nnvm_arm.py
index f3d1c62bd..e85786037 100644
--- a/tutorials/autotvm/tune_nnvm_arm.py
+++ b/tutorials/autotvm/tune_nnvm_arm.py
@@ -312,7 +312,9 @@ def tune_and_evaluate():
 
         # upload module to device
         print("Upload...")
-        remote = autotvm.measure.request_remote(device_key, timeout=10000)
+        remote = autotvm.measure.request_remote(device_key,
+                                                tracker_addr=('localhost', 9190),
+                                                timeout=10000)
         remote.upload(tmp.relpath(filename))
         rlib = remote.load_module(filename)
 
@@ -333,7 +335,6 @@ def tune_and_evaluate():
 
 # We do not run the tuning in our webpage server since it takes too long.
 # Uncomment the following line to run by yourself.
-
 # tune_and_evaluate()
 
 ######################################################################
-- 
GitLab