From 47b8c36dcfa85c983f4242f8bdacfecbd8f26c1a Mon Sep 17 00:00:00 2001
From: Siju <sijusamuel@gmail.com>
Date: Mon, 15 Oct 2018 06:43:52 +0530
Subject: [PATCH] [RUNTIME][DEBUG]Support remote debugging (#1866)

---
 python/tvm/contrib/debugger/debug_runtime.py  | 11 +++++---
 .../graph/debug/graph_runtime_debug.cc        | 13 +++++++++
 .../unittest/test_runtime_graph_debug.py      | 27 +++++++++++++++++++
 3 files changed, 48 insertions(+), 3 deletions(-)

diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py
index 986a7b167..25d17d528 100644
--- a/python/tvm/contrib/debugger/debug_runtime.py
+++ b/python/tvm/contrib/debugger/debug_runtime.py
@@ -5,8 +5,9 @@ import tempfile
 import shutil
 from datetime import datetime
 from tvm._ffi.base import string_types
-from tvm.contrib import graph_runtime
 from tvm._ffi.function import get_global_func
+from tvm.contrib import graph_runtime
+from tvm.rpc import base as rpc_base
 from . import debug_result
 
 _DUMP_ROOT_PREFIX = "tvmdbg_"
@@ -49,8 +50,12 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
 
     ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
     if num_rpc_ctx == len(ctx):
-        raise NotSupportedError("Remote graph debugging is not supported.")
-
+        libmod = rpc_base._ModuleHandle(libmod)
+        try:
+            fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create")
+        except ValueError:
+            raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \
+                             "config.cmake and rebuild TVM to enable debug mode")
     func_obj = fcreate(graph_json_str, libmod, *device_type_id)
     return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
 
diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc
index 7faee4420..452a48408 100644
--- a/src/runtime/graph/debug/graph_runtime_debug.cc
+++ b/src/runtime/graph/debug/graph_runtime_debug.cc
@@ -146,5 +146,18 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
         << args.num_args;
     *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
   });
+
+TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
+  .set_body([](TVMArgs args, TVMRetValue* rv) {
+    CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
+                                  "graph_runtime.remote_create is "
+                                  "at least 4, but it has "
+                               << args.num_args;
+    void* mhandle = args[1];
+    const auto& contexts = GetAllContext(args);
+    *rv = GraphRuntimeDebugCreate(
+        args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
+  });
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py
index ab6b72997..b9d8b689c 100644
--- a/tests/python/unittest/test_runtime_graph_debug.py
+++ b/tests/python/unittest/test_runtime_graph_debug.py
@@ -2,6 +2,8 @@ import os
 import tvm
 import numpy as np
 import json
+from tvm import rpc
+from tvm.contrib import util
 from tvm.contrib.debugger import debug_runtime as graph_runtime
 
 def test_graph_simple():
@@ -70,7 +72,32 @@ def test_graph_simple():
         #verify dump root delete after cleanup
         assert(not os.path.exists(directory))
 
+    def check_remote():
+        if not tvm.module.enabled("llvm"):
+            print("Skip because llvm is not enabled")
+            return
+        mlib = tvm.build(s, [A, B], "llvm", name="myadd")
+        server = rpc.Server("localhost")
+        remote = rpc.connect(server.host, server.port)
+        temp = util.tempdir()
+        ctx = remote.cpu(0)
+        path_dso = temp.relpath("dev_lib.so")
+        mlib.export_library(path_dso)
+        remote.upload(path_dso)
+        mlib = remote.load_module("dev_lib.so")
+        try:
+            mod = graph_runtime.create(graph, mlib, remote.cpu(0))
+        except ValueError:
+            print("Skip because debug graph_runtime not enabled")
+            return
+        a = np.random.uniform(size=(n,)).astype(A.dtype)
+        mod.run(x=tvm.nd.array(a, ctx))
+        out = tvm.nd.empty((n,), ctx=ctx)
+        out = mod.get_output(0, out)
+        np.testing.assert_equal(out.asnumpy(), a + 1)
+
     check_verify()
+    check_remote()
 
 if __name__ == "__main__":
     test_graph_simple()
-- 
GitLab