From 5912ed034e4dfe7e1a1a538e44e47b7115e3fee8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sat, 3 Jun 2017 09:06:28 -0700 Subject: [PATCH] [PERF/TIMER] Add builtin timing logic (#168) * [PERF/TIMER] Add buildin timing logic * fix lint --- python/tvm/_ffi/function.py | 6 ++- python/tvm/module.py | 33 ++++++++++++- src/runtime/rpc/rpc_module.cc | 57 ++++++++++++++++++----- src/runtime/rpc/rpc_session.cc | 35 ++++++++++++++ src/runtime/rpc/rpc_session.h | 23 ++++++++- tests/python/integration/test_ewise.py | 3 +- tests/python/integration/test_gemm.py | 11 ++--- tests/python/unittest/test_runtime_rpc.py | 4 +- 8 files changed, 146 insertions(+), 26 deletions(-) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 65e292aff..f780ed4eb 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -56,10 +56,12 @@ class Function(_FunctionBase): class ModuleBase(object): """Base class for module""" - __slots__ = ["handle", "_entry"] + __slots__ = ["handle", "_entry", "entry_name"] + def __init__(self, handle): self.handle = handle self._entry = None + self.entry_name = "__tvm_main__" def __del__(self): check_call(_LIB.TVMModFree(self.handle)) @@ -75,7 +77,7 @@ class ModuleBase(object): """ if self._entry: return self._entry - self._entry = self.get_function("__tvm_main__") + self._entry = self.get_function(self.entry_name) return self._entry def get_function(self, name, query_imports=False): diff --git a/python/tvm/module.py b/python/tvm/module.py index ff0df4397..a088e6408 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -72,7 +72,7 @@ class Module(ModuleBase): The name of the shared library. """ if self.type_key != "llvm": - raise ValueError("Only llvm support export shared") + raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key) temp = _util.tempdir() path_obj = temp.relpath("lib.o") self.save(path_obj) @@ -84,6 +84,37 @@ class Module(ModuleBase): files.append(path_cc) _cc.create_shared(file_name, files) + def time_evaluator(self, func_name, ctx, number): + """Get an evaluator that measures time cost of running function. + + Parameters + ---------- + func_name: str + The name of the function in the module. + + ctx: TVMContext + The context we should run this function on. + + number: int + The number of repeative times to run evaluation. + + Note + ---- + The function will be invoked number + 1 times, + with the first call discarded in case there is lazy initialization. + + Returns + ------- + ftimer : Function + The function that takes same argument as func + and return a float representing seconds per function call. + """ + try: + return _RPCTimeEvaluator( + self, func_name, ctx.device_type, ctx.device_id, number) + except NameError: + raise NameError("time_evaluate is only supported when RPC is enabled") + def load(path, fmt=""): """Load module from file diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index c946d129c..d3606f0a2 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode { PackedFunc GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) final { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - if (handle == nullptr) return PackedFunc(); - auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_); - return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + RPCFuncHandle handle = GetFuncHandle(name); + return WrapRemote(handle); } void SaveToFile(const std::string& file_name, @@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode { return sess_; } + PackedFunc GetTimeEvaluator(const std::string& name, + TVMContext ctx, + int nstep) { + RPCFuncHandle handle = GetFuncHandle(name); + if (handle == nullptr) return PackedFunc(); + handle = sess_->GetTimeEvaluator(handle, ctx, nstep); + return WrapRemote(handle); + } + private: + PackedFunc WrapRemote(RPCFuncHandle handle) { + if (handle == nullptr) return PackedFunc(); + auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_); + return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { + return wf->operator()(args, rv); + }); + } + + RPCFuncHandle GetFuncHandle(const std::string& name) { + RPCFuncHandle handle = nullptr; + if (module_handle_ == nullptr) { + handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); + } else { + handle = sess_->CallRemote( + RPCCode::kModuleGetFunc, module_handle_, name); + } + return handle; + } // The module handle void* module_handle_{nullptr}; // The local channel @@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect") *rv = RPCConnect(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + std::string tkey = m->type_key(); + TVMContext ctx; + ctx.device_type = static_cast<DLDeviceType>(args[2].operator int()); + ctx.device_id = args[3]; + if (tkey == "rpc") { + *rv = static_cast<RPCModuleNode*>(m.operator->()) + ->GetTimeEvaluator(args[1], ctx, args[4]); + } else { + *rv = WrapTimeEvaluator( + m.GetFunction(args[1], false), ctx, args[3]); + } + }); + TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 1319889f4..ca77c02af 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -6,6 +6,7 @@ #include <tvm/runtime/packed_func.h> #include <memory> #include <array> +#include <chrono> #include "./rpc_session.h" #include "../device_api.h" @@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from, } } +RPCFuncHandle RPCSession::GetTimeEvaluator( + RPCFuncHandle fhandle, TVMContext ctx, int nstep) { + return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep); +} + void RPCSession::SendReturnValue( int succ, TVMValue ret_value, int ret_tcode) { if (succ == 0) { @@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt); } +void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { + PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*()); + void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2])); + delete pf; + *rv = fhandle; +} + RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { RPCCode code; CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int)); @@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { case RPCCode::kCopyToRemote: HandleCopyToRemote(); break; case RPCCode::kShutdown: break; // system functions + case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; @@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) { } return code; } + +PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) { + auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) { + TVMRetValue temp; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + // start timing + auto tbegin = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < nstep; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + auto tend = std::chrono::high_resolution_clock::now(); + double speed = std::chrono::duration_cast<std::chrono::duration<double> >( + tend - tbegin).count() / nstep; + // return the time. + *rv = speed; + }; + return PackedFunc(ftimer); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 0e30486a5..f56a9f87c 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -31,6 +31,7 @@ enum class RPCCode : int { kCopyAck, // The following are code that can send over CallRemote kGetGlobalFunc, + kGetTimeEvaluator, kFreeFunc, kDevSetDevice, kDevGetAttr, @@ -92,6 +93,18 @@ class RPCSession { size_t to_offset, size_t size, TVMContext ctx_from); + /*! + * \brief Get a remote timer function on ctx. + * This function consumes fhandle, caller should not call Free on fhandle. + * + * \param fhandle The function handle. + * \param ctx The ctx to run measurement on. + * \param nstep Number of steps to run. + * \return A remote timer function + */ + RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, + TVMContext ctx, + int nstep); /*! * \brief Call a remote defined system function with arguments. * \param fcode The function code. @@ -133,13 +146,13 @@ class RPCSession { void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n); void RecvPackedSeq(RPCArgBuffer *buf); RPCCode HandleNextEvent(TVMRetValue *rv); + TVMContext StripSessMask(TVMContext ctx); // special handler. void HandleCallFunc(); void HandleException(); void HandleCopyFromRemote(); void HandleCopyToRemote(); void HandleReturn(TVMRetValue* rv); - TVMContext StripSessMask(TVMContext ctx); // Internal mutex std::recursive_mutex mutex_; // Internal socket @@ -152,6 +165,14 @@ class RPCSession { int table_index_{0}; }; +/*! + * \brief Wrap a timer function for a given packed function. + * \param f The function argument. + * \param ctx The context. + * \param nstep Number of repeative steps. + */ +PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep); + // Remote space pointer. struct RemoteSpace { void* data; diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index e7553b2b8..990c900b3 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -95,7 +95,8 @@ def test_add(): c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) vbias = np.random.uniform() vscale = np.random.uniform() - fadd(a, b, c, vbias, vscale) + ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000) + tcost = ftimer(a, b, c, vbias, vscale) np.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6) diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 900db1f60..bb6a59afa 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -78,14 +78,9 @@ def test_gemm(): a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - ctx.sync() - tbegin = time.time() - f(a, b, c) - tpush = time.time() - ctx.sync() - tend = time.time() - print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin)) + ftimer = f.time_evaluator(f.entry_name, ctx, number=20) + tcost = ftimer(a, b, c) + print("%s: exec=%g sec/op" % (ctx, tcost)) np.testing.assert_allclose( c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index a1fe995a9..2078eabf5 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -70,7 +70,9 @@ def test_rpc_remote_module(): f1 = remote.load_module("dev_lib.so") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - f1(a, b) + time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + cost = time_f(a, b) + print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) check_remote() -- GitLab