From 5912ed034e4dfe7e1a1a538e44e47b7115e3fee8 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sat, 3 Jun 2017 09:06:28 -0700
Subject: [PATCH] [PERF/TIMER] Add builtin timing logic (#168)

* [PERF/TIMER] Add buildin timing logic

* fix lint
---
 python/tvm/_ffi/function.py               |  6 ++-
 python/tvm/module.py                      | 33 ++++++++++++-
 src/runtime/rpc/rpc_module.cc             | 57 ++++++++++++++++++-----
 src/runtime/rpc/rpc_session.cc            | 35 ++++++++++++++
 src/runtime/rpc/rpc_session.h             | 23 ++++++++-
 tests/python/integration/test_ewise.py    |  3 +-
 tests/python/integration/test_gemm.py     | 11 ++---
 tests/python/unittest/test_runtime_rpc.py |  4 +-
 8 files changed, 146 insertions(+), 26 deletions(-)

diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py
index 65e292aff..f780ed4eb 100644
--- a/python/tvm/_ffi/function.py
+++ b/python/tvm/_ffi/function.py
@@ -56,10 +56,12 @@ class Function(_FunctionBase):
 
 class ModuleBase(object):
     """Base class for module"""
-    __slots__ = ["handle", "_entry"]
+    __slots__ = ["handle", "_entry", "entry_name"]
+
     def __init__(self, handle):
         self.handle = handle
         self._entry = None
+        self.entry_name = "__tvm_main__"
 
     def __del__(self):
         check_call(_LIB.TVMModFree(self.handle))
@@ -75,7 +77,7 @@ class ModuleBase(object):
         """
         if self._entry:
             return self._entry
-        self._entry = self.get_function("__tvm_main__")
+        self._entry = self.get_function(self.entry_name)
         return self._entry
 
     def get_function(self, name, query_imports=False):
diff --git a/python/tvm/module.py b/python/tvm/module.py
index ff0df4397..a088e6408 100644
--- a/python/tvm/module.py
+++ b/python/tvm/module.py
@@ -72,7 +72,7 @@ class Module(ModuleBase):
             The name of the shared library.
         """
         if self.type_key != "llvm":
-            raise ValueError("Only llvm support export shared")
+            raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
         temp = _util.tempdir()
         path_obj = temp.relpath("lib.o")
         self.save(path_obj)
@@ -84,6 +84,37 @@ class Module(ModuleBase):
             files.append(path_cc)
         _cc.create_shared(file_name, files)
 
+    def time_evaluator(self, func_name, ctx, number):
+        """Get an evaluator that measures time cost of running function.
+
+        Parameters
+        ----------
+        func_name: str
+            The name of the function in the module.
+
+        ctx: TVMContext
+            The context we should run this function on.
+
+        number: int
+            The number of repeative times to run evaluation.
+
+        Note
+        ----
+        The function will be invoked number + 1 times,
+        with the first call discarded in case there is lazy initialization.
+
+        Returns
+        -------
+        ftimer : Function
+            The function that takes same argument as func
+            and return a float representing seconds per function call.
+        """
+        try:
+            return _RPCTimeEvaluator(
+                self, func_name, ctx.device_type, ctx.device_id, number)
+        except NameError:
+            raise NameError("time_evaluate is only supported when RPC is enabled")
+
 
 def load(path, fmt=""):
     """Load module from file
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index c946d129c..d3606f0a2 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -51,18 +51,8 @@ class RPCModuleNode final : public ModuleNode {
   PackedFunc GetFunction(
       const std::string& name,
       const std::shared_ptr<ModuleNode>& sptr_to_self) final {
-    RPCFuncHandle handle = nullptr;
-    if (module_handle_ == nullptr) {
-      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
-    } else {
-      handle = sess_->CallRemote(
-          RPCCode::kModuleGetFunc, module_handle_, name);
-    }
-    if (handle == nullptr) return PackedFunc();
-    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
-    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
-        return wf->operator()(args, rv);
-      });
+    RPCFuncHandle handle = GetFuncHandle(name);
+    return WrapRemote(handle);
   }
 
   void SaveToFile(const std::string& file_name,
@@ -86,7 +76,34 @@ class RPCModuleNode final : public ModuleNode {
     return sess_;
   }
 
+  PackedFunc GetTimeEvaluator(const std::string& name,
+                              TVMContext ctx,
+                              int nstep) {
+    RPCFuncHandle handle = GetFuncHandle(name);
+    if (handle == nullptr) return PackedFunc();
+    handle = sess_->GetTimeEvaluator(handle, ctx, nstep);
+    return WrapRemote(handle);
+  }
+
  private:
+  PackedFunc WrapRemote(RPCFuncHandle handle) {
+    if (handle == nullptr) return PackedFunc();
+    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
+    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
+        return wf->operator()(args, rv);
+      });
+  }
+
+  RPCFuncHandle GetFuncHandle(const std::string& name) {
+    RPCFuncHandle handle = nullptr;
+    if (module_handle_ == nullptr) {
+      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
+    } else {
+      handle = sess_->CallRemote(
+          RPCCode::kModuleGetFunc, module_handle_, name);
+    }
+    return handle;
+  }
   // The module handle
   void* module_handle_{nullptr};
   // The local channel
@@ -123,6 +140,22 @@ TVM_REGISTER_GLOBAL("contrib.rpc._Connect")
     *rv = RPCConnect(args[0], args[1]);
   });
 
+TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Module m = args[0];
+    std::string tkey = m->type_key();
+    TVMContext ctx;
+    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
+    ctx.device_id = args[3];
+    if (tkey == "rpc") {
+      *rv = static_cast<RPCModuleNode*>(m.operator->())
+          ->GetTimeEvaluator(args[1], ctx, args[4]);
+    } else {
+      *rv = WrapTimeEvaluator(
+          m.GetFunction(args[1], false), ctx, args[3]);
+    }
+  });
+
 TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
     Module m = args[0];
diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc
index 1319889f4..ca77c02af 100644
--- a/src/runtime/rpc/rpc_session.cc
+++ b/src/runtime/rpc/rpc_session.cc
@@ -6,6 +6,7 @@
 #include <tvm/runtime/packed_func.h>
 #include <memory>
 #include <array>
+#include <chrono>
 #include "./rpc_session.h"
 #include "../device_api.h"
 
@@ -181,6 +182,11 @@ void RPCSession::CopyFromRemote(void* from,
   }
 }
 
+RPCFuncHandle RPCSession::GetTimeEvaluator(
+    RPCFuncHandle fhandle, TVMContext ctx, int nstep) {
+  return this->CallRemote(RPCCode::kGetTimeEvaluator, fhandle, ctx, nstep);
+}
+
 void RPCSession::SendReturnValue(
     int succ, TVMValue ret_value, int ret_tcode) {
   if (succ == 0) {
@@ -593,6 +599,13 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
   *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
 }
 
+void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
+  PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
+  void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2]));
+  delete pf;
+  *rv = fhandle;
+}
+
 RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
   RPCCode code;
   CHECK_EQ(sock_.RecvAll(&code, sizeof(int)), sizeof(int));
@@ -604,6 +617,7 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
     case RPCCode::kCopyToRemote: HandleCopyToRemote(); break;
     case RPCCode::kShutdown: break;
     // system functions
+    case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
     case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
     case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
     case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
@@ -620,5 +634,26 @@ RPCCode RPCSession::HandleNextEvent(TVMRetValue *rv) {
   }
   return code;
 }
+
+PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int nstep) {
+  auto ftimer = [pf, ctx, nstep](TVMArgs args, TVMRetValue *rv) {
+    TVMRetValue temp;
+    // skip first time call, to activate lazy compilation components.
+    pf.CallPacked(args, &temp);
+    DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
+    // start timing
+    auto tbegin = std::chrono::high_resolution_clock::now();
+    for (int i = 0; i < nstep; ++i) {
+      pf.CallPacked(args, &temp);
+    }
+    DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
+    auto tend = std::chrono::high_resolution_clock::now();
+    double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
+        tend - tbegin).count() / nstep;
+    // return the time.
+    *rv = speed;
+  };
+  return PackedFunc(ftimer);
+}
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index 0e30486a5..f56a9f87c 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -31,6 +31,7 @@ enum class RPCCode : int {
   kCopyAck,
   // The following are code that can send over CallRemote
   kGetGlobalFunc,
+  kGetTimeEvaluator,
   kFreeFunc,
   kDevSetDevice,
   kDevGetAttr,
@@ -92,6 +93,18 @@ class RPCSession {
                       size_t to_offset,
                       size_t size,
                       TVMContext ctx_from);
+  /*!
+   * \brief Get a remote timer function on ctx.
+   *  This function consumes fhandle, caller should not call Free on fhandle.
+   *
+   * \param fhandle The function handle.
+   * \param ctx The ctx to run measurement on.
+   * \param nstep Number of steps to run.
+   * \return A remote timer function
+   */
+  RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
+                                 TVMContext ctx,
+                                 int nstep);
   /*!
    * \brief Call a remote defined system function with arguments.
    * \param fcode The function code.
@@ -133,13 +146,13 @@ class RPCSession {
   void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n);
   void RecvPackedSeq(RPCArgBuffer *buf);
   RPCCode HandleNextEvent(TVMRetValue *rv);
+  TVMContext StripSessMask(TVMContext ctx);
   // special handler.
   void HandleCallFunc();
   void HandleException();
   void HandleCopyFromRemote();
   void HandleCopyToRemote();
   void HandleReturn(TVMRetValue* rv);
-  TVMContext StripSessMask(TVMContext ctx);
   // Internal mutex
   std::recursive_mutex mutex_;
   // Internal socket
@@ -152,6 +165,14 @@ class RPCSession {
   int table_index_{0};
 };
 
+/*!
+ * \brief Wrap a timer function for a given packed function.
+ * \param f The function argument.
+ * \param ctx The context.
+ * \param nstep Number of repeative steps.
+ */
+PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int nstep);
+
 // Remote space pointer.
 struct RemoteSpace {
   void* data;
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index e7553b2b8..990c900b3 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -95,7 +95,8 @@ def test_add():
         c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
         vbias = np.random.uniform()
         vscale = np.random.uniform()
-        fadd(a, b, c, vbias, vscale)
+        ftimer = fadd.time_evaluator(fadd.entry_name, ctx, number=1000)
+        tcost = ftimer(a, b, c, vbias, vscale)
         np.testing.assert_allclose(
             c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
 
diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py
index 900db1f60..bb6a59afa 100644
--- a/tests/python/integration/test_gemm.py
+++ b/tests/python/integration/test_gemm.py
@@ -78,14 +78,9 @@ def test_gemm():
         a = tvm.nd.array(a_np, ctx)
         b = tvm.nd.array(b_np, ctx)
         c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
-        f(a, b, c)
-        ctx.sync()
-        tbegin = time.time()
-        f(a, b, c)
-        tpush = time.time()
-        ctx.sync()
-        tend = time.time()
-        print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin))
+        ftimer = f.time_evaluator(f.entry_name, ctx, number=20)
+        tcost = ftimer(a, b, c)
+        print("%s: exec=%g sec/op" % (ctx, tcost))
         np.testing.assert_allclose(
             c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
 
diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py
index a1fe995a9..2078eabf5 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -70,7 +70,9 @@ def test_rpc_remote_module():
         f1 = remote.load_module("dev_lib.so")
         a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
         b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
-        f1(a, b)
+        time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10)
+        cost = time_f(a, b)
+        print('%g secs/op' % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
     check_remote()
 
-- 
GitLab