From aa1310d1d9c3733b69134352516984b10aa370be Mon Sep 17 00:00:00 2001 From: Liangfu Chen <liangfu.chen@icloud.com> Date: Wed, 12 Dec 2018 13:54:39 +0800 Subject: [PATCH] Fix a issue when running with graph_runtime_debug in python (#2271) * fix a issue when running with graph_runtime_debug in python; * add support to `debug_get_output` in python; * comply with the linter; --- python/tvm/contrib/debugger/debug_result.py | 15 +++++++++++ python/tvm/contrib/debugger/debug_runtime.py | 28 ++++++++++++++++++++ python/tvm/contrib/graph_runtime.py | 12 ++------- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 5b563c86e..101af6887 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -93,6 +93,21 @@ class DebugResult(object): """ return self._dtype_list + def get_output_tensors(self): + """Dump the outputs to a temporary folder, the tensors are in numpy format + """ + eid = 0 + order = 0 + output_tensors = {} + for node, time in zip(self._nodes_list, self._time_list): + num_outputs = self.get_graph_node_output_num(node) + for j in range(num_outputs): + order += time[0] + key = node['name'] + "_" + str(j) + output_tensors[key] = self._output_tensor_list[eid] + eid += 1 + return output_tensors + def dump_output_tensor(self): """Dump the outputs to a temporary folder, the tensors are in numpy format """ diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 6642a8bdc..d38ee6cf7 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -173,6 +173,34 @@ class GraphModuleDebug(graph_runtime.GraphModule): for j in range(num_outputs): out_tensor = self._get_output_by_layer(i, j) self.debug_datum._output_tensor_list.append(out_tensor) + + def debug_get_output(self, node, out): + """Run graph upto node and get the output to out + + Parameters + ---------- + node : int / str + The node index or name + + out : NDArray + The output array container + """ + ret = None + if isinstance(node, str): + output_tensors = self.debug_datum.get_output_tensors() + try: + ret = output_tensors[node] + except: + node_list = output_tensors.keys() + raise RuntimeError("Node " + node + " not found, available nodes are: " + + str(node_list) + ".") + elif isinstance(node, int): + output_tensors = self.debug_datum._output_tensor_list + ret = output_tensors[node] + else: + raise RuntimeError("Require node index or name only.") + return ret + def run(self, **input_dict): """Run forward execution of the graph with debug diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 1ba402e20..0d62a04a5 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -112,10 +112,6 @@ class GraphModule(object): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] - try: - self._debug_get_output = module["debug_get_output"] - except AttributeError: - pass self._load_params = module["load_params"] def set_input(self, key=None, value=None, **params): @@ -209,12 +205,8 @@ class GraphModule(object): out : NDArray The output array container """ - if hasattr(self, '_debug_get_output'): - self._debug_get_output(node, out) - else: - raise RuntimeError( - "Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") - return out + raise NotImplementedError( + "Please use debugger.debug_runtime as graph_runtime instead.") def load_params(self, params_bytes): """Load parameters from serialized byte array of parameter dict. -- GitLab