Skip to content
Snippets Groups Projects
Commit 6ea74d41 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by Tianqi Chen
Browse files

[AUTOTVM] Core part of auto-tuning module (#1312)

parent 7e7154f1
No related branches found
No related tags found
No related merge requests found
Showing
with 3092 additions and 1 deletion
......@@ -96,6 +96,7 @@ assign_source_group("Include" ${GROUP_INCLUDE})
file(GLOB COMPILER_SRCS
src/api/*.cc
src/arithmetic/*.cc
src/autotvm/*.cc
src/codegen/*.cc
src/codegen/stack_vm/*.cc
src/lang/*.cc
......
tvm.autotvm
-----------
.. automodule:: tvm.autotvm
tvm.autotvm.measure
~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.measure.measure
.. autoclass:: tvm.autotvm.measure.MeasureInput
:members:
.. autoclass:: tvm.autotvm.measure.MeasureResult
:members:
.. autofunction:: tvm.autotvm.measure.measure_option
.. autofunction:: tvm.autotvm.measure.create_measure_batch
tvm.autotvm.tuner
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.tuner
:members:
.. autoclass:: tvm.autotvm.tuner.Tuner
:members:
.. autoclass:: tvm.autotvm.tuner.RandomTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GridSearchTuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.GATuner
:members:
:inherited-members:
.. autoclass:: tvm.autotvm.tuner.XGBTuner
:members:
:inherited-members:
.. automodule:: tvm.autotvm.tuner.callback
:members:
tvm.autotvm.task
~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.task
:members:
.. automodule:: tvm.autotvm.task.task
:members:
.. automodule:: tvm.autotvm.task.space
:members:
tvm.autotvm.record
~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.record
:members:
......@@ -14,6 +14,7 @@ Python API
ndarray
container
function
autotvm
graph_runtime
rpc
bridge
......
......@@ -191,6 +191,7 @@ gallery_dirs = ["tutorials", "vta/tutorials"]
subsection_order = ExplicitOrder(
['../tutorials/language',
'../tutorials/optimize',
'../tutorials/autotvm',
'../tutorials/vta',
'../tutorials/topi',
'../tutorials/deployment',
......
......@@ -488,7 +488,7 @@ bool VerifyMemory(LoweredFunc func, int device_type);
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block.
* "max_threads_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
......
"""The auto-tuning module of tvm
This module includes:
* Tuning space definition API
* Efficient auto-tuners
* Tuning result and database support
* Distributed measurement to scale up tuning
"""
from . import database
from . import feature
from . import measure
from . import record
from . import task
from . import tuner
from . import util
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity
from .record import ApplyHistoryBest as apply_history_best
# pylint: disable=consider-using-enumerate,invalid-name
"""
Database of MeasureInput/MeasureResult pair.
This can be used for replaying measurement.
"""
import os
from .record import encode, decode, measure_str_key
class Database(object):
"""
Base class for a record database object.
"""
def load(self, inp, get_all=False):
"""
Load a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
get_all: bool, optional
Whether the latest result (or all matching results) should be returned
Returns
-------
rec: MeasureResult if previously saved, otherwise None
"""
raise NotImplementedError()
def save(self, inp, res, extend=False):
"""
Save a result based on an input's string key
Parameters
----------
inp: MeasureInput
to be translated into key for RedisDB
res: MeasureResult
to associate with key
extend:
Whether to extend existing MeasureResults if they exist
"""
raise NotImplementedError()
def filter_inputs(db, measure_inputs, retry=False):
"""
Filter a measure_inputs batch based on saved db results
Parameters
----------
db: Database
database object
measure_inputs: Array of MeasureInput
measure_inputs as expected in measure_batch
retry: bool
whether to retry if the saved result is a failure
Returns
-------
partial_results: Array of MeasureResult
a full list of result, where None denotes no corresponding saved result
unsaved: Array of MeasureInput
a list that only contains unsaved inputs
"""
partial_results = list()
unsaved = list()
for inp in measure_inputs:
res = db.load(inp)
if res is None or (retry and res.error_no != 0):
unsaved.append(inp)
partial_results.append(None)
else:
partial_results.append(res)
return partial_results, unsaved
class RedisDatabase(Database):
"""
Redis version of record database
"""
REDIS_PROD = 15
REDIS_LOCA = 14
REDIS_TEST = 13 # for unit test
REDIS_NIGHT_TEMP = 12 # for nightly report (will be flushed after every workload)
MAGIC_SPLIT = "$"
def __init__(self, db_index=REDIS_PROD):
import redis
if db_index == RedisDatabase.REDIS_TEST:
host = 'localhost'
else:
host = os.environ.get('TVM_FLEET_HOST')
self.db = redis.StrictRedis(host=host, port=6379, db=db_index)
self.db_index = db_index
def set(self, key, value):
self.db.set(key, value)
def get(self, key):
return self.db.get(key)
def load(self, inp, get_all=False):
current = self.get(measure_str_key(inp))
if current is not None:
current = str(current)
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records]
if get_all:
return results
return max(results, key=lambda result: result.timestamp)
return current
def save(self, inp, res, extend=False):
current = self.get(measure_str_key(inp))
if not extend or current is None:
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)]))
else:
current = current.split(RedisDatabase.MAGIC_SPLIT)
self.set(measure_str_key(inp),
RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)]))
def filter(self, func):
"""
Dump all of the records for a particular target
Parameters
----------
func: callable
The signature of the function is bool (MeasureInput, Array of MeasureResult)
Returns
-------
list of records (inp, result) matching the target
Examples
--------
get records for a target
>>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys)
"""
matched_records = list()
# may consider filtering in iterator in the future
for key in self.db:
current = self.get(key)
try:
records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)]
except TypeError: # got a badly formatted/old format record
continue
inps, results = zip(*records)
inp = inps[0]
if not func(inp, results):
continue
result = max(results, key=lambda res: res.timestamp)
matched_records.append((inp, result))
return matched_records
def flush(self):
self.db.flushdb()
class DummyDatabase(RedisDatabase):
"""
A database based on python dictionary for testing.
"""
def __init__(self):
# pylint: disable=super-init-not-called
self.db = {}
def set(self, key, value):
self.db[key] = value
def get(self, key):
return self.db.get(key)
def flush(self):
self.db = {}
"""Global configuration/variable scope for autotvm"""
class AutotvmGlobalScope(object):
current = None
def __init__(self):
self._old = AutotvmGlobalScope.current
AutotvmGlobalScope.current = self
self.cuda_target_arch = None
GLOBAL_SCOPE = AutotvmGlobalScope()
# pylint: disable=invalid-name
"""Extract feature of iter vars
There are two types of feature
1) Itervar feature
This feature is extracted based on loop variables.
Different loop structures will result in different shapes of feature
2) Curve sample feature (relation feature)
This feature is extracted by sampling relation curve.
This feature is invariant of loop structure.
"""
import struct
import numpy as np
from tvm import schedule, ir_pass, build_module, get_global_func, target as _target
def ana_lower(sch, args,
binds=None,
simple_mode=True):
"""Do lower while keeping all axes in IR
i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
assert simple_mode
return stmt
try:
_get_buffer_curve_sample_flatten = get_global_func(
"autotvm.feature.GetCurveSampleFeatureFlatten")
_get_itervar_feature = get_global_func("autotvm.feature.GetItervarFeature")
_get_itervar_feature_flatten = get_global_func("autotvm.feature.GetItervarFeatureFlatten")
except ValueError as e:
def raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("Cannot load autotvm c++ API")
_get_buffer_curve_sample_flatten = _get_itervar_feature = _get_itervar_feature_flatten = \
raise_error
def get_itervar_feature(sch, args, take_log=False):
"""get features of iter vars
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
features of every axis in the IR, see doc/features.md for detail
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature(stmt, take_log)
# convert tvm node to python type
ret = []
for row in feas:
tmp = []
tmp.append([row[0][0].value, row[0][1]])
for item in row[1:]:
tmp.append([item[0].value] + [x.value for x in item[1:]])
ret.append(tmp)
return ret
def flatten_itervar_feature(fea):
"""flatten features into one-dimensional feature vectors
Parameters
----------
fea: list
return value of get_itervar_feature
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
flatten = []
for axis in fea:
for pair in axis[1:]:
flatten.append(pair[1:])
return np.concatenate(flatten)
def get_itervar_feature_flatten(sch, args, take_log=True):
"""get flatten features of iter vars
this is equivalent to get_itervar_feature + flatten_itervar_feature, but much faster.
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
take_log: bool
whether take log of numerical statics
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_itervar_feature_flatten(stmt, take_log)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
def get_flatten_name(fea):
""" Get names of feature after flatten.
Parameters
----------
fea: list or str
return value of get_itervar_feature or a line of logfile
Returns
-------
feature_names: Array of str
"""
feature_name = {
"_attr_": ["length", "nest_level", "topdown", "bottomup"] +
["ann_%d" % i for i in range(20)],
"_arith_": ["add", "mul", "div"],
"buf_touch": ["stride", "mod", "count", "reuse", "T_count", "T_reuse"],
}
if isinstance(fea, str):
from .record import decode
# flatten line to feature
line = fea
inp, _ = decode(line)
target = _target.create(inp.target)
with target:
s, args = inp.template.instantiate(inp.config)
fea = get_itervar_feature(s, args)
names = []
ct = 0
for row in fea:
var_name = str(row[0][1])
for pair in row[1:]:
key = pair[0]
if key in feature_name:
name_list = feature_name[key]
else:
name_list = feature_name["buf_touch"]
for i in range(len((pair[1:]))):
names.append(".".join(["f%d" % ct, var_name, key, name_list[i]]))
ct += 1
return names
def get_buffer_curve_sample_flatten(sch, args, sample_n=30):
"""
Get flatten curve sample feature (relation feature)
Parameters
----------
sch: tvm.schedule.Schedule
args: Array of tvm.tensor.Tensor
the buffer args for lower
sample_n: int
number of sample points along one dimension
Returns
-------
flatten_feature: np.ndarray
one-dimensional vector
"""
stmt = ana_lower(sch, args, simple_mode=True)
feas = _get_buffer_curve_sample_flatten(stmt, sample_n, False)
feas = struct.unpack('%df' % (len(feas)//4), feas)
return feas
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo
from .measure import create_measure_batch, measure_option
from .measure_methods import request_remote
from .local_executor import LocalExecutor
from .executor import Future, Executor
""" Abstraction for asynchronous job execution """
class Executor(object):
"""
Base abstract executor interface for asynchronous job submission.
Allows submit asynchronous jobs and returns the Future object.
"""
# timeout for jobs that may hang
DEFAULT_TIMEOUT = 60
def submit(self, func, *args, **kwargs):
"""
Pass task (function, arguments) to the Executor.
Parameters
----------
func : callable
function to be run by a worker
args : list or tuple, optional
arguments passed to the function
kwargs : dict, optional
The keyword arguments
Returns
-------
future : Future
Future object wrapping the task which can be used to
collect the task's result.
"""
raise NotImplementedError()
class Future(object):
"""
Base class of the future object.
The implementations can return object of subclass of this.
This objects encapsulates the asynchronous execution of task
submitted to another thread, or another worker for execution.
Future objects store the state of tasks--can be polled for
result or a blocking call to retrieve the result can be used.
"""
def done(self):
"""
Return True if job was successfully cancelled or finished running.
"""
raise NotImplementedError()
def get(self, timeout=None):
"""
Get the result. This will block until the result is available.
Parameters
----------
timeout : int or float, optional
Maximum number of seconds to wait before it timeouts.
If not specified, it means we block until the result is available.
Returns
-------
result : Any
The result returned by the submitted function.
Raises
------
TimeoutError : if the result call timeouts.
"""
raise NotImplementedError()
class FutureError(RuntimeError):
"""Base error class of all future events"""
pass
# pylint:disable=redefined-builtin
class TimeoutError(FutureError):
"""Error raised when a task is timeout."""
pass
class ExecutionError(FutureError):
"""
Error raised when future execution crashes or failed.
"""
pass
"""Local based implementation of the executor using multiprocessing"""
import signal
from multiprocessing import Process, Queue
try:
from queue import Empty
except ImportError:
from Queue import Empty
import psutil
from . import executor
def kill_child_processes(parent_pid, sig=signal.SIGTERM):
"""kill all child processes recursively"""
try:
parent = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for process in children:
try:
process.send_signal(sig)
except psutil.NoSuchProcess:
return
def _execute_func(func, queue, args, kwargs):
"""execute function and return the result or exception to a queue"""
try:
res = func(*args, **kwargs)
except Exception as exc: # pylint: disable=broad-except
res = exc
queue.put(res)
def timeout_monitor(queue, timeout, func, args, kwargs):
"""A wrapper to support timeout of a function call"""
# start a new process for timeout (cannot use thread because we have c function)
p = Process(target=_execute_func, args=(func, queue, args, kwargs))
p.start()
p.join(timeout=timeout)
alive = p.is_alive()
kill_child_processes(p.pid)
p.terminate()
p.join()
if alive:
queue.put(executor.TimeoutError())
else:
if queue.empty():
queue.put(executor.ExecutionError("Fatal error in local executor"))
class LocalFuture(executor.Future):
"""Local wrapper for the future
Parameters
----------
process: multiprocessing.Process
process for running this task
queue: multiprocessing.Queue
queue for receiving the result of this task
"""
def __init__(self, process, queue):
self._done = False
self._process = process
self._queue = queue
def done(self):
self._done = self._done or not self._queue.empty()
return self._done
def get(self, timeout=None):
try:
res = self._queue.get(block=True, timeout=timeout)
except Empty:
raise executor.TimeoutError()
if self._process.is_alive():
kill_child_processes(self._process.pid)
self._process.terminate()
self._process.join()
self._queue.close()
self._queue.join_thread()
self._done = True
del self._queue
del self._process
return res
class LocalFutureNoFork(executor.Future):
"""Local wrapper for the future.
This is a none-fork version of LocalFuture.
Use this for the runtime that does not support fork (like cudnn)
"""
def __init__(self, result):
self._result = result
def done(self):
return True
def get(self, timeout=None):
return self._result
class LocalExecutor(executor.Executor):
"""Local executor that runs workers on the same machine with multiprocessing."""
def __init__(self, timeout=None):
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
def submit(self, func, *args, **kwargs):
"""
Note
----------
By default, the executor will fork a new process for a new job
But some runtime does not support fork (e.g. cuda runtime, cudnn).
In this circumstance, you should set 'fork_new_process' to False in kwargs
"""
fork_new_process = kwargs.pop('fork_new_process', True)
if not fork_new_process:
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(1)
process = Process(target=timeout_monitor,
args=(queue, self.timeout, func, args, kwargs))
process.start()
return LocalFuture(process, queue)
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import time
from collections import namedtuple
import numpy as np
from ... import build, nd, target as _target
from ...contrib.util import tempdir
from ...rpc.tracker import Tracker
from ...rpc.server import Server
from ..util import get_const_tuple
from .local_executor import LocalExecutor
class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
"""
Stores all the necessary inputs for a measurement.
Parameters
----------
target : tvm.target.Target
The target device
task : task.Task
Task function
config : ConfigEntity
Specific configuration.
"""
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
"""
Stores all the results of a measurement
Parameters
----------
costs: Array of float or Array of Exception
If no error occurs for this measurement, it is an array of measured running times.
If some error occurs during the measurement, it is an array of the exception objections.
error_no: int
Denote error type, defined by MeasureErrorNo
all_cost: float
All cost of this measure, including rpc, compilation, test runs
timestamp: float
The absolute time stamp when we finish measurement.
"""
class MeasureErrorNo(object):
"""Error type for MeasureResult"""
NO_ERROR = 0 # no error
INSTANTIATION_ERROR = 1 # error when calling template function
COMPILE_HOST = 2 # error when compiling code on host (e.g. tvm.build)
COMPILE_DEVICE = 3 # error when compiling code on device (e.g. opencl JIT on device)
RUNTIME_DEVICE = 4 # error when run program on device
WRONG_ANSWER = 5 # answer is wrong when compared to a golden output
FLEET_ERROR = 6 # error of measure infrastructure
def measure_option(mode,
number=1,
repeat=1,
timeout=60,
parallel_num=1,
pack_size=1,
check_correctness=False,
build_option=None,
replay_db=None,
save_to_replay_db=True,
rpc_device_key=None,
rpc_priority=1,
rpc_timeout=60,
rpc_tracker_addr=None,
use_ndk=False,
custom_measure_batch=None):
"""Configure how to do measurement
Parameters
----------
mode: str
'local': use the local device for measurement. In this mode,
the tuner starts a tracker and a RPC server silently for the user.
'rpc': request devices for measurement from rpc tracker. In this mode,
you should start a rpc tracker in a separate processing.
'custom': use custom measure function
'local-nofork': use local device for measure but does not use multiprocessing.
This mode is suitable for debug, but does not support timeout and parallel.
number : int, optional
Number of times to do the measurement for average
repeat : int, optional
Number of times to repeat the measurement.
In total, the generated code will be run (1 + number x repeat) times,
where the first one is warm up. The returned result contains `repeat` costs,
each of which is the average of `number` test run.
timeout: int, optional
Timeout for a whole batch. TimeoutError will be returned as the result if a
task timeouts.
parallel_num: int, optional
The number of measurement task that can run in parallel.
Set this according to the number of cpu cores (for compilation) and
the number of devices you have (for measuring generate code).
pack_size : int, optional
Number of configs to measure in one RPC call.
Usually this can be set to 1. If your device has high cost to establish a rpc connection,
set this higher.
check_correctness: bool
Whether check correctness after measurement.
build_option: Dict, optional
Build options for tvm.build_config
replay_db : Database, optional
The database that we retrieve saved MeasureResults from
save_to_replay_db: bool, optional
Whether save measure result to database. This is useless when replay_db is None
rpc_priority: int, optional
Priority of this task, used by scheduler in tracker
rpc_device_key: str, optional
The device key of registered devices in tracker
rpc_timeout: int, optional
Timeout of rpc session
rpc_tracker_addr: Tuple(str, int), optional
The address of rpc tracker in Tuple(host, port) format.
If is set, will use this address.
If is not set, will use environment variable "TVM_TRACKER_HOST" and "TVM_TRACKER_PORT"
use_ndk: bool, option
Whether export requires ndk
custom_measure_batch: callable, optional
custom measure function
Returns
-------
options: dict
A dict to store all options
"""
return {
'mode': mode,
'number': number,
'repeat': repeat,
'timeout': timeout,
'parallel_num': parallel_num,
'pack_size': pack_size,
'check_correctness': check_correctness,
'build_option': build_option,
'replay_db': replay_db,
'save_to_replay_db': save_to_replay_db,
'rpc_device_key': rpc_device_key,
'rpc_priority': rpc_priority,
'rpc_timeout': rpc_timeout,
'rpc_tracker_addr': rpc_tracker_addr,
'use_ndk': use_ndk,
'custom_measure_batch': custom_measure_batch
}
def create_measure_batch(task, options):
"""Get a standard measure_batch function.
Parameters
----------
task: tvm.autotvm.task.Task
The tuning task
options: dict
The option for measuring generated code.
You should use the return value of :any:`autotvm.measure_option` for this argument
Returns
-------
measure_batch: callable
a callback function to measure a batch of configs
"""
from . import measure_methods
from ..database import filter_inputs
mode = options['mode']
number, repeat = options['number'], options['repeat']
timeout, parallel_num = options['timeout'], options['parallel_num']
pack_size = options['pack_size']
check_correctness = options['check_correctness']
build_option = options['build_option']
replay_db = options['replay_db']
save_to_replay_db = options['save_to_replay_db']
rpc_device_key = options['rpc_device_key']
rpc_priority, rpc_timeout = options['rpc_priority'], options['rpc_timeout']
use_ndk = options['use_ndk']
custom_measure_batch = options['custom_measure_batch']
kwargs = {}
executor = LocalExecutor(timeout=timeout)
if mode == 'local':
# start temporary rpc tracker and rpc server for the user
tracker = Tracker('localhost', port=9000, port_end=10000,
silent=True)
rpc_device_key = '$local$device$%d' % tracker.port
server = Server('localhost', port=9000, port_end=10000,
key=rpc_device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_tracker_addr'] = (tracker.host, tracker.port)
kwargs['rpc_timeout'] = timeout
kwargs['tmp_dir'] = tempdir()
elif mode == 'rpc':
fmeasure = measure_methods.measure_rpc
kwargs['rpc_device_key'] = rpc_device_key
kwargs['rpc_priority'] = rpc_priority
kwargs['rpc_timeout'] = rpc_timeout
kwargs['use_ndk'] = use_ndk
kwargs['tmp_dir'] = tempdir()
assert rpc_device_key, "In rpc mode, a rpc_device_key must be provided"
elif mode == "custom":
assert callable(custom_measure_batch), "In custom mode, custom_measure_func " \
"must be a callable object"
elif mode == 'local-nofork':
fmeasure = measure_methods.measure_local
kwargs['fork_new_process'] = False
else:
raise RuntimeError("Invalid mode: " + mode)
if 'cuda' in task.target.keys and 'rpc_device_key' in kwargs: # query cuda device info
add_cuda_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
if 'opencl' in task.target.keys and 'rpc_device_key' in kwargs:
add_opencl_device_info(kwargs['rpc_device_key'], kwargs.get('rpc_tracker_addr'), kwargs)
if check_correctness:
# use llvm to generate a reference input/output
# this option works for tuning topi, but might not work for you custom op
with _target.create("llvm"):
s, arg_bufs = task.instantiate(task.config_space.get(0))
ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in ref_input]
func(*tvm_buf)
ref_output = [x.asnumpy() for x in tvm_buf]
kwargs['ref_input'], kwargs['ref_outpu'] = ref_input, ref_output
def measure_batch(measure_inputs):
"""measure the time cost for a batch of configs in real machines"""
if replay_db is not None:
partial_results, measure_inputs =\
filter_inputs(replay_db, measure_inputs, retry=False)
# pack configs
input_packs = []
for i in range(0, len(measure_inputs), pack_size):
input_packs.append(measure_inputs[i:i + pack_size])
# send to measure
futures = []
for input_pack in input_packs:
future = executor.submit(
fmeasure, input_pack,
number=number,
repeat=repeat,
build_option=build_option,
**kwargs
)
futures.append(future)
# transform results
results = []
for future in futures:
result = future.get()
if isinstance(result, Exception):
if mode == 'local-nofork':
# debug usage, raise exception
raise result
tstamp = time.time()
results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
timeout, tstamp)] * pack_size)
else:
results.extend(result)
if replay_db is not None:
if save_to_replay_db: # save result to database
for measure_input, result in zip(measure_inputs, results):
replay_db.save(measure_input, result)
result_idx = 0
for i in range(len(partial_results)):
if partial_results[i] is None:
partial_results[i] = results[result_idx]
result_idx += 1
return partial_results
return results
if mode == 'custom':
measure_batch = custom_measure_batch
measure_batch.parallel_num = parallel_num
if mode == 'local':
measure_batch.aux_objects = {"server": server, "tracker": tracker}
return measure_batch
def add_cuda_device_info(device_key, rpc_tracker_addr, kwargs):
"""Query cuda device info. This is used to set the flags for nvcc compiler
and check the validity of a generated code."""
from .measure_methods import request_remote
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context('cuda', 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
def add_opencl_device_info(device_key, rpc_tracker_addr, kwargs):
"""Query opencl device info. This is used to check the validity of a generated code."""
from .measure_methods import request_remote
remote = request_remote(device_key, rpc_tracker_addr)
ctx = remote.context('opencl', 0)
max_dims = ctx.max_thread_dimensions
kwargs['check_gpu'] = {
'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
'max_threads_per_block': ctx.max_threads_per_block,
'max_thread_x': max_dims[0],
'max_thread_y': max_dims[1],
'max_thread_z': max_dims[2],
}
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
"""
Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output
"""
import logging
import os
import time
from random import getrandbits
import numpy as np
from ...contrib import ndk, nvcc
from ... import rpc, ir_pass, build, build_config, nd, context, TVMError, register_func
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
from .measure import MeasureResult, MeasureErrorNo
from ..task.space import InstantiationError
class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
measure side """
pass
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
"""request a remote session
Parameters
----------
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format
priority: int, optional
priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
Returns
------
session: RPCSession
"""
# connect to the tracker
if tracker_addr:
host = tracker_addr[0]
port = tracker_addr[1]
else:
host = os.environ['TVM_TRACKER_HOST']
port = int(os.environ['TVM_TRACKER_PORT'])
tracker = rpc.connect_tracker(host, port)
remote = tracker.request(device_key, priority=priority,
session_timeout=timeout)
return remote
def _measure_generic(fbuild, input_pack, ref_input, ref_output):
"""Generic measurement function
Parameters
----------
fbuild : function takes MeasureInput returns tuple of (time_func, ctx)
The build function used to build each input.
input_pack : list of MeasureInput
The inputs we need to evaluate
ref_input: Array of np.ndarray
Reference input for checking correctness
ref_output: Array of np.ndarray
Reference output for checking correctness
Returns
-------
res_pack : array of MeasureResult
The list of execution result of measurement.
"""
res_pack = []
for inp in input_pack:
tic = time.time()
try:
time_f, ctx, arg_bufs = fbuild(inp)
except TVMError as exc:
tstamp = time.time()
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
if "InstantiationError" in msg:
try:
msg = msg.split('\n')[-2].split(": ")[1]
except Exception: # pylint: disable=broad-except
pass
res_pack.append(MeasureResult((InstantiationError(msg),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
else:
res_pack.append(MeasureResult((RuntimeError(msg),),
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp))
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((e,),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
# measure time
errno = MeasureErrorNo.NO_ERROR
try:
if ref_input:
args = [nd.array(x, ctx) for x in ref_input]
else:
args = [nd.array(np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype),
ctx) for x in arg_bufs]
costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
costs.sort()
costs = tuple(costs[1:-1])
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
if "Stack trace returned" in msg:
msg = msg[:msg.index("Stack trace returned")]
costs = (RuntimeError(msg),)
errno = MeasureErrorNo.RUNTIME_DEVICE
tstamp = time.time()
res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
return res_pack
def _build_func(inp, build_option, kwargs):
"""Build function module. Exception will be raised when error occurs"""
with inp.target:
s, args = inp.task.instantiate(inp.config)
if not inp.config.valid():
raise InstantiationError(inp.config.errors)
code_hash = getattr(s, 'code_hash', None)
if inp.config.code_hash != code_hash:
raise HashMismatchError('got {0:s}, expected {1:s}'
.format(str(inp.config.code_hash), str(code_hash)))
opts = build_option or {}
if "check_gpu" in kwargs:
values = kwargs['check_gpu']
# Add gpu verify pass to filter out invalid configs in advance.
# This can accelerate the tuning process
check_keys = ['max_shared_memory_per_block', 'max_threads_per_block',
'max_thread_x', 'max_thread_y', 'max_thread_z']
opts["add_lower_pass"] = [
(2, gpu_verify_pass(**{key: values[key] for key in check_keys}))]
if 'cuda_arch' in kwargs:
set_cuda_target_arch(kwargs['cuda_arch'])
with build_config(**opts):
func = build(s, args, target_host=inp.task.target_host)
return func, args
def measure_rpc(input_pack,
rpc_device_key,
number,
repeat=1,
build_option=None,
rpc_tracker_addr=None,
rpc_priority=1,
rpc_timeout=60,
tmp_dir=None,
**kwargs):
"""Measure the time cost on a device by rpc
Parameters
----------
input_pack : list of MeasureInput
The inputs we need to evaluate
rpc_device_key: str
The device key of registered devices in tracker
number : int
Number of times to get the running measurement
repeat : int, optional
How many times we want to repeat the measurement.
build_option: Dict
build options for tvm.build_config
rpc_tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format
If is none, will use environment variable
rpc_priority: int, optional
priority of this task, used by scheduler in tracker
rpc_timeout: int, optional
timeout of the rpc session
tmp_dir: tvm.contrib.util.TempDirectory, optional
directory to store temp file
kwargs: dict, optional
Additional key word arguments
Returns
-------
res_pack : Array of MeasureResult
The list of execution results of measurement.
"""
def _fbuild(inp):
""" Local build function."""
func, args = _build_func(inp, build_option, kwargs)
if not kwargs.get('use_ndk', False):
file_name = "tmp_func_%0x.tar" % getrandbits(64)
path = tmp_dir.relpath(file_name)
func.export_library(path)
else:
file_name = "tmp_func_%0x.so" % getrandbits(64)
path = tmp_dir.relpath(file_name)
func.export_library(path, ndk.create_shared)
remote = request_remote(rpc_device_key, rpc_tracker_addr, rpc_priority, rpc_timeout)
remote.upload(path)
func = remote.load_module(file_name)
ctx = remote.context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
return time_f, ctx, args
ret = _measure_generic(_fbuild, input_pack,
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
return ret
def measure_local(input_pack,
number,
repeat=1,
build_option=None,
**kwargs):
"""Measure the time cost on a local machine.
Parameters
----------
input_pack : list of MeasureInput
The inputs we need to evaluate
number : int
Number of times to get the running measurement
repeat : int, optional
How many times we want to repeat the measurement.
build_option: dict, optional
Build options for tvm.build_config
kwargs: dict, optional
Additional key word arguments
Returns
-------
res_pack : Array of MeasureResult
The list of execution results of measurement.
"""
def _fbuild(inp):
""" Local build function """
func, args = _build_func(inp, build_option, kwargs)
ctx = context(str(inp.target), 0)
time_f = func.time_evaluator(
func.entry_name, ctx, number=number, repeat=repeat)
return time_f, ctx, args
ret = _measure_generic(_fbuild, input_pack,
kwargs.get("ref_input", None), kwargs.get("ref_output", None))
return ret
def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel
This pass will check shared memory size and number of threads per block.
"""
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
@register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate ptx code for better optimization"""
ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
return ptx
def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler"""
AutotvmGlobalScope.current.cuda_target_arch = arch
# pylint: disable=superfluous-parens, redefined-outer-name, redefined-outer-name,pointless-string-statement
# pylint: disable=consider-using-enumerate,invalid-name
"""Tuning record and serialization format"""
import argparse
import base64
import logging
import multiprocessing
import pickle
import json
import time
from collections import OrderedDict
import numpy as np
from .. import target, build, lower
from . import task
from .task import DispatchContext, ConfigEntity
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
try: # convert unicode to str for python2
_unicode = unicode
except NameError:
_unicode = ()
def measure_str_key(inp, include_config=True):
""" get unique str key for MeasureInput
Parameters
----------
inp: MeasureInput
input for the measure
include_config: bool, optional
whether includes config in the str key
Returns
-------
key: str
The str representation of key
"""
config_str = str(inp.config) if include_config else ""
return "".join([str(inp.target), inp.task.name, str(inp.task.args),
str(inp.task.kwargs), config_str])
def encode(inp, result, protocol='json'):
"""encode (MeasureInput, MeasureResult) pair to a string
Parameters
----------
inp: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
pair of input/result
protocol: str
log protocol, json or pickle
Returns
-------
row: str
a row in the logger file
"""
if protocol == 'json':
json_dict = {
"i": (str(inp.target),
inp.task.name, inp.task.args, inp.task.kwargs,
inp.task.workload,
inp.config.to_json_dict()),
"r": (result.costs if result.error_no == 0 else (1e9,),
result.error_no,
result.all_cost,
result.timestamp),
"v": AUTOTVM_LOG_VERSION
}
return json.dumps(json_dict)
elif protocol == 'pickle':
row = (str(inp.target),
str(base64.b64encode(pickle.dumps([inp.task.name,
inp.task.args,
inp.task.kwargs,
inp.task.workload])).decode()),
str(base64.b64encode(pickle.dumps(inp.config)).decode()),
str(base64.b64encode(pickle.dumps(tuple(result))).decode()))
return '\t'.join(row)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
def decode(row, protocol='json'):
"""Decode encoded record string to python object
Parameters
----------
row: str
a row in the logger file
protocol: str
log protocol, json or pickle
Returns
-------
input: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
"""
# pylint: disable=unused-variable
if protocol == 'json':
row = json.loads(row)
tgt, task_name, task_args, task_kwargs, workload, config = row['i']
tgt = target.create(str(tgt))
def clean_json_to_python(x):
"""1. convert all list in x to tuple (hashable)
2. convert unicode to str for python2
"""
if isinstance(x, list):
return tuple([clean_json_to_python(a) for a in x])
if isinstance(x, _unicode):
return str(x)
return x
tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
tsk.workload = clean_json_to_python(workload)
config = ConfigEntity.from_json_dict(config)
inp = MeasureInput(tgt, tsk, config)
result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["r"]])
return inp, result
elif protocol == 'pickle':
items = row.split("\t")
tgt = target.create(items[0])
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
config = pickle.loads(base64.b64decode(items[2].encode()))
result = pickle.loads(base64.b64decode(items[3].encode()))
tsk = task.Task(task_tuple[0], task_tuple[1])
tsk.workload = task_tuple[3]
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
else:
raise RuntimeError("Invalid log protocol: " + protocol)
def load_from_file(filename):
"""Generator: load records from file.
This is a generator that yields the records.
Parameters
----------
filename: str
Yields
------
input: autotvm.tuner.MeasureInput
result: autotvm.tuner.MeasureResult
"""
for row in open(filename):
yield decode(row)
class ApplyHistoryBest(DispatchContext):
"""
Apply the history best config
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
if is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
otherwise, it is an iterator
default: ConfigEntity, optional
default config to return when no history records
"""
def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__()
if isinstance(records, str):
records = load_from_file(records)
counter = 0
best_map = {}
for inp, res in records:
counter += 1
if res.error_no != 0:
continue
for k in inp.target.keys:
key = (k, inp.task.workload)
if key not in best_map:
best_map[key] = (inp, res)
else:
_, other_res = best_map[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_map[key] = (inp, res)
logging.info(
"Finish load %d records, %d entries selected", counter, len(best_map))
self._best_map = best_map
self._default = default
def query(self, target, workload):
if target is None:
raise RuntimeError("Need a target context to find the history best. "
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ")
for k in target.keys:
key = (k, workload)
if key in self._best_map:
return self._best_map[key][0].config
if self._default:
return self._default
raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload))
def dump_best(self, out_file):
"""Dump the best records for each workload to a file
Parameters
----------
out_file: str
filename
"""
fout = open(out_file, 'a')
for val in self._best_map.values():
inp, res = val
fout.write(encode(inp, res) + '\n')
def split_workload(in_file, clean=True):
"""Split a log file into separate files, each of which contains only a single workload
This function can also delete duplicated records in log file
Parameters
----------
in_file: str
input filename
clean: bool
whether delete duplicated items
"""
tic = time.time()
lines = list(open(in_file).readlines())
logging.info("start convert...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict()
for inp, res in lines:
wkl = measure_str_key(inp, False)
if wkl not in wkl_dict:
wkl_dict[wkl] = []
wkl_dict[wkl].append([inp, res])
if clean:
for i, (k, v) in enumerate(wkl_dict.items()):
# clean duplicated items
added = set()
cleaned = []
for inp, res in v:
str_key = measure_str_key(inp)
if str_key in added:
continue
added.add(str_key)
cleaned.append([inp, res])
# write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned:
fout.write(encode(inp, res) + '\n')
else:
for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v:
fout.write(encode(inp, res) + '\n')
"""
Usage:
This record executable module has three modes.
* Print log file in readable format
e.g. python -m autotvm.record --mode read --i collect_conv.tsv --begin 0 --end 5 --ir --code
* Extract history best from a large log file
e.g. python -m autotvm.record --mode best --i collect.tsv
* Split a log file into separate files, each of which contains only a single wkl
e.g. python -m autotvm.record --mode split --i collect.tsv
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=['read', 'best', 'split'], default='read')
parser.add_argument("--i", type=str, help="input file")
parser.add_argument("--o", type=str, default=None, help='output file')
parser.add_argument("--begin", type=int, default=0)
parser.add_argument("--end", type=int, default=5)
parser.add_argument("--ir", action='store_true')
parser.add_argument("--code", action='store_true')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.mode == 'best':
args.o = args.o or args.i + ".best"
hist_best = ApplyHistoryBest(load_from_file(args.i))
hist_best.dump_best(args.o)
elif args.mode == 'read':
for i, (inp, result) in enumerate(load_from_file(args.i)):
if args.begin <= i < args.end:
with inp.target:
s, arg_bufs = inp.task.instantiate(inp.config)
print("")
print(inp.target, inp.task, inp.config)
print(result)
if args.ir:
with inp.target:
print(lower(s, arg_bufs, simple_mode=True))
if args.code:
with inp.target:
func = build(s, arg_bufs)
print(func.imported_modules[0].get_source())
elif args.mode == 'split':
split_workload(args.i)
"""Task is a tunable composition of template functions.
Tuner takes a tunable task and optimizes the joint configuration
space of all the template functions in the task.
This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, register, template, get_config
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
"""
Decorator functions for hashing schedule code
code hashing is used to check the consistence of schedule code and the parameters loaded from log
"""
import inspect
import zlib
from tvm import schedule
def attach_code_hash(s):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
s: Schedule
tvm.schedule.Schedule to attach the hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
s.code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
def attach_code_hash_to_arg(arg_idx=1):
"""Decorator for attaching a code hash to a schedule
Parameters
----------
arg_idx: int
index of the argument (expected to be a Schedule) to attach the code
hash to
"""
def decorator(func):
def wrapper(*args, **kwargs):
func(*args, **kwargs)
assert isinstance(args[arg_idx], schedule.Schedule)
raw_hash = zlib.crc32(''.join(inspect.getsourcelines(func)[0]).encode())
args[arg_idx].code_hash = hex(raw_hash)[2:]
return wrapper
return decorator
"""
Template dispatcher module.
A dispatcher is a function that can contains multiple behaviors.
Its specific behavior is can be controlled by DispatchContext.
DispatchContext is used in two ways, usually via different implementation
of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
from __future__ import absolute_import as _abs
from decorator import decorate
from tvm import target as _target
class DispatchContext(object):
"""
Base class of dispatch context.
DispatchContext enables the target and workload
specific dispatch mechanism for templates.
"""
current = None
def query(self, target, workload):
"""
Query the context to get the specific implementation.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
raise NotImplementedError()
def __enter__(self):
self._old_ctx = DispatchContext.current
DispatchContext.current = self
return self
def __exit__(self, ptype, value, trace):
DispatchContext.current = self._old_ctx
class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query.
Parameters
----------
config : ConfigSpace or ConfigEntity
The specific configuration we care about.
"""
def __init__(self, config):
super(ApplyConfig, self).__init__()
self._config = config
self.workload = None
def query(self, target, workload):
"""Override query"""
self.workload = workload
return self._config
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
Parameters
----------
fworkload : function
The workload extraction function from arguments.
Returns
-------
fdispatcher : function
A wrapped dispatcher function, which will
dispatch based on DispatchContext and
the current workload.
"""
dispatch_dict = {}
func_name = fworkload.__name__
def register(key, func=None, override=False):
"""Register template function.
Parameters
----------
key : str or List of str
The template key to identify the template
under this dispatcher.
func : function
The function to be registered.
The first argument of the function is always
cfg returned by DispatchContext,
the rest arguments are the same as the fworkload.
override : bool
Whether override existing registration.
Returns
-------
The register function if necessary.
"""
if isinstance(key, str):
key = [key]
def _do_reg(myf):
for x in key:
if x in dispatch_dict and not override:
raise ValueError(
"Key %s is already registered for %s" % (x, func_name))
dispatch_dict[x] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispatch function"""
tgt = _target.current_target()
context = DispatchContext.current
if context is None:
raise RuntimeError("DispatchContext is not initialized")
workload = func(*args, **kwargs)
cfg = context.query(tgt, workload)
return dispatch_dict[cfg.template_key](cfg, *args, **kwargs)
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
This diff is collapsed.
# pylint: disable=unused-variable
"""Definition of task function.
Task can be constructed from tuple of func, args, and kwargs.
func is a state-less function, or a string that
registers the standard task.
"""
import numpy as np
from ... import tensor, expr, container, target as _target
from ..util import get_const_int, get_const_tuple, get_func_name
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
from .space import ConfigSpace
def _raise_error(*args, **kwargs): # pylint: disable=unused-argument
raise RuntimeError("The function of this task is not found. Possibly the function "
"of this task is registered in another python file "
"which is not imported in this run")
class Task(object):
"""A Tunable Task
Parameters
----------
name: str
The name of the task.
args: Tuple
Positional argument of func
"""
def __init__(self, name, args):
self.name = name
self.args = args
self.kwargs = {} # currently unused
# init null config space
self.config_space = None
self.func = TASK_TABLE.get(name, _raise_error)
# auxiliary info, available after `init_space` is called
self.workload = None
self.flop = None
self.target = None
self.target_host = None
def instantiate(self, config):
"""Instantiate this task function (template) with a config.
Returns corresponding schedule.
Parameters
----------
config: template.ConfigEntity
parameter config for this template
Returns
-------
sch: tvm.schedule.Schedule
The tvm schedule
arg_bufs: Array of tvm.tensor.Tensor
The input/output buffers
"""
config.flop = 0
with ApplyConfig(config):
sch, arg_bufs = self.func(*self.args, **self.kwargs)
if not self.flop:
config.flop = config.flop or compute_flop(sch)
self.flop = config.flop
return sch, arg_bufs
def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
self.name, self.args, self.kwargs, self.workload
)
TASK_TABLE = {
}
def register(name, func=None, override=False):
"""Register a task function.
Parameters
----------
name : str
The name to identify the task.
func : callable
The function to be registered.
override : bool
Whether override existing registration.
Returns
-------
func: callable
The registered function
"""
def _do_reg(myf):
if name in TASK_TABLE and not override:
raise ValueError(
"Key %s is already registered" % name)
TASK_TABLE[name] = myf
return myf
if func:
return _do_reg(func)
return _do_reg
def create(func_name, args, target, target_host=None, template_key=None):
"""Create a tuning task and initialize its search space
Parameters
----------
func_name : str or callable
The task function
args : List
Positional arguments
target : Target
The compilation target
target_host: Target, optional
The compilation target for host side
Returns
-------
tsk: Task
a task object
"""
if callable(func_name):
# register this function if it is not registered before
func = func_name
func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
if func_name in TASK_TABLE:
assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
"Consider to choose another name for this task"
else:
register(func_name, func=func)
func = TASK_TABLE[func_name]
ret = Task(func_name, args)
if isinstance(target, str):
target = _target.create(target)
# init config space
ret.config_space = ConfigSpace()
ret.config_space.template_key = template_key or ""
ctx = ApplyConfig(ret.config_space)
with ctx:
with target:
sch, _ = func(*args)
ret.config_space.code_hash = getattr(sch, 'code_hash', None)
ret.workload = ctx.workload
ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target_host
return ret
def args_to_workload(x):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
Parameters
----------
x: primitive hashable types or tensor.Tensor
The original value
Returns
-------
ret: hashable
The hashable value
"""
if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
return x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
elif x is None:
return None
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x))
def template(func):
"""
Decorate a function as a tunable schedule template
Parameters
----------
func: callable
A callable template function.
Its argument should be hashable values.
Its return value should be a Tuple(Schedule, Array of Tensor)
Returns
-------
func: callable
The decorated function
Examples
--------
The following code is a tunable template for a blocked matrix multiplication
.. code-block:: python
@autotvm.template
def matmul(N, L, M, dtype):
A = tvm.placeholder((N, L), name='A', dtype=dtype)
B = tvm.placeholder((L, M), name='B', dtype=dtype)
k = tvm.reduce_axis((0, L), name='k')
C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
s = tvm.create_schedule(C.op)
# schedule
y, x = s[C].op.axis
k = s[C].op.reduce_axis[0]
##### define space begin #####
cfg = autotvm.get_config()
cfg.define_split("tile_y", y, num_outputs=2)
cfg.define_split("tile_x", x, num_outputs=2)
##### define space end #####
# schedule according to config
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, k, yi, xi)
return s, [A, B, C]
"""
# pylint: disable=unused-variable
fname = get_func_name(func)
@register(fname)
@dispatcher
def config_dispatcher(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
@config_dispatcher.register("")
def template_call(cfg, *args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
with ApplyConfig(cfg):
return func(*args, **kwargs)
config_dispatcher.func_name = fname
return config_dispatcher
def get_config():
"""Get current config object
Returns
-------
cfg: ConfigSpace or ConfigEntity
The current config
"""
return DispatchContext.current.query(None, None)
class FlopCalculationError(RuntimeError):
"""Error happens when estimating FLOP for a compute op"""
pass
def compute_flop(sch):
"""Calculate number of FLOP (floating number operations) of the compute ops in a schedule
Parameters
----------
sch: tvm.schedule.Schedule
schedule
Returns
-------
flop: int
number of FLOP in this schedule
"""
def _prod_length(axes):
"""compute product of the lengths of a list of axes"""
try:
num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
except ValueError:
raise FlopCalculationError("The length of axis is not constant. ")
return num_iter
def _count_flop(exp):
"""compute flop for a single expression"""
if isinstance(exp, expr.Reduce):
num_iter = _prod_length(exp.axis)
combiner = exp.combiner.result
source = exp.source
if len(combiner) != 1:
raise FlopCalculationError("Found multiple output in the combiner of reduce op")
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
elif isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
return 0
elif isinstance(exp, expr.Cast):
return _count_flop(exp.value)
elif isinstance(exp, expr.Var):
return 0
elif isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
base = 1 if "float" in exp.a.dtype else 0
if isinstance(exp, expr.Not): # unary
return base + _count_flop(exp.a)
return base + _count_flop(exp.a) + _count_flop(exp.b)
elif isinstance(exp, expr.Select):
return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
_count_flop(exp.false_value))
elif isinstance(exp, expr.Call):
return sum([_count_flop(x) for x in exp.args])
else:
raise FlopCalculationError("Found unsupported operator in the compute expr")
def traverse(ops):
"""accumulate flops"""
ret = 0
for op in ops:
if isinstance(op, tensor.ComputeOp):
num_element = _prod_length(op.axis)
body = op.body
if len(body) != 1:
raise FlopCalculationError("Found multiple output in the compute")
exp = body[0]
ret += num_element * _count_flop(exp)
ret += traverse([sch[t].op for t in op.input_tensors])
elif isinstance(op, tensor.PlaceholderOp):
pass
else:
raise FlopCalculationError("Only support tvm.compute currently. "
"Other ops like tvm.scan is not supported")
return ret
try:
ret = traverse(sch.outputs)
except FlopCalculationError as exc:
raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
+ str(exc) + ". Please use `cfg.add_flop` to manually set "
"FLOP for this operator")
if ret == 0:
raise RuntimeError("Cannot find float number operation in this operator. "
"Please use `cfg.add_flop` to manually set "
"FLOP for this operator")
return ret
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment