From aa1310d1d9c3733b69134352516984b10aa370be Mon Sep 17 00:00:00 2001
From: Liangfu Chen <liangfu.chen@icloud.com>
Date: Wed, 12 Dec 2018 13:54:39 +0800
Subject: [PATCH] Fix a issue when running with graph_runtime_debug in python
 (#2271)

* fix a issue when running with graph_runtime_debug in python;

* add support to `debug_get_output` in python;

* comply with the linter;
---
 python/tvm/contrib/debugger/debug_result.py  | 15 +++++++++++
 python/tvm/contrib/debugger/debug_runtime.py | 28 ++++++++++++++++++++
 python/tvm/contrib/graph_runtime.py          | 12 ++-------
 3 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py
index 5b563c86e..101af6887 100644
--- a/python/tvm/contrib/debugger/debug_result.py
+++ b/python/tvm/contrib/debugger/debug_result.py
@@ -93,6 +93,21 @@ class DebugResult(object):
         """
         return self._dtype_list
 
+    def get_output_tensors(self):
+        """Dump the outputs to a temporary folder, the tensors are in numpy format
+        """
+        eid = 0
+        order = 0
+        output_tensors = {}
+        for node, time in zip(self._nodes_list, self._time_list):
+            num_outputs = self.get_graph_node_output_num(node)
+            for j in range(num_outputs):
+                order += time[0]
+                key = node['name'] + "_" + str(j)
+                output_tensors[key] = self._output_tensor_list[eid]
+                eid += 1
+        return output_tensors
+
     def dump_output_tensor(self):
         """Dump the outputs to a temporary folder, the tensors are in numpy format
         """
diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py
index 6642a8bdc..d38ee6cf7 100644
--- a/python/tvm/contrib/debugger/debug_runtime.py
+++ b/python/tvm/contrib/debugger/debug_runtime.py
@@ -173,6 +173,34 @@ class GraphModuleDebug(graph_runtime.GraphModule):
             for j in range(num_outputs):
                 out_tensor = self._get_output_by_layer(i, j)
                 self.debug_datum._output_tensor_list.append(out_tensor)
+
+    def debug_get_output(self, node, out):
+        """Run graph upto node and get the output to out
+
+        Parameters
+        ----------
+        node : int / str
+            The node index or name
+
+        out : NDArray
+            The output array container
+        """
+        ret = None
+        if isinstance(node, str):
+            output_tensors = self.debug_datum.get_output_tensors()
+            try:
+                ret = output_tensors[node]
+            except:
+                node_list = output_tensors.keys()
+                raise RuntimeError("Node " + node + " not found, available nodes are: "
+                                   + str(node_list) + ".")
+        elif isinstance(node, int):
+            output_tensors = self.debug_datum._output_tensor_list
+            ret = output_tensors[node]
+        else:
+            raise RuntimeError("Require node index or name only.")
+        return ret
+
     def run(self, **input_dict):
         """Run forward execution of the graph with debug
 
diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py
index 1ba402e20..0d62a04a5 100644
--- a/python/tvm/contrib/graph_runtime.py
+++ b/python/tvm/contrib/graph_runtime.py
@@ -112,10 +112,6 @@ class GraphModule(object):
         self._get_output = module["get_output"]
         self._get_input = module["get_input"]
         self._get_num_outputs = module["get_num_outputs"]
-        try:
-            self._debug_get_output = module["debug_get_output"]
-        except AttributeError:
-            pass
         self._load_params = module["load_params"]
 
     def set_input(self, key=None, value=None, **params):
@@ -209,12 +205,8 @@ class GraphModule(object):
         out : NDArray
             The output array container
         """
-        if hasattr(self, '_debug_get_output'):
-            self._debug_get_output(node, out)
-        else:
-            raise RuntimeError(
-                "Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
-        return out
+        raise NotImplementedError(
+            "Please use debugger.debug_runtime as graph_runtime instead.")
 
     def load_params(self, params_bytes):
         """Load parameters from serialized byte array of parameter dict.
-- 
GitLab