diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 5b563c86e6e4b2c87e37331c942b7216e7bb16c6..101af6887c47eeb050161ea239b36de78e22344a 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -93,6 +93,21 @@ class DebugResult(object): """ return self._dtype_list + def get_output_tensors(self): + """Dump the outputs to a temporary folder, the tensors are in numpy format + """ + eid = 0 + order = 0 + output_tensors = {} + for node, time in zip(self._nodes_list, self._time_list): + num_outputs = self.get_graph_node_output_num(node) + for j in range(num_outputs): + order += time[0] + key = node['name'] + "_" + str(j) + output_tensors[key] = self._output_tensor_list[eid] + eid += 1 + return output_tensors + def dump_output_tensor(self): """Dump the outputs to a temporary folder, the tensors are in numpy format """ diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 6642a8bdc82256f44897079ad37ad914501c1781..d38ee6cf7982e63429124f2a26fec713342f05e3 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -173,6 +173,34 @@ class GraphModuleDebug(graph_runtime.GraphModule): for j in range(num_outputs): out_tensor = self._get_output_by_layer(i, j) self.debug_datum._output_tensor_list.append(out_tensor) + + def debug_get_output(self, node, out): + """Run graph upto node and get the output to out + + Parameters + ---------- + node : int / str + The node index or name + + out : NDArray + The output array container + """ + ret = None + if isinstance(node, str): + output_tensors = self.debug_datum.get_output_tensors() + try: + ret = output_tensors[node] + except: + node_list = output_tensors.keys() + raise RuntimeError("Node " + node + " not found, available nodes are: " + + str(node_list) + ".") + elif isinstance(node, int): + output_tensors = self.debug_datum._output_tensor_list + ret = output_tensors[node] + else: + raise RuntimeError("Require node index or name only.") + return ret + def run(self, **input_dict): """Run forward execution of the graph with debug diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 1ba402e20e7e3ed4ff278d2e862b0c3191dc8f24..0d62a04a55710533820b1c2e0d5d2a6ff645d233 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -112,10 +112,6 @@ class GraphModule(object): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] - try: - self._debug_get_output = module["debug_get_output"] - except AttributeError: - pass self._load_params = module["load_params"] def set_input(self, key=None, value=None, **params): @@ -209,12 +205,8 @@ class GraphModule(object): out : NDArray The output array container """ - if hasattr(self, '_debug_get_output'): - self._debug_get_output(node, out) - else: - raise RuntimeError( - "Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") - return out + raise NotImplementedError( + "Please use debugger.debug_runtime as graph_runtime instead.") def load_params(self, params_bytes): """Load parameters from serialized byte array of parameter dict.