Skip to content
Snippets Groups Projects
Commit 136061dc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by Tianqi Chen
Browse files

[AUTOTVM] Improve tutorial and logging (#1544)

parent 33606741
No related branches found
No related tags found
No related merge requests found
Showing
with 200 additions and 116 deletions
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, create_measure_batch, use_rpc
from .measure_methods import request_remote, check_remote, create_measure_batch, use_rpc
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -9,6 +9,7 @@ import logging
import os
import time
from random import getrandbits
import threading
import numpy as np
......@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
from .measure import MeasureResult, MeasureErrorNo
from .local_executor import LocalExecutor
logger = logging.getLogger('autotvm')
class HashMismatchError(ValueError):
"""Raised when the code hash of a submitted config doesn't match that on the
......@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
priority of this request, larger is more prior
The priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
The timeout of this session (units: seconds)
Returns
------
......@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout=timeout)
return remote
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def _check():
remote = request_remote(device_key, tracker_addr, priority)
remote.context(str(target))
t = threading.Thread(target=_check,)
t.start()
t.join(timeout)
return not t.is_alive()
def create_measure_batch(task, option):
"""Get a standard measure_batch function.
......@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func = default_build_func
build_kwargs['use_ndk'] = True
# check the availability of remote devices
if hasattr(measure_func, 'rpc_info'):
rpc_info = measure_func.rpc_info
if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")
# add device info of cuda and opencl target
if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
and hasattr(measure_func, 'rpc_info'):
......@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue
except InstantiationError as e:
tstamp = time.time()
res_pack.append(MeasureResult((e,),
res_pack.append(MeasureResult((InstantiationError(str(e)),),
MeasureErrorNo.INSTANTIATION_ERROR,
tstamp - tic, tstamp))
continue
......@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if ref_output:
for expected, real in zip(ref_output, args):
if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
logging.warning("Wrong Answer!")
logger.warning("Wrong Answer!")
errno = MeasureErrorNo.WRONG_ANSWER
except TVMError as exc:
msg = str(exc)
......
......@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
logger = logging.getLogger('autotvm')
try: # convert unicode to str for python2
_unicode = unicode
......@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic = time.time()
lines = list(open(in_file).readlines())
logging.info("start converting...")
logger.info("start converting...")
pool = multiprocessing.Pool()
lines = pool.map(decode, lines)
logging.info("map done %.2f", time.time() - tic)
logger.info("map done %.2f", time.time() - tic)
wkl_dict = OrderedDict()
for inp, res in lines:
......@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned.append([inp, res])
# write to file
logging.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in cleaned:
fout.write(encode(inp, res) + '\n')
else:
for i, (k, v) in enumerate(wkl_dict.items()):
logging.info("Key: %s\tNum: %d", k, len(v))
logger.info("Key: %s\tNum: %d", k, len(v))
with open(args.i + ".%03d.wkl" % i, 'w') as fout:
for inp, res in v:
fout.write(encode(inp, res) + '\n')
......@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the %s", len(best_set), in_file)
logger.info("Extract %d best records from the %s", len(best_set), in_file)
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
for inp, res in load_from_file(in_file):
......@@ -270,7 +271,7 @@ if __name__ == '__main__':
parser.add_argument("--code", action='store_true')
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger.basicConfig(level=logger.INFO)
if args.mode == 'pick':
args.o = args.o or args.i + ".best.log"
......
......@@ -10,6 +10,8 @@ of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import logging
......@@ -19,6 +21,8 @@ import numpy as np
from tvm import target as _target
logger = logging.getLogger('autotvm')
class DispatchContext(object):
"""
Base class of dispatch context.
......@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
best_by_model[key] = (inp, res)
break
logging.debug("Finish loading %d records", counter)
logger.debug("Finish loading %d records", counter)
def query(self, target, workload):
if target is None:
......
......@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
"""
# pylint: disable=invalid-name
import logging
import os
......@@ -16,6 +17,7 @@ from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
logger = logging.getLogger('autotvm')
def _alias(name):
"""convert alias for some packages"""
......@@ -79,7 +81,7 @@ def download_package(backend):
os.mkdir(path)
backend = _alias(backend)
logging.info("Download pre-tuned parameters for %s", backend)
logger.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0)
......@@ -110,7 +112,7 @@ def list_packages():
"""
path = tempdir()
filename = path.relpath("info.json")
logging.info("Download meta info for pre-tuned parameters")
logger.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0)
......
......@@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import logging
import numpy as np
from .. import record
logger = logging.getLogger('autotvm')
def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
......@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str
The prefix of output message
"""
class _Context:
class _Context(object):
"""Context to store local variables"""
def __init__(self):
self.best_flops = 0
......@@ -112,13 +114,14 @@ def progress_bar(total, prefix=''):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()
sys.stdout.write('%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s\r' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()
return _callback
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate, invalid-name
"""
Cost model optimizer based on simulated annealing
"""
......@@ -12,6 +12,8 @@ import numpy as np
from ..util import sample_ints
from .model_based_tuner import ModelOptimizer, knob2point, point2knob
logger = logging.getLogger('autotvm')
class SimulatedAnnealingOptimizer(ModelOptimizer):
"""parallel simulated annealing optimization algorithm
......@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
if log_interval and k % log_interval == 0:
t_str = "%.2f" % t
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
heap_items.sort(key=lambda item: -item[0])
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.debug("SA Maximums: %s", heap_items)
logger.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logger.debug("SA Maximums: %s", heap_items)
if self.persistent:
self.points = points
......
......@@ -4,11 +4,12 @@ import logging
import numpy as np
from ..measure import MeasureInput
from ..measure import create_measure_batch
from ..measure import MeasureInput, create_measure_batch
from ..env import GLOBAL_SCOPE
logger = logging.getLogger('autotvm')
class Tuner(object):
"""Base class for tuners
......@@ -86,9 +87,10 @@ class Tuner(object):
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stopping = early_stopping or 1e9
old_level = logger.level
GLOBAL_SCOPE.in_tuning = True
i = 0
i = error_ct = 0
while i < n_trial:
if not self.has_next():
break
......@@ -103,17 +105,20 @@ class Tuner(object):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
error_ct = 0
else:
flops = 0
error_ct += 1
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
i += len(results)
......@@ -123,11 +128,16 @@ class Tuner(object):
callback(self, inputs, results)
if i > self.best_iter + early_stopping:
logging.debug("Early stopped. Best iter: %d.", self.best_iter)
logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break
GLOBAL_SCOPE.in_tuning = False
if error_ct > 50:
logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(old_level)
GLOBAL_SCOPE.in_tuning = False
del measure_batch
def reset(self):
......
......@@ -16,6 +16,8 @@ from ..util import get_rank
from .metric import max_curve, recall_curve, cover_curve
from .model_based_tuner import CostModel, FeatureCache
logger = logging.getLogger('autotvm')
class XGBoostCostModel(CostModel):
"""XGBoost as cost model
......@@ -163,17 +165,17 @@ class XGBoostCostModel(CostModel):
],
verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
logger.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
def fit_log(self, records, plan_size):
tic = time.time()
self._reset_pool()
args = list(records)
logging.debug("XGB load %d entries from history log file", len(args))
logger.debug("XGB load %d entries from history log file", len(args))
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
......@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
],
verbose_eval=self.log_interval)])
logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
logger.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
def predict(self, xs, output_margin=False):
feas = self._get_feature(xs)
......@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
infos.append("%s: %.6f" % (item[0], item[1]))
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
logging.debug("\t".join(infos))
logger.debug("\t".join(infos))
if log_file:
with open(log_file, "a") as fout:
fout.write("\t".join(infos) + '\n')
......@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose_eval and env.rank == 0:
logging.debug("XGB stopped. Best iteration: %s ", best_msg)
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
return callback
......
......@@ -8,6 +8,7 @@ import numpy as np
from .. import expr, ir_pass
logger = logging.getLogger('autotvm')
class EmptyContext(object):
"""An empty context"""
......@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
tic = time.time()
local_pool = pool or multiprocessing.Pool()
if verbose:
logging.info("mapping begin")
logger.info("mapping begin")
for i in range(0, len(args), batch_size):
if verbose:
logging.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic)
logger.info("mapping %d/%d elapsed %.2f", i, len(args),
time.time() - tic)
tmp = np.array(local_pool.map(func, args[i:i+batch_size]))
ret = tmp if ret is None else np.concatenate((ret, tmp))
if verbose:
logging.info("mapping done")
logger.info("mapping done")
if not pool:
local_pool.close()
return ret
......
"""Base definitions for RPC."""
# pylint: disable=invalid-name
from __future__ import absolute_import
import socket
......@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH = RPC_MAGIC + 2
logger = logging.getLogger('RPCServer')
class TrackerCode(object):
"""Enumeration code for the RPC tracker"""
......@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
return prefix + str(random.random())
def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
def connect_with_retry(addr, timeout=60, retry_period=5):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
......@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
retry_period : float
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
tstart = time.time()
while True:
......@@ -152,9 +152,8 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
if period > timeout:
raise RuntimeError(
"Failed to connect to server %s" % str(addr))
if not silent:
logging.info("Cannot connect to tracker%s, retry in %g secs...",
str(addr), retry_period)
logger.warning("Cannot connect to tracker %s, retry in %g secs...",
str(addr), retry_period)
time.sleep(retry_period)
......
......@@ -23,7 +23,8 @@ try:
from tornado import ioloop
from . import tornado_util
except ImportError as error_msg:
raise ImportError("RPCProxy module requires tornado package %s" % error_msg)
raise ImportError(
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from . import base
from .base import TrackerCode
......@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
def _connect(key):
conn = yield websocket.websocket_connect(url)
on_message = create_on_message(conn)
temp = _server_env(None, None)
temp = _server_env(None)
# Start connecton
conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True)
key = "server:" + key
......
......@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
- The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
from __future__ import absolute_import
import os
......@@ -30,11 +32,11 @@ from ..contrib import util
from . import base
from . base import TrackerCode
def _server_env(load_library, logger):
logger = logging.getLogger('RPCServer')
def _server_env(load_library):
"""Server environment function return temp dir"""
temp = util.tempdir()
if logger is None:
logger = logging.getLogger()
# pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath")
......@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
return temp
def _serve_loop(sock, addr, load_library, silent):
def _serve_loop(sock, addr, load_library):
"""Server loop"""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
sockfd = sock.fileno()
temp = _server_env(load_library, logger)
temp = _server_env(load_library)
base._ServerLoop(sockfd)
temp.remove()
logger.info("Finish serving %s", addr)
......@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
ret["timeout"] = float(kv[9:])
return ret
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, silent):
def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server master."""
logger = logging.getLogger("RPCServer")
if silent:
logger.disabled = True
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
......@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
if arr[0] != expect_header:
conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
conn.close()
logger.info("mismatch key from %s", addr)
logger.warning("mismatch key from %s", addr)
continue
else:
conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
......@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
try:
# step 1: setup tracker and report to tracker
if tracker_addr and tracker_conn is None:
tracker_conn = base.connect_with_retry(tracker_addr, silent=silent)
tracker_conn = base.connect_with_retry(tracker_addr)
tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
if magic != base.RPC_TRACKER_MAGIC:
......@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
tracker_conn = None
continue
except RuntimeError as exc:
if silent:
return
else:
raise exc
raise exc
# step 3: serving
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library, silent))
args=(conn, addr, load_library))
server_proc.deamon = True
server_proc.start()
# close from our side.
......@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
server_proc.terminate()
def _connect_proxy_loop(addr, key, load_library, silent):
logger = logging.getLogger("RPCProxy")
if silent:
logger.disabled = True
def _connect_proxy_loop(addr, key, load_library):
key = "server:" + key
retry_count = 0
max_retry = 5
......@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
if magic == base.RPC_CODE_DUPLICATE:
raise RuntimeError("key: %s has already been used in proxy" % key)
elif magic == base.RPC_CODE_MISMATCH:
logger.info("RPCProxy do not have matching client key %s", key)
logger.warning("RPCProxy do not have matching client key %s", key)
elif magic != base.RPC_CODE_SUCCESS:
raise RuntimeError("%s is not RPC Proxy" % str(addr))
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
......@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(
target=_serve_loop, args=(sock, addr, load_library, silent))
target=_serve_loop, args=(sock, addr, load_library))
process.deamon = True
process.start()
sock.close()
......@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
logger.info("Error encountered %s, retry in %g sec", str(err), retry_period)
logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
if retry_count > max_retry:
raise RuntimeError("Maximum retry error: last error: %s" % str(err))
time.sleep(retry_period)
......@@ -323,9 +312,8 @@ class Server(object):
self.custom_addr = custom_addr
self.use_popen = use_popen
self.logger = logging.getLogger("RPCServer")
if silent:
self.logger.disabled = True
logger.setLevel(logging.WARN)
if use_popen:
cmd = [sys.executable,
......@@ -360,18 +348,18 @@ class Server(object):
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.sock = sock
self.proc = multiprocessing.Process(
target=_listen_loop, args=(
self.sock, self.port, key, tracker_addr, load_library,
self.custom_addr, silent))
self.custom_addr))
self.proc.deamon = True
self.proc.start()
else:
self.proc = multiprocessing.Process(
target=_connect_proxy_loop, args=((host, port), key, load_library, silent))
target=_connect_proxy_loop, args=((host, port), key, load_library))
self.proc.deamon = True
self.proc.start()
......
......@@ -23,6 +23,8 @@ List of available APIs:
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
# pylint: disable=invalid-name
import heapq
import time
import logging
......@@ -37,12 +39,13 @@ try:
from . import tornado_util
except ImportError as error_msg:
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)
"RPCTracker module requires tornado package %s. Try 'pip install tornado'." % error_msg)
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode
logger = logging.getLogger("RPCTracker")
class Scheduler(object):
"""Abstratc interface of scheduler."""
......@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
def _init_conn(self, message):
"""Initialie the connection"""
if len(message) != 4:
logging.info("Invalid connection from %s", self.name())
logger.warning("Invalid connection from %s", self.name())
self.close()
magic = struct.unpack('<i', message)[0]
if magic != RPC_TRACKER_MAGIC:
logging.info("Invalid magic from %s", self.name())
logger.warning("Invalid magic from %s", self.name())
self.close()
self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True)
self._init_req_nbytes = 0
......@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
status = self._tracker.summary()
self.ret_value([TrackerCode.SUCCESS, status])
else:
logging.info("Unknown tracker code %d", code)
logger.warning("Unknown tracker code %d", code)
self.close()
def on_close(self):
self._tracker._connections.remove(self)
def on_error(self, err):
logging.info("%s: Error in RPC Tracker: %s", self.name(), err)
logger.warning("%s: Error in RPC Tracker: %s", self.name(), err)
self.close()
......@@ -335,9 +338,8 @@ class Tracker(object):
port=9190,
port_end=9199,
silent=False):
self.logger = logging.getLogger("RPCTracker")
if silent:
self.logger.disabled = True
logger.setLevel(logging.WARN)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = None
......@@ -354,7 +356,7 @@ class Tracker(object):
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
self.logger.info("bind to %s:%d", host, self.port)
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_tracker_server, args=(sock, self.stop_key))
......@@ -380,7 +382,7 @@ class Tracker(object):
self._stop_tracker()
self.proc.join(1)
if self.proc.is_alive():
self.logger.info("Terminating Tracker Server...")
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc = None
......
......@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
......
......@@ -163,8 +163,10 @@ def get_network(name, batch_size):
# Set Tuning Options
# ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board
# in our environment as example. In your setting, you should modify the target
# and device_key accordingly.
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
......@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker
device_key = 'rk3399'
# tuning option
# Set this to True if you use android phone
use_android = False
#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'
......@@ -181,17 +186,17 @@ dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner':'xgb',
'tuner': 'xgb',
'n_trial': 1000,
'early_stopping': 200,
'early_stopping': 250,
'measure_option': autotvm.measure_option(
autotvm.use_rpc(device_key, host='localhost', port=9190),
number=4,
parallel_num=1,
timeout=10),
'use_transfer_learning': True,
timeout=10,
build_func='ndk' if use_android else 'default',
),
}
####################################################################
......@@ -208,9 +213,6 @@ tuning_option = {
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
#
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
#
###################################################################
# Begin Tuning
......@@ -280,12 +282,14 @@ def tune_tasks(tasks,
def tune_and_evaluate():
# extract workloads from nnvm graph
print("Extract tasks...")
net, params, shape, out_shape = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_graph(net, shape=shape, dtype=dtype,
symbols=(nnvm.sym.conv2d,),
target=target)
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_option)
# compile kernels with history best records
......@@ -325,10 +329,11 @@ def tune_and_evaluate():
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
######################################################################
......@@ -341,6 +346,8 @@ def tune_and_evaluate():
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
......@@ -362,3 +369,23 @@ def tune_and_evaluate():
# Evaluate inference time cost...
# Mean inference time (std dev): 156.51 ms (0.89 ms)
#
######################################################################
#
# .. note:: **Meet some problems?**
#
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
......@@ -267,8 +267,9 @@ print(task.config_space)
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
# logging config (for printing tuning log to the screen)
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option('local',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment