From a0062582c59c6c84aac43475bcecf96a0ca00401 Mon Sep 17 00:00:00 2001
From: eqy <eqy@cs.washington.edu>
Date: Mon, 24 Dec 2018 11:28:36 -0800
Subject: [PATCH] [RELAY][AUTOTVM] Extract tuning tasks from Relay programs
 (#2181)

---
 python/tvm/autotvm/task/__init__.py           |   1 +
 python/tvm/autotvm/task/nnvm_integration.py   | 231 +++---------------
 python/tvm/autotvm/task/relay_integration.py  | 200 +++++++++++++++
 python/tvm/autotvm/task/topi_integration.py   | 192 ++++++++++++++-
 .../relay/test_autotvm_task_extraction.py     |  56 +++++
 topi/python/topi/x86/conv2d.py                |   2 +-
 topi/python/topi/x86/depthwise_conv2d.py      |   2 +-
 7 files changed, 477 insertions(+), 207 deletions(-)
 create mode 100644 python/tvm/autotvm/task/relay_integration.py
 create mode 100644 tests/python/relay/test_autotvm_task_extraction.py

diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py
index 04bcec92f..f6ea07c27 100644
--- a/python/tvm/autotvm/task/__init__.py
+++ b/python/tvm/autotvm/task/__init__.py
@@ -14,3 +14,4 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBe
 
 from .topi_integration import register_topi_compute, register_topi_schedule
 from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
+from .relay_integration import extract_from_program, extract_from_multiple_program
diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py
index 6a07194a5..cd7337586 100644
--- a/python/tvm/autotvm/task/nnvm_integration.py
+++ b/python/tvm/autotvm/task/nnvm_integration.py
@@ -7,208 +7,13 @@ import warnings
 import logging
 
 
-from ... import tensor, placeholder, create_schedule, target as _target
+from ... import target as _target
 
-from ..util import get_const_tuple
-from .task import create, register
+from .task import create
+from .topi_integration import TaskExtractEnv
 
 logger = logging.getLogger('autotvm')
 
-def serialize_args(args):
-    """serialize arguments of a topi function to a hashable tuple.
-
-    Parameters
-    ----------
-    args: list of hashable or Tensor
-    """
-    ret = []
-    for t in args:
-        if isinstance(t, tensor.Tensor):
-            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
-        else:
-            ret.append(t)
-    return tuple(ret)
-
-
-def deserialize_args(args):
-    """The inverse function of :code:`serialize_args`.
-
-    Parameters
-    ----------
-    args: list of hashable or Tensor
-    """
-    ret = []
-    for t in args:
-        if isinstance(t, tuple) and t[0] == 'TENSOR':
-            ret.append(placeholder(shape=t[1], dtype=t[2]))
-        else:
-            ret.append(t)
-    return ret
-
-
-# Task extractor for nnvm graph
-class TaskExtractEnv:
-    """Global environment for extracting tuning tasks from nnvm graph"""
-    current = None
-
-    def __init__(self):
-        import topi
-        import nnvm
-
-        # NOTE: To add more symbols, you only need to change the following lists
-        # nnvm symbol -> topi compute
-        self.symbol2topi = {
-            nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
-                              topi.nn.group_conv2d_nchw],
-            nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
-            nnvm.sym.dense: [topi.nn.dense],
-        }
-
-        # topi compute -> autotvm task name
-        self.topi_to_task = {
-            topi.nn.conv2d: "topi_nn_conv2d",
-            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
-            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
-            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
-            topi.nn.dense: "topi_nn_dense",
-        }
-
-        self.topi_to_schedule = {
-            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
-                             topi.generic.schedule_conv2d_nhwc],
-            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
-                                            topi.generic.schedule_depthwise_conv2d_nhwc],
-            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
-            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
-            topi.nn.dense: [topi.generic.schedule_dense],
-        }
-
-        self._register_tracing()
-        self._register_topi_task()
-        self.task_collection = []
-        self.wanted_topi_funcs = list(self.topi_to_task.keys())
-
-    def _register_tracing(self):
-        """Register tracing function to track the topi function call"""
-        # register topi compute for "tracing" target
-        for topi_compute in self.topi_to_task:
-            def _local_scope(compute_func):
-                """start a scope to hold the local function in for loop"""
-
-                @compute_func.register("tracing", )
-                def _tracing_topi_compute(*args, **kwargs):
-                    assert not kwargs, "Do not support extracting tuning tasks when" \
-                                       "kwargs is used in TOPI function call." \
-                                       "Please modify it to use only positional args."
-
-                    if compute_func in self.wanted_topi_funcs:  # record this call
-                        key = (self.topi_to_task[compute_func], serialize_args(args))
-                        if key not in self.task_collection:
-                            self.task_collection.append(key)
-
-                    return compute_func.fdefault(*args)
-            _local_scope(topi_compute)
-
-        # register topi schedule for "tracing" target
-        for topi_compute in self.topi_to_task:
-            for topi_schedule in self.topi_to_schedule[topi_compute]:
-                def _local_scope_(schedule_func):
-                    """start a scope to hold the local function in for loop"""
-
-                    @schedule_func.register("tracing", )
-                    def _tracing_topi_compute(outs):
-                        outs = [outs] if isinstance(outs, tensor.Tensor) else outs
-                        return create_schedule([x.op for x in outs])
-                _local_scope_(topi_schedule)
-
-    def _register_topi_task(self):
-        """register tuning wrapper for topi function"""
-        import topi
-
-        # Tuning wrapper for topi functions
-        @register("topi_nn_conv2d")
-        def _topi_nn_conv2d(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            layout = args[-2]
-            assert layout == 'NCHW', "only support NCHW currently"
-            C = topi.nn.conv2d(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_depthwise_conv2d_nchw")
-        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_group_conv2d_nchw")
-        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_group_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv2d_transpose_nchw")
-        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_transpose_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_dense")
-        def _topi_nn_dense(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function call"
-            args = deserialize_args(args)
-            data, weight, bias = args
-            C = topi.nn.dense(*args, **kwargs)
-            s = topi.generic.schedule_dense([C])
-            if bias is not None:
-                return s, [data, weight, bias, C]
-            return s, [data, weight, C]
-
-    def reset(self, wanted_topi_funcs):
-        """Reset task collections
-
-        Parameters
-        ----------
-        wanted_topi_funcs: List of function
-            The topi function to be extracted
-        """
-        self.task_collection = []
-        self.wanted_topi_funcs = wanted_topi_funcs
-
-    def get_tasks(self):
-        """Get collected tasks
-
-        Returns
-        -------
-        tasks: List of tuple(name, args)
-            A list of tasks extracted from the nnvm graph
-        """
-        return self.task_collection
-
-    @staticmethod
-    def get():
-        """Get the single instance of TaskExtractEnv
-
-        Returns
-        -------
-        env: TaskExtractEnv
-            The single instance of TaskExtractEnv
-        """
-        if not TaskExtractEnv.current:
-            TaskExtractEnv.current = TaskExtractEnv()
-        return TaskExtractEnv.current
-
 
 def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
     """ Extract tuning tasks from a nnvm graph.
@@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
         collected tasks
     """
     import nnvm.compiler
+    import nnvm
+    import topi
 
     env = TaskExtractEnv.get()
 
+    #NOTE: To add more symbols, you only need to change the following lists
+    #nnvm symbol -> topi compute
+    SYMBOL2TOPI = {
+        nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
+                          topi.nn.group_conv2d_nchw],
+        nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
+        nnvm.sym.dense: [topi.nn.dense],
+    }
+
     topi_funcs = []
     for sym_name in symbols:
-        if sym_name in env.symbol2topi:
-            topi_funcs.extend(env.symbol2topi[sym_name])
+        if sym_name in SYMBOL2TOPI:
+            topi_funcs.extend(SYMBOL2TOPI[sym_name])
         else:
             warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
 
@@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
         collected tasks
     """
     import nnvm.compiler
+    import nnvm
+    import topi
 
     env = TaskExtractEnv.get()
 
+    #NOTE: To add more symbols, you only need to change the following lists
+    #nnvm symbol -> topi compute
+    SYMBOL2TOPI = {
+        nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
+                          topi.nn.group_conv2d_nchw],
+        nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
+        nnvm.sym.dense: [topi.nn.dense],
+    }
+
     topi_funcs = []
     for sym_name in symbols:
-        if sym_name in env.symbol2topi:
-            topi_funcs.extend(env.symbol2topi[sym_name])
+        if sym_name in SYMBOL2TOPI:
+            topi_funcs.extend(SYMBOL2TOPI[sym_name])
         else:
             warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
 
diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py
new file mode 100644
index 000000000..21acf257f
--- /dev/null
+++ b/python/tvm/autotvm/task/relay_integration.py
@@ -0,0 +1,200 @@
+# pylint: disable=unused-variable,invalid-name
+"""
+Decorator and utilities for the integration with TOPI and Relay
+99.9% copy-paste of implementation by @MerryMercy
+
+"""
+import threading
+import warnings
+import logging
+
+
+from ... import tensor, placeholder, target as _target
+
+from .task import create
+from .topi_integration import TaskExtractEnv
+
+logger = logging.getLogger('autotvm')
+
+
+def serialize_args(args):
+    """serialize arguments of a topi function to a hashable tuple.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, tensor.Tensor):
+            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
+        else:
+            ret.append(t)
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, tuple) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+def extract_from_program(func, params, ops, target, target_host=None):
+    """ Extract tuning tasks from a relay program.
+
+    This function collects tuning tasks by building the program
+    with a "tracing" target and tracing all the calls to topi.
+
+    Parameters
+    ----------
+    func: relay.expr.Function
+        The func to tune
+    params: dict of str to numpy array
+        The associated parameters of the program
+    ops: List of relay op
+        List of relay ops to be tuned
+    dtype: str or dict of str to str
+        The input types to the program
+    target: tvm.target.Target
+        The compilation target
+    target_host: tvm.target.Target
+        The host compilation target
+
+    Returns
+    -------
+    task: Array of autotvm.task.Task
+        collected tasks
+    """
+    env = TaskExtractEnv.get()
+    import tvm.relay.op
+    from tvm import relay
+    import topi
+
+    # NOTE: To add more ops, you only need to change the following lists
+    # relay op -> topi compute
+    OP2TOPI = {
+        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
+                                 topi.nn.group_conv2d_nchw],
+        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
+        tvm.relay.op.nn.dense: [topi.nn.dense],
+    }
+
+    topi_funcs = []
+    for op_name in ops:
+        if op_name in OP2TOPI:
+            topi_funcs.extend(OP2TOPI[op_name])
+        else:
+            warnings.warn("Op %s is not tunable, ignored" % op_name)
+
+    # run compiler to collect all TOPI calls during compilation
+    env.reset(topi_funcs)
+
+    # disable logger temporarily
+    old_state = logger.disabled
+    logger.disabled = True
+
+    # use a "tracing" target to do a fake compile for collecting topi calls
+    tracing_target = _target.create("llvm -device=tracing")
+    relay.backend.compile_engine.get().clear()
+    # wrap build call in thread to avoid multiprocessing problems
+    build_thread = threading.Thread(target=relay.build, args=(func,
+                                                              tracing_target,
+                                                              target_host,
+                                                              params))
+    build_thread.start()
+    build_thread.join()
+    logger.disabled = old_state
+
+    # create tasks for target
+    tasks = []
+    for task_name, args in env.get_tasks():
+        tasks.append(create(task_name, args,
+                            target=target, target_host=target_host,
+                            template_key='direct'))
+
+    return tasks
+
+
+def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
+    """ Extract tuning tasks from multiple relay programs.
+
+    This function is the multiple program version of extract_from_program
+
+    Parameters
+    ----------
+    funcs: List of relay.expr.Function
+        The list of functions to tune
+    params: List of dict of str to numpy array
+        The associated parameters of the programs
+    ops: List of relay op
+        List of relay ops to be tuned
+    target: tvm.target.Target
+        The compilation target
+    target_host: tvm.target.Target
+        The host compilation target
+
+    Returns
+    -------
+    task: Array of autotvm.task.Task
+        collected tasks
+    """
+    env = TaskExtractEnv.get()
+    import tvm.relay.op
+    from tvm import relay
+    import topi
+
+    # NOTE: To add more ops, you only need to change the following lists
+    # relay op -> topi compute
+    OP2TOPI = {
+        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
+                                 topi.nn.group_conv2d_nchw],
+        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
+        tvm.relay.op.nn.dense: [topi.nn.dense],
+    }
+
+    topi_funcs = []
+    for op_name in ops:
+        if op_name in OP2TOPI:
+            topi_funcs.extend(OP2TOPI[op_name])
+        else:
+            warnings.warn("Op %s is not tunable, ignored" % op_name)
+
+    # run compiler to collect all TOPI calls during compilation
+    env.reset(topi_funcs)
+
+    # disable logger temporarily
+    old_state = logger.disabled
+    logger.disabled = True
+
+    # use a "tracing" target to do a fake compile for collecting topi calls
+    tracing_target = _target.create("llvm -device=tracing")
+
+    for func, param in zip(funcs, params):
+        # wrap build call in thread to avoid multiprocessing problems
+        build_thread = threading.Thread(target=relay.build, args=(func,
+                                                                  tracing_target,
+                                                                  target_host,
+                                                                  params))
+        build_thread.start()
+        build_thread.join()
+
+    logger.disabled = old_state
+
+    # create tasks for target
+    tasks = []
+    for task_name, args in env.get_tasks():
+        tasks.append(create(task_name, args,
+                            target=target, target_host=target_host,
+                            template_key='direct'))
+
+    return tasks
diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py
index f005ee0c9..412d7ae0e 100644
--- a/python/tvm/autotvm/task/topi_integration.py
+++ b/python/tvm/autotvm/task/topi_integration.py
@@ -11,16 +11,202 @@ tuple.
 See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
 """
 
-from ... import _api_internal, tensor
-
-from .task import args_to_workload, dispatcher
+from ... import _api_internal, tensor, placeholder, create_schedule
 
+from .task import args_to_workload, dispatcher, register
+from ..util import get_const_tuple
 
 # A table that records all registered dispatcher for all targets
 _REGISTED_DISPATHCER = {
 }
 
 
+def serialize_args(args):
+    """serialize arguments of a topi function to a hashable tuple.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, tensor.Tensor):
+            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
+        else:
+            ret.append(t)
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`.
+
+    Parameters
+    ----------
+    args: list of hashable or Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, tuple) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+# Task extractor for nnvm graph, relay program
+class TaskExtractEnv:
+    """Global environment for extracting tuning tasks from nnvm graph"""
+    current = None
+
+    def __init__(self):
+        import topi
+
+        # topi compute -> autotvm task name
+        self.topi_to_task = {
+            topi.nn.conv2d: "topi_nn_conv2d",
+            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
+            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
+            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
+            topi.nn.dense: "topi_nn_dense",
+        }
+
+        self.topi_to_schedule = {
+            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
+                             topi.generic.schedule_conv2d_nhwc],
+            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
+                                            topi.generic.schedule_depthwise_conv2d_nhwc],
+            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
+            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
+            topi.nn.dense: [topi.generic.schedule_dense],
+        }
+
+        self._register_tracing()
+        self._register_topi_task()
+        self.task_collection = []
+        self.wanted_topi_funcs = list(self.topi_to_task.keys())
+
+    def _register_tracing(self):
+        """Register tracing function to track the topi function call"""
+        # register topi compute for "tracing" target
+        for topi_compute in self.topi_to_task:
+            def _local_scope(compute_func):
+                """start a scope to hold the local function in for loop"""
+
+                @compute_func.register("tracing", )
+                def _tracing_topi_compute(*args, **kwargs):
+                    assert not kwargs, "Do not support extracting tuning tasks when" \
+                                       "kwargs is used in TOPI function call." \
+                                       "Please modify it to use only positional args."
+
+                    if compute_func in self.wanted_topi_funcs:  # record this call
+                        key = (self.topi_to_task[compute_func], serialize_args(args))
+                        if key not in self.task_collection:
+                            self.task_collection.append(key)
+
+                    return compute_func.fdefault(*args)
+            _local_scope(topi_compute)
+
+        # register topi schedule for "tracing" target
+        for topi_compute in self.topi_to_task:
+            for topi_schedule in self.topi_to_schedule[topi_compute]:
+                def _local_scope_(schedule_func):
+                    """start a scope to hold the local function in for loop"""
+
+                    @schedule_func.register("tracing", )
+                    def _tracing_topi_compute(outs):
+                        outs = [outs] if isinstance(outs, tensor.Tensor) else outs
+                        return create_schedule([x.op for x in outs])
+                _local_scope_(topi_schedule)
+
+    def _register_topi_task(self):
+        """register tuning wrapper for topi function"""
+        import topi
+
+        # Tuning wrapper for topi functions
+        @register("topi_nn_conv2d")
+        def _topi_nn_conv2d(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, W = args[:2]
+            layout = args[-2]
+            assert layout == 'NCHW', "only support NCHW currently"
+            C = topi.nn.conv2d(*args, **kwargs)
+            s = topi.generic.schedule_conv2d_nchw([C])
+            return s, [A, W, C]
+
+        @register("topi_nn_depthwise_conv2d_nchw")
+        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, W = args[:2]
+            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
+            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
+            return s, [A, W, C]
+
+        @register("topi_nn_group_conv2d_nchw")
+        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, W = args[:2]
+            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
+            s = topi.generic.schedule_group_conv2d_nchw([C])
+            return s, [A, W, C]
+
+        @register("topi_nn_conv2d_transpose_nchw")
+        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            A, W = args[:2]
+            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
+            s = topi.generic.schedule_conv2d_transpose_nchw([C])
+            return s, [A, W, C]
+
+        @register("topi_nn_dense")
+        def _topi_nn_dense(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function call"
+            args = deserialize_args(args)
+            data, weight, bias = args
+            C = topi.nn.dense(*args, **kwargs)
+            s = topi.generic.schedule_dense([C])
+            if bias is not None:
+                return s, [data, weight, bias, C]
+            return s, [data, weight, C]
+
+    def reset(self, wanted_topi_funcs):
+        """Reset task collections
+
+        Parameters
+        ----------
+        wanted_topi_funcs: List of function
+            The topi function to be extracted
+        """
+        self.task_collection = []
+        self.wanted_topi_funcs = wanted_topi_funcs
+
+    def get_tasks(self):
+        """Get collected tasks
+
+        Returns
+        -------
+        tasks: List of tuple(name, args)
+            A list of tasks extracted from the nnvm graph
+        """
+        return self.task_collection
+
+    @staticmethod
+    def get():
+        """Get the single instance of TaskExtractEnv
+
+        Returns
+        -------
+        env: TaskExtractEnv
+            The single instance of TaskExtractEnv
+        """
+        if not TaskExtractEnv.current:
+            TaskExtractEnv.current = TaskExtractEnv()
+        return TaskExtractEnv.current
+
+
 def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
     """Register a tunable template for a topi compute function.
 
diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py
new file mode 100644
index 000000000..8c93e4a56
--- /dev/null
+++ b/tests/python/relay/test_autotvm_task_extraction.py
@@ -0,0 +1,56 @@
+"""Test task extraction for autotvm"""
+import tvm.relay.testing
+from tvm import relay
+from tvm import autotvm
+
+def get_network(name, batch_size):
+    """Get the symbol definition and random weight of a network"""
+    input_shape = (batch_size, 3, 224, 224)
+
+    if name == 'resnet-18':
+        net, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=batch_size)
+    elif name == 'mobilenet':
+        net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
+    elif name == 'dcgan':
+        net, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
+        input_shape = (batch_size, 100)
+    else:
+        raise ValueError("Unsupported network: " + name)
+
+    return net, params, input_shape
+
+def test_task_extraction():
+    target = 'llvm'
+
+    net, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(net, target=target,
+                                            params=params,
+                                            ops=(relay.op.nn.conv2d,))
+    assert len(tasks) == 12
+
+    net, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(net, target=target,
+                                            params=params,
+                                            ops=(relay.op.nn.dense,))
+    assert len(tasks) == 1
+
+    net, params, input_shape = get_network('resnet-18', batch_size=1)
+    tasks = autotvm.task.extract_from_program(net, target=target,
+                                            params=params,
+                                            ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+    assert len(tasks) == 13
+
+    net, params, input_shape = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_program(net, target=target,
+                                            params=params,
+                                            ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+    assert len(tasks) == 20
+
+    net, params, input_shape = get_network('dcgan', batch_size=1)
+    tasks = autotvm.task.extract_from_program(net, target=target,
+                                            params=params,
+                                            ops=(relay.op.nn.conv2d_transpose,))
+    assert len(tasks) == 4
+
+if __name__ == '__main__':
+    test_task_extraction()
diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py
index 1a7373626..fe38b38d3 100644
--- a/topi/python/topi/x86/conv2d.py
+++ b/topi/python/topi/x86/conv2d.py
@@ -2,7 +2,7 @@
 """Conv2D schedule on x86"""
 import tvm
 from tvm import autotvm
-from tvm.autotvm.task.nnvm_integration import deserialize_args
+from tvm.autotvm.task.topi_integration import deserialize_args
 from tvm.autotvm.task import get_config
 from .. import generic, tag
 from .. import nn
diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py
index 8f37a0316..64858df91 100644
--- a/topi/python/topi/x86/depthwise_conv2d.py
+++ b/topi/python/topi/x86/depthwise_conv2d.py
@@ -4,7 +4,7 @@ import tvm
 from tvm import autotvm
 from tvm.autotvm.task import get_config
 from tvm.autotvm.task.space import SplitEntity
-from tvm.autotvm.task.nnvm_integration import deserialize_args
+from tvm.autotvm.task.topi_integration import deserialize_args
 from .. import generic, tag
 from ..nn.pad import pad
 from ..util import get_const_tuple
-- 
GitLab