From b7beb1ebefa18e29bfbf8ff1c4f8f0c8892d93bc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn> Date: Tue, 21 Aug 2018 18:35:32 -0700 Subject: [PATCH] [AUTOTVM] Allow fallback for template & Fix bugs in tuners (#1615) * support fallback & fix bugs in tuners & clean topi test * update task extraction * update task extraction * fix arm tutorial * Update tune_nnvm_arm.py --- nnvm/python/nnvm/compiler/build_module.py | 5 +- .../compiler/test_autotvm_task_extraction.py | 63 +++++++ python/tvm/autotvm/__init__.py | 3 +- python/tvm/autotvm/measure/measure.py | 5 +- python/tvm/autotvm/measure/measure_methods.py | 44 ++--- python/tvm/autotvm/task/__init__.py | 2 +- python/tvm/autotvm/task/dispatcher.py | 117 +++++++++---- python/tvm/autotvm/task/nnvm_integration.py | 117 +++++++++---- python/tvm/autotvm/task/space.py | 56 ++++++- python/tvm/autotvm/task/task.py | 2 +- python/tvm/autotvm/tophub.py | 7 +- python/tvm/autotvm/tuner/ga_tuner.py | 10 +- python/tvm/autotvm/tuner/model_based_tuner.py | 33 ++-- .../tvm/autotvm/tuner/sa_model_optimizer.py | 2 +- python/tvm/autotvm/tuner/tuner.py | 13 +- .../tvm/autotvm/tuner/xgboost_cost_model.py | 119 ++++++++----- python/tvm/autotvm/tuner/xgboost_tuner.py | 17 +- python/tvm/exec/tophub.py | 9 +- python/tvm/target.py | 1 + .../unittest/test_autotvm_dispatch_context.py | 44 +++-- tests/python/unittest/test_autotvm_space.py | 15 +- .../unittest/test_autotvm_xgboost_model.py | 6 +- topi/python/topi/arm_cpu/conv2d.py | 86 +++++++--- topi/python/topi/arm_cpu/depthwise_conv2d.py | 14 +- topi/python/topi/x86/injective.py | 2 +- topi/tests/python/common.py | 12 ++ .../python/test_topi_bitserial_conv2d.py | 25 ++- .../python/test_topi_bitserial_conv2d_rasp.py | 16 +- topi/tests/python/test_topi_bnn.py | 2 +- topi/tests/python/test_topi_broadcast.py | 25 +-- topi/tests/python/test_topi_clip.py | 3 +- topi/tests/python/test_topi_conv2d.py | 47 ------ topi/tests/python/test_topi_conv2d_hwcn.py | 14 +- topi/tests/python/test_topi_conv2d_nchw.py | 157 ++++++++++++------ .../python/test_topi_conv2d_transpose_nchw.py | 22 +-- topi/tests/python/test_topi_dense.py | 9 +- .../python/test_topi_depthwise_conv2d.py | 38 ++--- tutorials/autotvm/tune_nnvm_arm.py | 5 +- 38 files changed, 756 insertions(+), 411 deletions(-) create mode 100644 nnvm/tests/python/compiler/test_autotvm_task_extraction.py create mode 100644 topi/tests/python/common.py delete mode 100644 topi/tests/python/test_topi_conv2d.py diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 217598c9d..6fab4460b 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -239,8 +239,9 @@ def build(graph, target=None, shape=None, dtype="float32", raise ValueError("Target is not set in env or passed as argument.") target = tvm.target.create(target) - # if not inside an autotvm config dispatch context, load pre-tuned parameters from TopHub - if autotvm.task.DispatchContext.current is None: + # If current dispatch context is fallback context (the default root context), + # then load pre-tuned parameters from TopHub + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): tophub_context = autotvm.tophub.context(target) else: tophub_context = autotvm.util.EmptyContext() diff --git a/nnvm/tests/python/compiler/test_autotvm_task_extraction.py b/nnvm/tests/python/compiler/test_autotvm_task_extraction.py new file mode 100644 index 000000000..fd14934f8 --- /dev/null +++ b/nnvm/tests/python/compiler/test_autotvm_task_extraction.py @@ -0,0 +1,63 @@ +"""Test task extraction for autotvm""" + +import nnvm.testing +import nnvm.compiler +from tvm import autotvm + +def get_network(name, batch_size): + """Get the symbol definition and random weight of a network""" + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + + if name == 'resnet-18': + net, params = nnvm.testing.resnet.get_workload(num_layers=18, batch_size=batch_size) + elif name == 'mobilenet': + net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) + elif name == 'squeezenet v1.1': + net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') + elif name == 'vgg-16': + net, params = nnvm.testing.vgg.get_workload(num_layers=16, batch_size=batch_size) + elif name == 'dcgan': + net, params = nnvm.testing.dcgan.get_workload(batch_size=batch_size) + input_shape = (batch_size, 100) + else: + raise ValueError("Unsupported network: " + name) + + return net, params, input_shape, output_shape + +def test_task_extraction(): + target = 'llvm' + dtype = 'float32' + + net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': input_shape}, dtype=dtype, + symbols=(nnvm.sym.conv2d,)) + assert len(tasks) == 12 + + net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': input_shape}, dtype=dtype, + symbols=(nnvm.sym.dense,)) + assert len(tasks) == 1 + + net, params, input_shape, out_shape = get_network('resnet-18', batch_size=1) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': input_shape}, dtype=dtype, + symbols=(nnvm.sym.conv2d, nnvm.sym.dense)) + assert len(tasks) == 13 + + net, params, input_shape, out_shape = get_network('mobilenet', batch_size=1) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': input_shape}, dtype=dtype, + symbols=(nnvm.sym.conv2d, nnvm.sym.dense)) + assert len(tasks) == 20 + + net, params, input_shape, out_shape = get_network('dcgan', batch_size=1) + tasks = autotvm.task.extract_from_graph(net, target=target, + shape={'data': input_shape}, dtype=dtype, + symbols=(nnvm.sym.conv2d_transpose,)) + assert len(tasks) == 4 + +if __name__ == '__main__': + test_task_extraction() diff --git a/python/tvm/autotvm/__init__.py b/python/tvm/autotvm/__init__.py index 5b312d93d..625b50c10 100644 --- a/python/tvm/autotvm/__init__.py +++ b/python/tvm/autotvm/__init__.py @@ -25,5 +25,6 @@ from . import tophub from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo from .tuner import callback from .task import template, get_config, create, ConfigSpace, ConfigEntity, \ - ApplyHistoryBest as apply_history_best + register_topi_compute, register_topi_schedule, \ + DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best from .env import GLOBAL_SCOPE diff --git a/python/tvm/autotvm/measure/measure.py b/python/tvm/autotvm/measure/measure.py index 2325a970b..2d780eeaf 100644 --- a/python/tvm/autotvm/measure/measure.py +++ b/python/tvm/autotvm/measure/measure.py @@ -89,8 +89,9 @@ def measure_option(measure_func, callable: customized build function for other backends (e.g. VTA). See measure/measure_methods.py::default_build_func for example. - check_correctness: bool - Whether check correctness after measurement. This will use llvm cpu as reference. + check_correctness: bool, optional + Whether check correctness after measurement. This will use llvm cpu target to generate + reference output. replay_db : Database, optional The database that we retrieve saved MeasureResult from. diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index d845cc1f8..2d740b949 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -83,7 +83,7 @@ def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10): The priority of this request, larger is more prior timeout: float, optional The timeout of this check (units: seconds). - If time is out, a RuntimerError will be raised. + If time is out, a RuntimeError will be raised. """ def _check(): remote = request_remote(device_key, tracker_addr, priority) @@ -281,11 +281,11 @@ def rpc(key, results: List of MeasureResult The results for input_pack """ - remote = request_remote(key, (host, port), priority, session_timeout) + remote_args = (key, (host, port), priority, session_timeout) res = _measure_common(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output, - remote) + remote_args) return res fmeasure.pack_size = pack_size @@ -294,7 +294,7 @@ def rpc(key, def _measure_common(input_pack, build_func, build_kwargs, number, repeat, - ref_input=None, ref_output=None, remote=None): + ref_input=None, ref_output=None, remote_args=None): """Measure the time cost for a pack of inputs. (Note: A pack is a list of inputs which will be measured inside a same RPC session) @@ -318,8 +318,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, Reference input for checking correctness ref_output: Array of np.ndarray, optional Reference output for checking correctness - remote: RPCSession, optional - The remote RPC session + remote_args: Tuple, optional + The arguments to request_remote. If is not None, will use remote rpc devices. Returns ------- @@ -327,7 +327,8 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, The list of results of measurement. """ res_pack = [] - tmp_dir = util.tempdir() if remote else None + tmp_dir = util.tempdir() if remote_args else None + assert len(input_pack) == 1, "Only supports input_pack == 1 for now" for inp in input_pack: tic = time.time() @@ -360,31 +361,36 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat, tstamp - tic, tstamp)) continue - # upload built module - if remote: - remote.upload(tmp_dir.relpath(filename)) - func = remote.load_module(filename) - ctx = remote.context(str(inp.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat) - else: - ctx = context(str(inp.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat) - # measure time errno = MeasureErrorNo.NO_ERROR try: + # upload built module + if remote_args: + remote = request_remote(*remote_args) + remote.upload(tmp_dir.relpath(filename)) + func = remote.load_module(filename) + ctx = remote.context(str(inp.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat) + else: + ctx = context(str(inp.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat) + + # set input if ref_input: args = [nd.array(x, ctx=ctx) for x in ref_input] else: args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx) for x in arg_bufs] + costs = time_f(*args).results if len(costs) > 2: # remove largest and smallest value to reduce variance costs = list(costs) costs.sort() costs = tuple(costs[1:-1]) + + # check correctness of output if ref_output: for expected, real in zip(ref_output, args): if not np.allclose(expected, real.asnumpy(), rtol=1e-4): diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index 0d43f9265..7592fc5af 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -9,7 +9,7 @@ of typical tasks of interest. from .task import Task, create, register, template, get_config, args_to_workload from .space import ConfigSpace, ConfigEntity from .code_hash import attach_code_hash, attach_code_hash_to_arg -from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, dispatcher +from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, FallbackContext, dispatcher from .topi_integration import register_topi_compute, register_topi_schedule from .nnvm_integration import extract_from_graph diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 93f6d584a..ec1dcc44f 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -21,7 +21,7 @@ import numpy as np from tvm import target as _target -from .space import ConfigSpace +from .space import FallbackConfigEntity logger = logging.getLogger('autotvm') @@ -34,9 +34,36 @@ class DispatchContext(object): """ current = None + def __init__(self): + self._old_ctx = DispatchContext.current + def query(self, target, workload): """ - Query the context to get the specific implementation. + Query the context to get the specific config for a template. + If cannot find the result inside this context, this function will query it + from the upper contexts. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : ConfigSpace + The specific configuration. + """ + ret = self._query_inside(target, workload) + if ret is None: + ret = self._old_ctx.query(target, workload) + return ret + + def _query_inside(self, target, workload): + """ + Query the context to get the specific config for a template. + This function only query config inside this context. Parameters ---------- @@ -117,17 +144,17 @@ def dispatcher(fworkload): def dispatch_func(func, *args, **kwargs): """The wrapped dispatch function""" tgt = _target.current_target() - context = DispatchContext.current - if context is None: - raise RuntimeError("DispatchContext is not initialized") workload = func(*args, **kwargs) - cfg = context.query(tgt, workload) - if cfg.template_key: - return dispatch_dict[cfg.template_key](cfg, *args, **kwargs) - else: - assert dispatch_dict, "No func registered for this dispatcher" + cfg = DispatchContext.current.query(tgt, workload) + if cfg.is_fallback and not cfg.template_key: + # first try 'direct' template + if 'direct' in dispatch_dict: + return dispatch_dict['direct'](cfg, *args, **kwargs) + # otherwise pick a random template for v in dispatch_dict.values(): return v(cfg, *args, **kwargs) + else: + return dispatch_dict[cfg.template_key](cfg, *args, **kwargs) fdecorate = decorate(fworkload, dispatch_func) fdecorate.register = register @@ -135,7 +162,7 @@ def dispatcher(fworkload): class ApplyConfig(DispatchContext): - """Apply a specific config entity during query. + """Apply a deterministic config entity for all queries. Parameters ---------- @@ -147,7 +174,7 @@ class ApplyConfig(DispatchContext): self._config = config self.workload = None - def query(self, target, workload): + def _query_inside(self, target, workload): """Override query""" self.workload = workload return self._config @@ -164,20 +191,12 @@ class ApplyHistoryBest(DispatchContext): If is str, then it should be the filename of a records log file. Each row of this file is an encoded record pair. Otherwise, it is an iterator. - default: ConfigEntity, optional - The default config to return when no history records - allow_fallback: bool - Whether allow to use a fallback configuration if cannot find - tuned result. """ - def __init__(self, records, default=None, allow_fallback=False): + def __init__(self, records): super(ApplyHistoryBest, self).__init__() self.best_by_targetkey = {} self.best_by_model = {} - self._default = default - self._allow_fallback = allow_fallback - self.fallback = {} if records: self.load(records) @@ -234,7 +253,7 @@ class ApplyHistoryBest(DispatchContext): logger.debug("Finish loading %d records", counter) - def query(self, target, workload): + def _query_inside(self, target, workload): if target is None: raise RuntimeError("Need a target context to find the history best. " "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" @@ -254,20 +273,50 @@ class ApplyHistoryBest(DispatchContext): if key in self.best_by_targetkey: return self.best_by_targetkey[key][0].config - if self._default: - return self._default + return None + + +class FallbackContext(DispatchContext): + """ + A fallback dispatch context. + + Any tunable template can be called under this context. + This is the root context. + """ + + def __init__(self): + super(FallbackContext, self).__init__() + self.memory = {} + self.silent = False + + def _query_inside(self, target, workload): + key = (str(target), workload) + if key in self.memory: + return self.memory[key] - if self._allow_fallback: - key = (target, workload) - if key in self.fallback: - return self.fallback[key] + if not self.silent: logger.warning( "Cannot find config for target=%s, workload=%s. A fallback configuration " "is used, which may bring great performance regression.", target, workload) - cfg = ConfigSpace() - self.fallback[key] = cfg - return cfg + cfg = FallbackConfigEntity() + + # cache this config + self.memory[key] = cfg + return cfg + + def clear_cache(self, target, workload): + """Clear fallback cache. Pass the same argument as _query_inside to this function + to clean the cache. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + """ + key = (str(target), workload) + if key in self.memory: + del self.memory[key] - raise RuntimeError( - "Cannot find config for target=%s, workload=%s. You need to do tuning " - "for this workload to get the config." % (target, workload)) +DispatchContext.current = FallbackContext() diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py index 1b50869fc..9138cc288 100644 --- a/python/tvm/autotvm/task/nnvm_integration.py +++ b/python/tvm/autotvm/task/nnvm_integration.py @@ -7,11 +7,10 @@ import warnings import logging -from ... import tensor, placeholder, target as _target +from ... import tensor, placeholder, create_schedule, target as _target from ..util import get_const_tuple from .task import create, register -from .dispatcher import ApplyHistoryBest logger = logging.getLogger('autotvm') @@ -56,40 +55,68 @@ class TaskExtractEnv: import topi import nnvm + # NOTE: To add more symbols, you only need to change the following lists + # nnvm symbol -> topi compute self.symbol2topi = { nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw], - nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose], + nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], + nnvm.sym.dense: [topi.nn.dense], } + # topi compute -> autotvm task name self.topi_to_task = { topi.nn.conv2d: "topi_nn_conv2d", topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", + topi.nn.dense: "topi_nn_dense", } - self._register_dummy() + self.topi_to_schedule = { + topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw, + topi.generic.schedule_conv2d_nhwc], + topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw, + topi.generic.schedule_depthwise_conv2d_nhwc], + topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], + topi.nn.dense: [topi.generic.schedule_dense], + } + + self._register_tracing() self._register_topi_task() self.task_collection = [] + self.wanted_topi_funcs = list(self.topi_to_task.keys()) + + def _register_tracing(self): + """Register tracing function to track the topi function call""" + # register topi compute for "tracing" target + for topi_compute in self.topi_to_task: + def _local_scope(compute_func): + """start a scope to hold the local function in for loop""" - def _register_dummy(self): - """Register dummy function to track the topi function call""" - for func in self.topi_to_task: - def _local_scope(local_func): - """build a scope to holds the function""" - @local_func.register("dummy", ) - def _dummy_func(*args, **kwargs): + @compute_func.register("tracing", ) + def _tracing_topi_compute(*args, **kwargs): assert not kwargs, "Do not support extracting tuning tasks when" \ "kwargs is used in TOPI function call." \ "Please modify it to use only positional args." - if (self.topi_to_task[local_func], serialize_args(args)) \ - not in self.task_collection: - self.task_collection.append((self.topi_to_task[local_func], - serialize_args(args))) - with _target.create("opencl"): - return local_func(*args) + if compute_func in self.wanted_topi_funcs: # record this call + key = (self.topi_to_task[compute_func], serialize_args(args)) + if key not in self.task_collection: + self.task_collection.append(key) + + return compute_func.fdefault(*args) + _local_scope(topi_compute) + + # register topi schedule for "tracing" target + for topi_compute in self.topi_to_task: + for topi_schedule in self.topi_to_schedule[topi_compute]: + def _local_scope_(schedule_func): + """start a scope to hold the local function in for loop""" - _local_scope(func) + @schedule_func.register("tracing", ) + def _tracing_topi_compute(outs): + outs = [outs] if isinstance(outs, tensor.Tensor) else outs + return create_schedule([x.op for x in outs]) + _local_scope_(topi_schedule) def _register_topi_task(self): """register tuning wrapper for topi function""" @@ -125,17 +152,47 @@ class TaskExtractEnv: s = topi.generic.schedule_conv2d_transpose_nchw([C]) return s, [A, W, C] - def reset(self): - """Reset task collections""" + @register("topi_nn_dense") + def _topi_nn_dense(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, weight, bias = args + C = topi.nn.dense(*args, **kwargs) + s = topi.generic.schedule_dense([C]) + if bias is not None: + return s, [data, weight, bias, C] + return s, [data, weight, C] + + def reset(self, wanted_topi_funcs): + """Reset task collections + + Parameters + ---------- + wanted_topi_funcs: List of function + The topi function to be extracted + """ self.task_collection = [] + self.wanted_topi_funcs = wanted_topi_funcs def get_tasks(self): - """Get collected tasks""" + """Get collected tasks + + Returns + ------- + tasks: List of tuple(name, args) + A list of tasks extracted from the nnvm graph + """ return self.task_collection @staticmethod def get(): - """Get the single instance of TaskExtractEnv""" + """Get the single instance of TaskExtractEnv + + Returns + ------- + env: TaskExtractEnv + The single instance of TaskExtractEnv + """ if not TaskExtractEnv.current: TaskExtractEnv.current = TaskExtractEnv() return TaskExtractEnv.current @@ -144,8 +201,8 @@ class TaskExtractEnv: def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): """ Extract tuning tasks from a nnvm graph. - This function collects tunning tasks by building the graph - with a "dummy" target and tracing all the calls to topi. + This function collects tuning tasks by building the graph + with a "tracing" target and tracing all the calls to topi. Parameters ---------- @@ -158,7 +215,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): target: tvm.target.Target The compilation target symbols : Array of nnvm.symbol - Array of nnvm symbols + Array of nnvm symbols want to be tuned target_host: tvm.target.Target The host compilation target @@ -179,16 +236,16 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): warnings.warn("Symbol %s is not tunable, ignored" % sym_name) # run compiler to collect all TOPI calls during compilation - env.reset() + env.reset(topi_funcs) # disable logger temporarily old_state = logger.disabled logger.disabled = True - # use a dummy target to do a fake compile for collecting topi calls - dummy_target = _target.create("opencl -device=dummy") - with ApplyHistoryBest([], allow_fallback=True): - nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype) + # use a "tracing" target to do a fake compile for collecting topi calls + tracing_target = _target.create("llvm -device=tracing") + nnvm.compiler.engine.clear_cache() + nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype) logger.disabled = old_state diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index ea823c6f2..5a34353ac 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -567,15 +567,16 @@ class ConfigSpace(object): """ def __init__(self): # private dict to provide sugar - self.space_map = OrderedDict() # name -> space + self.space_map = OrderedDict() # name -> space self._collect = True self._length = None - self._entity_map = OrderedDict() + self._entity_map = OrderedDict() # name -> entity self._constraints = [] self.errors = [] self.template_key = None self.code_hash = None self.flop = 0 + self.is_fallback = False @staticmethod def axis(var): @@ -607,6 +608,15 @@ class ConfigSpace(object): If is 'candidate', try listed candidate. kwargs: dict extra arguments for policy + see examples below for how to use filter + + Examples + -------- + >>> # use custom candidates + >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]]) + + >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4 + >>> cfg.define_split('tile_y', y, policy='all', filter=lambda x: x.size[-1] <= 4) """ axes = [axis] return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) @@ -889,3 +899,45 @@ class ConfigEntity(ConfigSpace): def __repr__(self): return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash, self.index) + +class FallbackConfigEntity(ConfigSpace): + """The config entity created to support fallback""" + + def __init__(self): + super(FallbackConfigEntity, self).__init__() + self.is_fallback = True + + def fallback_split(self, name, constraints): + """Fallback a split knob + + Parameters + ---------- + name: str + name of the knob + constraints: List of int + The maximum tile size for every dimension. Value `-1` means no constraint. + + Examples + -------- + If you use cfg.define_split('tile_0', 128, num_outputs=3), + Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4] + + If you use cfg.define_split('tile_0', 49, num_outputs=3), + Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1] + """ + space = self.space_map[name] + assert len(constraints) == space.num_outputs + indices = np.arange(space.num_outputs) + + # '-1' means no constraint + constraints = [x if x != -1 else 1e10 for x in constraints] + + for entity in reversed(space.entities): + if all([entity.size[i] <= constraints[i] for i in indices]): + self._entity_map[name] = entity + return + + raise RuntimeError("Cannot find feasible fallback split entity for node: " + name) + + def __repr__(self): + return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index f8923fca5..ab52788c8 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -206,7 +206,7 @@ def args_to_workload(x): elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): return x.value elif x is None: - return None + return 0 else: raise RuntimeError('Do not support type "%s" in argument. Consider to use' 'primitive types only' % type(x)) diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index e11bb7a4f..3d7b249df 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -28,7 +28,7 @@ def _alias(name): return table.get(name, name) -def context(target, extra_files=None, allow_fallback=False): +def context(target, extra_files=None): """Return the dispatch context with pre-tuned parameters. The corresponding downloaded *.log files under tophub root path will be loaded. Users can also add their own files in argument `extra_files`. @@ -39,12 +39,9 @@ def context(target, extra_files=None, allow_fallback=False): The compilation target extra_files: list of str, optional Extra log files to load - allow_fallback: bool - Whether allow to use a fallback configuration if cannot find - tuned result. """ rootpath = AUTOTVM_TOPHUB_ROOT_PATH - best_context = ApplyHistoryBest([], allow_fallback=allow_fallback) + best_context = ApplyHistoryBest([]) if isinstance(target, str): target = _target.create(target) diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py index b92737ed5..b9d900e49 100644 --- a/python/tvm/autotvm/tuner/ga_tuner.py +++ b/python/tvm/autotvm/tuner/ga_tuner.py @@ -86,13 +86,9 @@ class GATuner(Tuner): # cross over indices = np.arange(len(genes)) - max_score = np.max(scores) - if max_score < 1e-8: - probs = np.empty_like(scores) - probs[:] = 1.0 / len(scores) - else: - scores /= max_score - probs = scores / np.sum(scores) + scores += 1e-8 + scores /= np.max(scores) + probs = scores / np.sum(scores) tmp_genes = [] for _ in range(self.pop_size): p1, p2 = np.random.choice(indices, size=2, replace=False, p=probs) diff --git a/python/tvm/autotvm/tuner/model_based_tuner.py b/python/tvm/autotvm/tuner/model_based_tuner.py index d1c1b16d3..62fc57f2e 100644 --- a/python/tvm/autotvm/tuner/model_based_tuner.py +++ b/python/tvm/autotvm/tuner/model_based_tuner.py @@ -8,7 +8,7 @@ import gc import numpy as np from .tuner import Tuner - +from ..env import GLOBAL_SCOPE class FeatureCache(object): """Feature cache manager for cache sharing between different cost models""" @@ -119,11 +119,9 @@ class CostModel(object): """ raise NotImplementedError() - def clone_new(self): - """Clone a new model with the same parameters. - This function will only copy hyperparameters of the tuner, not all the trained model - - This is used for deriving a base model conveniently + def spawn_base_model(self): + """Clone a base model with the same parameters. + The base model is used to fit history data in transfer learning. Returns ------- @@ -221,7 +219,9 @@ class ModelBasedTuner(Tuner): break self.trial_pt += 1 - if self.trial_pt >= len(self.trials): # trial list is empty, choose randomly + if self.trial_pt >= len(self.trials) - int(0.05 * self.plan_size): + # if the trial list is empty or + # the tuner is doing the last 5% trials (e-greedy), choose randomly index = np.random.randint(len(self.space)) while index in self.visited: index = np.random.randint(len(self.space)) @@ -264,18 +264,16 @@ class ModelBasedTuner(Tuner): self.train_ct += 1 def load_history(self, data_set): - # filter data, only pick the data with a same task - data = [] - for inp, res in data_set: - if inp.task.name == self.task.name and \ - inp.config.template_key == self.task.config_space.template_key: - data.append((inp, res)) - if not data: - return + # set in_tuning as True to make the feature extraction consistent + GLOBAL_SCOPE.in_tuning = True # fit base model - base_model = self.cost_model.clone_new() - base_model.fit_log(data, self.plan_size) + base_model = self.cost_model.spawn_base_model() + success = base_model.fit_log(data_set, self.plan_size) + + if not success: + GLOBAL_SCOPE.in_tuning = False + return # use base model to select initial points if not self.trials: @@ -285,6 +283,7 @@ class ModelBasedTuner(Tuner): self.trial_pt = 0 self.cost_model.load_basemodel(base_model) + GLOBAL_SCOPE.in_tuning = False def has_next(self): return len(self.visited) < len(self.space) diff --git a/python/tvm/autotvm/tuner/sa_model_optimizer.py b/python/tvm/autotvm/tuner/sa_model_optimizer.py index 6e1c373c1..1947c6dde 100644 --- a/python/tvm/autotvm/tuner/sa_model_optimizer.py +++ b/python/tvm/autotvm/tuner/sa_model_optimizer.py @@ -87,7 +87,7 @@ class SimulatedAnnealingOptimizer(ModelOptimizer): new_scores = model.predict(new_points) - ac_prob = np.exp((new_scores - scores) / t) + ac_prob = np.exp((new_scores - scores) / (t + 1e-2)) ac_index = np.random.random(len(ac_prob)) < ac_prob points[ac_index] = new_points[ac_index] diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py index 91004cba4..cffbb9798 100644 --- a/python/tvm/autotvm/tuner/tuner.py +++ b/python/tvm/autotvm/tuner/tuner.py @@ -31,6 +31,10 @@ class Tuner(object): self.best_measure_pair = None self.best_iter = 0 + # time to leave + self.ttl = None + self.n_trial = None + def has_next(self): """Whether has next untried config in the space @@ -76,7 +80,7 @@ class Tuner(object): measure_option: dict The options for how to measure generated code. You should use the return value ot autotvm.measure_option for this argument. - early_stopping: int + early_stopping: int, optional Early stop the tuning when not finding better configs in this number of trials callbacks: List of callable A list of callback functions. The signature of callback function is @@ -87,6 +91,8 @@ class Tuner(object): measure_batch = create_measure_batch(self.task, measure_option) n_parallel = getattr(measure_batch, 'n_parallel', 1) early_stopping = early_stopping or 1e9 + self.n_trial = n_trial + old_level = logger.level GLOBAL_SCOPE.in_tuning = True @@ -127,11 +133,12 @@ class Tuner(object): for callback in callbacks: callback(self, inputs, results) - if i > self.best_iter + early_stopping: + self.ttl = min(early_stopping + self.best_iter, n_trial) - i + if i >= self.best_iter + early_stopping: logger.debug("Early stopped. Best iter: %d.", self.best_iter) break - if error_ct > 50: + if error_ct > 150: logger.warning("Too many errors happen in the tuning. Now is in debug mode") logger.setLevel(logging.DEBUG) else: diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 178e92476..bda3ee26e 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -31,8 +31,12 @@ class XGBoostCostModel(CostModel): If is 'curve', use sampled curve feature (relation feature). Note on choosing feature type: - For single task tuning, 'itervar' and 'knob' is good. + For single task tuning, 'itervar' and 'knob' are good. 'itervar' is more accurate but 'knob' is much faster. + There are some constraints on 'itervar', if you meet + problems with feature extraction when using 'itervar', + you can swith to 'knob'. + For cross-shape tuning (e.g. many convolutions with different shapes), 'itervar' and 'curve' has better transferability, 'knob' is faster. @@ -46,8 +50,11 @@ class XGBoostCostModel(CostModel): The number of threads. log_interval: int, optional If is not none, the cost model will print training log every `log_interval` iterations. + upper_model: XGBoostCostModel, optional + The upper model used in transfer learning """ - def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25): + def __init__(self, task, feature_type, loss_type, num_threads=4, log_interval=25, + upper_model=None): super(XGBoostCostModel, self).__init__() if xgb is None: @@ -109,35 +116,51 @@ class XGBoostCostModel(CostModel): else: raise RuntimeError("Invalid feature type " + feature_type) - self.feature_cache = FeatureCache() + if upper_model: # share a same feature cache with upper model + self.feature_cache = upper_model.feature_cache + else: + self.feature_cache = FeatureCache() + self.upper_model = upper_model self.feature_extra_ct = 0 self.pool = None self.base_model = None - self.upper_model = None self._sample_size = 0 + self._reset_pool(self.space, self.target, self.task) - self._reset_pool() + def _reset_pool(self, space, target, task): + """reset processing pool for feature extraction""" + + if self.upper_model: # base model will reuse upper model's pool, + self.upper_model._reset_pool(space, target, task) + return + + self._close_pool() - def _reset_pool(self): - # reset processing pool for feature extraction - if self.pool: - self.pool.terminate() - self.pool.join() - del self.pool # use global variable to pass common arguments global _extract_space, _extract_target, _extract_task - _extract_space = self.space - _extract_target = self.target - _extract_task = self.task + _extract_space = space + _extract_target = target + _extract_task = task self.pool = multiprocessing.Pool(self.num_threads) + def _close_pool(self): + if self.pool: + self.pool.terminate() + self.pool.join() + self.pool = None + + def _get_pool(self): + if self.upper_model: + return self.upper_model._get_pool() + return self.pool + def _base_model_discount(self): - return 1.0 / (2 ** (self._sample_size / 50.0)) + return 1.0 / (2 ** (self._sample_size / 64.0)) def fit(self, xs, ys, plan_size): tic = time.time() - self._reset_pool() + self._reset_pool(self.space, self.target, self.task) x_train = self._get_feature(xs) y_train = np.array(ys) @@ -150,8 +173,12 @@ class XGBoostCostModel(CostModel): self._sample_size = len(x_train) if self.base_model: - dtrain.set_base_margin(self._base_model_discount() * - self.base_model.predict(xs, output_margin=True)) + discount = self._base_model_discount() + if discount < 0.05: # discard base model + self.base_model.upper_model = None + self.base_model = None + else: + dtrain.set_base_margin(discount * self.base_model.predict(xs, output_margin=True)) self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=8000, @@ -172,11 +199,19 @@ class XGBoostCostModel(CostModel): def fit_log(self, records, plan_size): tic = time.time() - self._reset_pool() - args = list(records) - logger.debug("XGB load %d entries from history log file", len(args)) + # filter data, only pick the data with a same task + data = [] + for inp, res in records: + if inp.task.name == self.task.name and \ + inp.config.template_key == self.task.config_space.template_key: + data.append((inp, res)) + + logger.debug("XGB load %d entries from history log file", len(data)) + # extract feature + self._reset_pool(self.space, self.target, self.task) + pool = self._get_pool() if self.fea_type == 'itervar': feature_extract_func = _extract_itervar_feature_log elif self.fea_type == 'knob': @@ -185,10 +220,21 @@ class XGBoostCostModel(CostModel): feature_extract_func = _extract_curve_feature_log else: raise RuntimeError("Invalid feature type: " + self.fea_type) - res = self.pool.map(feature_extract_func, args) - xs, ys = zip(*res) - xs, ys = np.array(xs), np.array(ys) + res = pool.map(feature_extract_func, data) + + # filter out feature with different shapes + fea_len = len(self._get_feature([0])[0]) + + xs, ys = [], [] + for x, y in res: + if len(x) == fea_len: + xs.append(x) + ys.append(y) + if len(xs) < 500: # no enough samples + return False + + xs, ys = np.array(xs), np.array(ys) x_train = xs y_train = ys y_max = np.max(y_train) @@ -212,6 +258,8 @@ class XGBoostCostModel(CostModel): logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs)) + return True + def predict(self, xs, output_margin=False): feas = self._get_feature(xs) dtest = xgb.DMatrix(feas) @@ -224,20 +272,12 @@ class XGBoostCostModel(CostModel): def load_basemodel(self, base_model): self.base_model = base_model - if isinstance(base_model, XGBoostCostModel): - # share feature cache - base_model.feature_cache = self.feature_cache - - # close thread pool - if base_model.pool: - base_model.pool.terminate() - base_model.pool.join() - del base_model.pool - self.base_model.upper_model = self - - def clone_new(self): + self.base_model._close_pool() + self.base_model.upper_model = self + + def spawn_base_model(self): return XGBoostCostModel(self.task, self.fea_type, self.loss_type, - self.num_threads, self.log_interval) + self.num_threads, self.log_interval, self) def _get_feature(self, indexes): """get features for indexes, run extraction if we do not have cache for them""" @@ -251,7 +291,7 @@ class XGBoostCostModel(CostModel): need_extract = [x for x in indexes if x not in fea_cache] if need_extract: - pool = self.pool if self.upper_model is None else self.upper_model.pool + pool = self._get_pool() feas = pool.map(self.feature_extract_func, need_extract) for i, fea in zip(need_extract, feas): fea_cache[i] = fea @@ -261,6 +301,9 @@ class XGBoostCostModel(CostModel): ret[i, :] = fea_cache[ii] return ret + def __del__(self): + self._close_pool() + _extract_space = None _extract_target = None diff --git a/python/tvm/autotvm/tuner/xgboost_tuner.py b/python/tvm/autotvm/tuner/xgboost_tuner.py index 237ac4e19..886c82a4d 100644 --- a/python/tvm/autotvm/tuner/xgboost_tuner.py +++ b/python/tvm/autotvm/tuner/xgboost_tuner.py @@ -20,8 +20,12 @@ class XGBTuner(ModelBasedTuner): If is 'curve', use sampled curve feature (relation feature). Note on choosing feature type: - For single task tuning, 'itervar' and 'knob' is good. + For single task tuning, 'itervar' and 'knob' are good. 'itervar' is more accurate but 'knob' is much faster. + There are some constraints on 'itervar', if you meet + problems with feature extraction when using 'itervar', + you can swith to 'knob'. + For cross-shape tuning (e.g. many convolutions with different shapes), 'itervar' and 'curve' has better transferability, 'knob' is faster. @@ -32,8 +36,7 @@ class XGBTuner(ModelBasedTuner): If is 'rank', use pairwise rank loss to train cost model. The cost model predicts relative rank score. num_threads: int, optional - The number of threads. - optimizer: str or ModelOptimizer, optional + The number of threads. optimizer: str or ModelOptimizer, optional If is 'sa', use a default simulated annealing optimizer. Otherwise it should be a ModelOptimizer object. diversity_filter_ratio: int or float, optional @@ -45,7 +48,7 @@ class XGBTuner(ModelBasedTuner): If is 0, output nothing. Otherwise, output debug information every `verbose` iterations. """ - def __init__(self, task, plan_size=32, + def __init__(self, task, plan_size=64, feature_type='itervar', loss_type='rank', num_threads=None, optimizer='sa', diversity_filter_ratio=None, log_interval=50): cost_model = XGBoostCostModel(task, @@ -62,3 +65,9 @@ class XGBTuner(ModelBasedTuner): super(XGBTuner, self).__init__(task, cost_model, optimizer, plan_size, diversity_filter_ratio) + + def tune(self, *args, **kwargs): # pylint: disable=arguments-differ + super(XGBTuner, self).tune(*args, **kwargs) + + # manually close pool to avoid multiprocessing issues + self.cost_model._close_pool() diff --git a/python/tvm/exec/tophub.py b/python/tvm/exec/tophub.py index 9dd951a52..9bfd68665 100644 --- a/python/tvm/exec/tophub.py +++ b/python/tvm/exec/tophub.py @@ -8,8 +8,8 @@ from ..autotvm.tophub import list_packages, download_package if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--download", type=str, nargs='+', - help="Target to download. Use 'all' to download for all targets") + parser.add_argument("-d", "--download", type=str, nargs='+', + help="The targets to download. Use 'all' to download for all targets") parser.add_argument("-l", "--list", action='store_true', help="List available packages") args = parser.parse_args() @@ -21,8 +21,7 @@ if __name__ == '__main__': print("-" * 41) for target, info in info: print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000))) - - if args.download: + elif args.download: info = list_packages() all_targets = [x[0] for x in info] if 'all' in args.download: @@ -34,3 +33,5 @@ if __name__ == '__main__': if t not in all_targets: print("Warning : cannot find tuned parameters of " + t + ". (ignored)") download_package(t) + else: + parser.print_help() diff --git a/python/tvm/target.py b/python/tvm/target.py index e2d780f75..9d5200661 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -263,6 +263,7 @@ def override_native_generic_func(func_name): "Keyword arguments cannot be used when invoking generic_func %s" % func_name) return generic_func_node(*args) fresult = decorate(fdefault, dispatch_func) + fresult.fdefault = fdefault fresult.register = register return fresult return fdecorate diff --git a/tests/python/unittest/test_autotvm_dispatch_context.py b/tests/python/unittest/test_autotvm_dispatch_context.py index 6c718e5bd..1f2a7e276 100644 --- a/tests/python/unittest/test_autotvm_dispatch_context.py +++ b/tests/python/unittest/test_autotvm_dispatch_context.py @@ -3,34 +3,48 @@ The dispatcher can choose which template to use according to the parameters of workload""" from collections import namedtuple +from tvm import autotvm from tvm.autotvm.task import dispatcher, DispatchContext -SimpleWorkload = namedtuple("SimpleWorkload", ["key"]) -SimpleConfig = namedtuple("SimpleConfig", ["template_key"]) +SimpleConfig = namedtuple('SimpleConfig', ('template_key', 'is_fallback')) def test_dispatch(): @dispatcher def my_dispatcher(a, b): - return SimpleWorkload(key=a + b) - - @my_dispatcher.register("spatial_pack") - def _sp_pack_add(cfg, a, b): - return b + 100 + return (a, b) @my_dispatcher.register("im2col") - def _im2col_add(cfg, a, b): - return a + 1 + def _im2col(cfg, a, b): + return a + + @my_dispatcher.register("spatial_pack") + def _spatial_pack(cfg, a, b): + return b class SimpleDispatcher(DispatchContext): def query(self, target, workload): - tkey = "spatial_pack" if workload.key > 2 else "im2col" - return SimpleConfig(tkey) + a, b = workload + tkey = "spatial_pack" if a + b > 2 else "im2col" + cfg = SimpleConfig(tkey, False) + return cfg with SimpleDispatcher(): - # im2col - assert my_dispatcher(1, 0) == 2 - # spack - assert my_dispatcher(1, 100) == 200 + # this will call im2col + assert my_dispatcher(1, 0) == 1 + + # this will call spatial pack + assert my_dispatcher(1, 100) == 100 + +def test_fallback(): + + @autotvm.template + def simple_template(a, b): + cfg = autotvm.get_config() + assert cfg.is_fallback + + simple_template(2, 3) + if __name__ == "__main__": test_dispatch() + test_fallback() diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py index 0320ef1c6..e51e34e95 100644 --- a/tests/python/unittest/test_autotvm_space.py +++ b/tests/python/unittest/test_autotvm_space.py @@ -1,7 +1,7 @@ """Test space definition primitives""" import tvm -from tvm.autotvm.task.space import ConfigSpace +from tvm.autotvm.task.space import ConfigSpace, FallbackConfigEntity def gemm_func(cfg, N): A = tvm.placeholder((N, N), name='A') @@ -26,5 +26,18 @@ def test_split(): assert len(cfg) == 64 assert len(cfg.space_map['tile_y']) == 8 + # test fallback + cfg = FallbackConfigEntity() + cfg.define_split('tile_n', cfg.axis(128), num_outputs=3) + cfg.fallback_split('tile_n', [-1, 8, 4]) + + assert cfg['tile_n'].size == [4, 8, 4] + + cfg = FallbackConfigEntity() + cfg.define_split('tile_n', cfg.axis(49), num_outputs=3) + cfg.fallback_split('tile_n', [-1, 8, 4]) + + assert cfg['tile_n'].size == [7, 7, 1] + if __name__ == '__main__': test_split() diff --git a/tests/python/unittest/test_autotvm_xgboost_model.py b/tests/python/unittest/test_autotvm_xgboost_model.py index 3488d0f59..58da219f2 100644 --- a/tests/python/unittest/test_autotvm_xgboost_model.py +++ b/tests/python/unittest/test_autotvm_xgboost_model.py @@ -12,7 +12,7 @@ from test_autotvm_common import get_sample_task, get_sample_records def test_fit(): task, target = get_sample_task() - records = get_sample_records(n=100) + records = get_sample_records(n=500) base_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank') base_model.fit_log(records, plan_size=32) @@ -20,8 +20,8 @@ def test_fit(): upper_model = XGBoostCostModel(task, feature_type='itervar', loss_type='rank') upper_model.load_basemodel(base_model) - xs = np.arange(100) - ys = np.arange(100) + xs = np.arange(10) + ys = np.arange(10) upper_model.fit(xs, ys, plan_size=32) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 48bb4fb02..a3945a4c9 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -27,7 +27,14 @@ def _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): @autotvm.task.dispatcher def conv2d_arm_cpu(data, kernel, strides, padding, layout, out_dtype): """TOPI compute callback. Mark this function as a dispatcher, so - this template can assign config according to workload""" + this template can assign config according to workload + + Returns + ------- + workload: Tuple + Dispatcher will use this workload to query corresponding config. + Then use cfg.template_key to call a registered template. + """ return _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype) @conv2d_arm_cpu.register(['direct']) @@ -70,8 +77,10 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs): def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile): assert layout == "NCHW", "Only support NCHW" - out_dtype = out_dtype or data.dtype + # create workload according to raw arguments + wkl = _conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype) + out_dtype = out_dtype or data.dtype N, CI, IH, IW = get_const_tuple(data.shape) if len(kernel.shape) == 4: pre_packed = False @@ -113,6 +122,18 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') # ==================================================================== + if cfg.is_fallback: + if num_tile == 2: + cfg.fallback_split('tile_co', [-1, 8]) + cfg.fallback_split('tile_oh', [-1, 2]) + cfg.fallback_split('tile_ow', [-1, 8]) + else: + cfg.fallback_split('tile_co', [-1, 16, 4]) + cfg.fallback_split('tile_oh', [-1, 1, 1]) + cfg.fallback_split('tile_ow', [-1, 1, 4]) + cfg['ann_reduce'].anns = ['unroll', 'unroll'] + cfg['ann_spatial'].anns = ['none', 'unroll', 'vec'] + VC = cfg["tile_co"].size[-1] VH = cfg["tile_oh"].size[-1] VW = cfg["tile_ow"].size[-1] @@ -145,8 +166,7 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n output = tvm.compute(oshape, lambda n, co, h, w: conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], name='output_unpack', tag='spatial_conv2d_output', - attrs={'workload': _conv_arg_to_workload(data, kernel, strides, padding, - layout, out_dtype)}) + attrs={'workload': wkl}) return output def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, @@ -212,6 +232,10 @@ def decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype): return _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size) def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size): + # create workload according to raw arguments + wkl = _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, + out_dtype, tile_size) + N, CI, IH, IW = get_const_tuple(data.shape) if len(kernel.shape) == 4: pre_computed = False @@ -333,10 +357,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, layout, out_dtype, tile_ output = tvm.compute((N, K, H, W), lambda n, k, h, w: Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], name='output', tag='winograd_conv2d_output', - attrs={'workload': _winograd_conv_arg_to_workload( - data, kernel, strides, padding, layout, out_dtype, tile_size)}) + attrs={'workload': wkl}) - # we have to manually assign effective GFLOP for winogard + # we have to manually assign effective GFLOP for winograd cfg.add_flop(2 * N * K * H * W * KH * KW * C) return output @@ -358,30 +381,29 @@ def _schedule_winograd(cfg, s, output, last): kernel, G = U.op.input_tensors s[G].compute_inline() eps, nu, k, c, kk, = s[U].op.axis - r_kh, r_kw = s[U].op.reduce_axis - s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk) - s[U].unroll(eps) - s[U].unroll(nu) - s[U].unroll(r_kh) - s[U].unroll(r_kw) - s[U].vectorize(kk) if autotvm.GLOBAL_SCOPE.in_tuning: # kernel transformation will be pre-computed during compilation, so we skip # this part to make tuning records correct - s[U].pragma(k, 'debug_skip_region') + s[U].pragma(eps, 'debug_skip_region') else: + r_kh, r_kw = s[U].op.reduce_axis + s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk) + for axis in [eps, nu, r_kh, r_kw]: + s[U].unroll(axis) + s[U].vectorize(kk) s[U].parallel(k) + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + # transform image DD = s.cache_read(d, 'global', [V]) s[B].compute_inline() eps, nu, b, c, bb = s[V].op.axis r_eps, r_nu = s[V].op.reduce_axis s[V].reorder(b, c, eps, nu, r_eps, r_nu, bb) - s[V].unroll(eps) - s[V].unroll(nu) - s[V].unroll(r_eps) - s[V].unroll(r_nu) + for axis in [eps, nu, r_eps, r_nu]: + s[V].unroll(axis) s[DD].compute_at(s[V], c) s[V].vectorize(bb) s[V].parallel(b) @@ -405,10 +427,8 @@ def _schedule_winograd(cfg, s, output, last): s[A].compute_inline() k, b, vh, vw = s[Y].op.axis r_eps, r_nu = s[Y].op.reduce_axis - s[Y].unroll(vh) - s[Y].unroll(vw) - s[Y].unroll(r_eps) - s[Y].unroll(r_nu) + for axis in [vh, vw, r_eps, r_nu]: + s[Y].unroll(axis) # output n, co, h, w = s[last].op.axis @@ -444,6 +464,7 @@ def _winograd_conv_arg_to_workload(data, kernel, strides, padding, layout, out_d [data, raw_kernel, strides, padding, layout, out_dtype]) +##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM ##### @conv2d_winograd_without_weight_transform.register(['arm_cpu']) @autotvm.task.dispatcher def winograd_ww_config_dispatcher_(data, kernel, strides, padding, layout, out_dtype, tile_size): @@ -472,6 +493,7 @@ def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): return s +##### REGISTER ALTER OP LAYOUT ##### @conv2d_alter_layout.register(["arm_cpu", "mali"]) def _alter_conv2d_layout(attrs, inputs, tinfos): """Alter op layout for pre-computing kernel transformation""" @@ -493,18 +515,30 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): # query config of this workload workload = _conv_arg_to_workload(tinfos[0], tinfos[1], strides, padding, layout, out_dtype) - cfg = autotvm.task.DispatchContext.current.query(tvm.target.current_target(), workload) + cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload) + + if cfg.is_fallback: # if is fallback, clear query cache and return None + context = autotvm.DispatchContext.current + while not isinstance(context, autotvm.FallbackContext): + context = context._old_ctx + context.clear_cache(tvm.target.current_target(), workload) + return None if cfg.template_key == 'direct': # packing weight tensor new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1]) return sym.conv2d(*copy_inputs, **new_attrs) else: # pre-compute weight transformation in winograd - tile_size = 4 + if "-device=arm_cpu" in tvm.target.current_target().options: + tile_size = 4 + VC = cfg['tile_k'].size[-1] + else: + from ..mali.conv2d import _pick_tile_size + tile_size = _pick_tile_size(tinfos[0], tinfos[1]) + VC = cfg['tile_bna'].val weight = sym.contrib.conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) CO, CI, KH, KW = get_const_tuple(tinfos[1].shape) - VC = cfg['tile_k'].size[-1] weight = sym.reshape(weight, shape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) weight = sym.transpose(weight, axes=[0, 1, 2, 4, 3]) diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py index 8aafc4363..e066a1e29 100644 --- a/topi/python/topi/arm_cpu/depthwise_conv2d.py +++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py @@ -14,16 +14,21 @@ autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct', # register customized schedule for arm cpu. @autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct') -def schedule_depthwise_conv2d_nchw_(cfg, outs): +def schedule_depthwise_conv2d_nchw_arm(cfg, outs): """Schedule depthwise conv2d Parameters ---------- cfg: ConfigEntity - The configuration of this tempalte + The configuration of this template outs: Array of Tensor The computation graph description of depthwise convolution2d in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nchw. """ outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs s = tvm.create_schedule([x.op for x in outs]) @@ -38,6 +43,11 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs): cfg.define_split('tile_h', h, num_outputs=2) cfg.define_split('tile_w', w, num_outputs=2) + if cfg.is_fallback: + cfg.fallback_split('tile_c', [-1, 8]) + cfg.fallback_split('tile_h', [-1, 2]) + cfg.fallback_split('tile_w', [-1, 8]) + # park data to vector form [n, c, h, w] -> [n, C, h, w, VC] A0 = s.cache_read(data_pad, "global", C) _, c, h, w = s[A0].op.axis diff --git a/topi/python/topi/x86/injective.py b/topi/python/topi/x86/injective.py index ac552903a..06847bf9f 100644 --- a/topi/python/topi/x86/injective.py +++ b/topi/python/topi/x86/injective.py @@ -29,7 +29,7 @@ def schedule_injective(outs): elif len(s[x].op.axis) >= 3: fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1]) s[x].parallel(fused) - else: + elif len(s[x].op.axis) >= 1: s[x].parallel(s[x].op.axis[0]) return s diff --git a/topi/tests/python/common.py b/topi/tests/python/common.py new file mode 100644 index 000000000..d992be929 --- /dev/null +++ b/topi/tests/python/common.py @@ -0,0 +1,12 @@ +"""Common utility for topi test""" + +def get_all_backend(): + """return all supported target + + Returns + ------- + targets: list + A list of all supported targets + """ + return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', + 'llvm -device=arm_cpu'] diff --git a/topi/tests/python/test_topi_bitserial_conv2d.py b/topi/tests/python/test_topi_bitserial_conv2d.py index 6df18483a..82af0006c 100644 --- a/topi/tests/python/test_topi_bitserial_conv2d.py +++ b/topi/tests/python/test_topi_bitserial_conv2d.py @@ -1,11 +1,8 @@ -import os import numpy as np import tvm import topi import topi.testing -from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -from tvm.contrib import util from tvm.contrib.pickle_memoize import memoize def generate_quantized_np(shape, bits, out_dtype): @@ -16,23 +13,23 @@ def generate_quantized_np(shape, bits, out_dtype): def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, activation_bits, weight_bits, dorefa): in_height = in_width = in_size - input_type='uint32' - out_dtype='int32' + input_type = 'uint32' + out_dtype = 'int32' with tvm.target.create('llvm'): A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W') B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, - out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) + out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) s = topi.generic.schedule_bitserial_conv2d_nchw([B]) a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) - dtype = A.dtype + @memoize("topi.tests.test_topi_bitseral_conv2d_nchw") def get_ref_data(): - a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) - w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) + a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) + w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) if dorefa: w_ = np.copy(w_np).astype(out_dtype) for x in np.nditer(w_, op_flags=['readwrite']): @@ -61,16 +58,16 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, - layout="NHWC", dorefa=dorefa) + layout="NHWC", dorefa=dorefa) s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) - dtype = A.dtype + @memoize("topi.tests.test_topi_bitseral_conv2d_nhwc") def get_ref_data(): - a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) - w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) + a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_type) + w_np = generate_quantized_np(get_const_tuple(w_shape), weight_bits, input_type) if dorefa: w_ = np.copy(w_np).astype(out_dtype) for x in np.nditer(w_, op_flags=['readwrite']): @@ -109,4 +106,4 @@ def test_bitserial_conv2d(): verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False) if __name__ == "__main__": - test_bitserial_conv2d() \ No newline at end of file + test_bitserial_conv2d() diff --git a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py index 3de954abc..de467818d 100644 --- a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py +++ b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py @@ -4,10 +4,6 @@ import numpy as np import tvm import topi import topi.testing -from topi.util import get_const_tuple -from tvm.contrib import util - -target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon' def generate_quantized_np(shape, bits, out_dtype): np.random.seed(0) @@ -17,20 +13,19 @@ def generate_quantized_np(shape, bits, out_dtype): # Verify that certain special instructions from the tensorize pass exist def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, - activation_bits, weight_bits, dorefa): + activation_bits, weight_bits, dorefa): in_height = in_width = in_size - input_type='uint32' - out_dtype='int32' + input_type = 'uint32' + out_dtype = 'int32' with tvm.target.arm_cpu('rasp3b'): A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, - layout="NHWC", dorefa=dorefa) + layout="NHWC", dorefa=dorefa) s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) - - func = tvm.build(s, [A, W, B], target) + func = tvm.build(s, [A, W, B], tvm.target.arm_cpu('rasp3b')) assembly = func.get_source('asm') matches = re.findall("vpadal", assembly) @@ -47,7 +42,6 @@ def test_bitserial_conv2d(): stride = 1 pad = 1 - verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) diff --git a/topi/tests/python/test_topi_bnn.py b/topi/tests/python/test_topi_bnn.py index 90abc68e6..cf9f377e9 100644 --- a/topi/tests/python/test_topi_bnn.py +++ b/topi/tests/python/test_topi_bnn.py @@ -28,7 +28,7 @@ def verify_binary_dense(batch, in_dim, out_dim): a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype) b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype) c_np = np.dot(a_np, b_np.T) - return (a_np, b_np, c_np) + return a_np, b_np, c_np a_np, b_np, c_np = get_ref_data() diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index f888033b3..4ed5b3170 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -1,5 +1,5 @@ """Test code for broadcasting operators.""" -import os +from common import get_all_backend import numpy as np import tvm import topi @@ -8,6 +8,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): # Build the logic and compile the function A = tvm.placeholder(shape=in_shape, name="A") B = fbcast(A, out_shape) + def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -21,16 +22,11 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): out_npy = np.broadcast_to(data_npy, out_shape) data_nd = tvm.nd.array(data_npy, ctx) out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx) - for _ in range(1): - foo(data_nd, out_nd) + foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("vulkan") - check_device("opencl") - check_device("cuda") - check_device("metal") - check_device("rocm") - check_device("nvptx") + for target in get_all_backend(): + check_device(target) check_device("sdaccel") @@ -45,9 +41,10 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, B = (tvm.var("B", dtype=dtype) if rhs_shape is None else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype)) C = ftopi(A, B) - if (isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr)): + if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr): assert(isinstance(C, tvm.expr.Expr)) return + def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -82,12 +79,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, foo(lhs_nd, rhs_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) - check_device("opencl") - check_device("vulkan") - check_device("cuda") - check_device("metal") - check_device("rocm") - check_device("nvptx") + for target in get_all_backend(): + check_device(target) check_device("sdaccel") def test_broadcast_to(): diff --git a/topi/tests/python/test_topi_clip.py b/topi/tests/python/test_topi_clip.py index ffc89aeb9..f1367463e 100644 --- a/topi/tests/python/test_topi_clip.py +++ b/topi/tests/python/test_topi_clip.py @@ -5,6 +5,7 @@ import topi from topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize +from common import get_all_backend def verify_clip(N, a_min, a_max, dtype): A = tvm.placeholder((N, N), dtype=dtype, name='A') @@ -34,7 +35,7 @@ def verify_clip(N, a_min, a_max, dtype): f(a, b) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in ['llvm', 'opencl', 'sdaccel']: + for device in get_all_backend(): check_device(device) def test_clip(): diff --git a/topi/tests/python/test_topi_conv2d.py b/topi/tests/python/test_topi_conv2d.py deleted file mode 100644 index 365fdf551..000000000 --- a/topi/tests/python/test_topi_conv2d.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Example code to do conv2d.""" -import os -import numpy as np -import tvm -from tvm import autotvm -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - - -def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - with tvm.target.arm_cpu(): - A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') - W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), 'NCHW', 'float32') - s = topi.generic.schedule_conv2d_nchw([B]) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_conv2d") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - return a_np, w_np, b_np - - a_np, w_np, b_np = get_ref_data() - - ctx = tvm.cpu(0) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, W, B], "llvm") - func(a, w, b) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - -def test_conv2d(): - with autotvm.tophub.context(tvm.target.arm_cpu('rasp3b'), allow_fallback=True): - verify_conv2d(1, 56, 64, 64, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d() diff --git a/topi/tests/python/test_topi_conv2d_hwcn.py b/topi/tests/python/test_topi_conv2d_hwcn.py index 1ff4b0247..af1afcb9e 100644 --- a/topi/tests/python/test_topi_conv2d_hwcn.py +++ b/topi/tests/python/test_topi_conv2d_hwcn.py @@ -43,14 +43,12 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=128, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device) - func2 = tvm.build(s2, [A, W, C], device) - func1(a, w, b) - func2(a, w, c) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: check_device(device) diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index c663384b8..6f367d10c 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -1,31 +1,41 @@ """Example code to do convolution.""" -import os + import numpy as np import tvm +from tvm import autotvm import topi import topi.testing from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): +from common import get_all_backend + +def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W') + bias = tvm.placeholder((num_filter, 1, 1), name='bias') a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) dtype = A.dtype @memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - c_np = np.maximum(b_np, 0) + c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) return a_np, w_np, b_np, c_np a_np, w_np, b_np, c_np = get_ref_data() @@ -38,66 +48,103 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p print("Running on target: %s" % device) with tvm.target.create(device): dW = topi.nn.dilate(W, (1, 1, dilation, dilation)) - B = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW') - C = topi.nn.relu(B) - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) + C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_conv2d_nchw([C]) + a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - no_unroll_explicit = device in ["cuda", "nvptx", "rocm"] - with tvm.build_config(auto_unroll_max_step=1400, - unroll_explicit=not no_unroll_explicit): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) - func1(a, w, b) - func2(a, w, c) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, c) + np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): check_device(device) def test_conv2d_nchw(): + autotvm.DispatchContext.current.silent = True + # ResNet18 workloads - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # ResNet50 workloads - verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0) - verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0) - verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0) - verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0) - verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0) - verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0) - verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0) - verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0) - verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) + verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) + verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) + verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) + verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) + verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) + verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) + verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) + verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) + verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) + + # bias, relu + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) + # dilation = 2 - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1, dilation=2) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2) + + # weird workloads + verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1) + verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2) + + # inception v3 workloads + verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0) + verify_conv2d_nchw(1, 32, 149, 32, 3, 1, 0) + verify_conv2d_nchw(1, 32, 147, 64, 3, 1, 1) + verify_conv2d_nchw(1, 64, 73, 80, 1, 1, 0) + verify_conv2d_nchw(1, 80, 73, 192, 3, 1, 0) + verify_conv2d_nchw(1, 192, 35, 64, 1, 1, 0) + verify_conv2d_nchw(1, 192, 35, 48, 1, 1, 0) + verify_conv2d_nchw(1, 48, 35, 64, 5, 1, 2) + verify_conv2d_nchw(1, 64, 35, 96, 3, 1, 1) + verify_conv2d_nchw(1, 96, 35, 96, 3, 1, 1) + verify_conv2d_nchw(1, 192, 35, 32, 1, 1, 0) + verify_conv2d_nchw(1, 256, 35, 64, 1, 1, 0) + verify_conv2d_nchw(1, 256, 35, 48, 1, 1, 0) + verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0) + verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0) + verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0) + # verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0) + # verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0) + # verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0) + # verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0) + # verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) + # verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) + # verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0) + # verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0) + # verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0) + # verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3) + # verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) + # verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0) + # verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0) + # verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3) + # verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0) + # verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0) + verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0) + verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0) + verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0) + verify_conv2d_nchw(1, 384, 8, 384, 3, 1, 1) + verify_conv2d_nchw(1, 1280, 8, 448, 1, 1, 0) + verify_conv2d_nchw(1, 448, 8, 384, 3, 1, 1) + verify_conv2d_nchw(1, 1280, 8, 192, 1, 1, 0) + verify_conv2d_nchw(1, 2048, 8, 320, 1, 1, 0) + verify_conv2d_nchw(1, 2048, 8, 384, 1, 1, 0) + verify_conv2d_nchw(1, 2048, 8, 448, 1, 1, 0) + verify_conv2d_nchw(1, 2048, 8, 192, 1, 1, 0) + if __name__ == "__main__": test_conv2d_nchw() diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py b/topi/tests/python/test_topi_conv2d_transpose_nchw.py index 0c9854000..5f65c038b 100644 --- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py +++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py @@ -6,14 +6,13 @@ import topi.testing from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple +from common import get_all_backend def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): in_height = in_width = in_size A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W') - B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], padding, A.dtype) - C = topi.nn.relu(B) a_shape = get_const_tuple(A.shape) w_shape = get_const_tuple(W.shape) @@ -36,22 +35,23 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, return print("Running on target: %s" % device) with tvm.target.create(device): + B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype) + C = topi.nn.relu(B) s1 = topi.generic.schedule_conv2d_transpose_nchw([B]) s2 = topi.generic.schedule_conv2d_transpose_nchw([C]) a = tvm.nd.array(a_np, ctx) w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.build_config(auto_unroll_max_step=128, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device) - func2 = tvm.build(s2, [A, W, C], device) - func1(a, w, b) - func2(a, w, c) - np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): check_device(device) diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py index 2df43eb30..92f95f3e0 100644 --- a/topi/tests/python/test_topi_dense.py +++ b/topi/tests/python/test_topi_dense.py @@ -6,13 +6,12 @@ import topi.testing from topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize +from common import get_all_backend def verify_dense(batch, in_dim, out_dim, use_bias=True): A = tvm.placeholder((batch, in_dim), name='A') B = tvm.placeholder((out_dim, in_dim), name='B') C = tvm.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) dtype = A.dtype # use memoize to pickle the test data for next time use @@ -36,6 +35,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): return print("Running on target: %s" % device) with tvm.target.create(device): + D = topi.nn.dense(A, B, C if use_bias else None) + D = topi.nn.relu(D) s = topi.generic.schedule_dense(D) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) @@ -45,13 +46,15 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): f(a, b, c, d) np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: + for device in get_all_backend(): check_device(device) def test_dense(): verify_dense(1, 1024, 1000, use_bias=True) verify_dense(1, 1024, 1000, use_bias=False) + verify_dense(2, 1024, 1000, use_bias=True) + if __name__ == "__main__": test_dense() diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index 3086054ba..8c27af839 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -2,11 +2,10 @@ import tvm import topi import topi.testing import numpy as np -from scipy import signal from topi.util import get_const_tuple from tvm.contrib.pickle_memoize import memoize -from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nhwc +from common import get_all_backend def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_multiplier, filter_height, stride, padding, dilation=1): in_width = in_height @@ -18,10 +17,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') - # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding) - ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) def check_device(device): ctx = tvm.context(device, 0) @@ -30,6 +25,10 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu return print("Running on target: %s" % device) with tvm.target.create(device): + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding) + ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) # schedule s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift) @@ -88,12 +87,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - check_device("opencl") - check_device("cuda") - check_device("metal") - check_device("rocm") - check_device("vulkan") - check_device("nvptx") + for device in get_all_backend(): + check_device(device) def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding, dilation=1): @@ -107,11 +102,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu DilatedFilter = topi.nn.dilate(Filter, (1, 1, dilation, dilation), name='DilatedFilter') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') - # declare - DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding) - ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) - Relu = topi.nn.relu(ScaleShift) - # schedule def check_device(device): ctx = tvm.context(device, 0) @@ -121,6 +111,11 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu print("Running on target: %s" % device) with tvm.target.create(device): + # declare + DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding) + ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) + Relu = topi.nn.relu(ScaleShift) + # schedule s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift) s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu) @@ -180,12 +175,9 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) - check_device("opencl") - check_device("cuda") - check_device("metal") - check_device("rocm") - check_device("vulkan") - check_device("nvptx") + for device in get_all_backend(): + check_device(device) + def test_depthwise_conv2d(): print("testing nchw") diff --git a/tutorials/autotvm/tune_nnvm_arm.py b/tutorials/autotvm/tune_nnvm_arm.py index f3d1c62bd..e85786037 100644 --- a/tutorials/autotvm/tune_nnvm_arm.py +++ b/tutorials/autotvm/tune_nnvm_arm.py @@ -312,7 +312,9 @@ def tune_and_evaluate(): # upload module to device print("Upload...") - remote = autotvm.measure.request_remote(device_key, timeout=10000) + remote = autotvm.measure.request_remote(device_key, + tracker_addr=('localhost', 9190), + timeout=10000) remote.upload(tmp.relpath(filename)) rlib = remote.load_module(filename) @@ -333,7 +335,6 @@ def tune_and_evaluate(): # We do not run the tuning in our webpage server since it takes too long. # Uncomment the following line to run by yourself. - # tune_and_evaluate() ###################################################################### -- GitLab