From 00e6108eac57e0e4d7118e56522566d6097905f9 Mon Sep 17 00:00:00 2001
From: Jared Roesch <roeschinc@gmail.com>
Date: Tue, 25 Dec 2018 10:32:51 -0800
Subject: [PATCH] Add a the ability to trigger debugging in the interpreter
 without recompiling (#2219)

---
 include/tvm/relay/attrs/debug.h   | 29 ++++++++++++++
 include/tvm/relay/op_attr_types.h |  5 +++
 python/tvm/relay/__init__.py      |  7 +---
 python/tvm/relay/debug.py         | 25 ++++++++++++
 python/tvm/relay/op/__init__.py   |  2 +
 python/tvm/relay/op/op.py         | 16 ++++++++
 src/relay/backend/interpreter.cc  | 63 ++++++++++++++++++++++++++++++-
 src/relay/op/debug.cc             | 54 ++++++++++++++++++++++++++
 tests/python/relay/test_debug.py  | 32 ++++++++++++++++
 9 files changed, 226 insertions(+), 7 deletions(-)
 create mode 100644 include/tvm/relay/attrs/debug.h
 create mode 100644 python/tvm/relay/debug.py
 create mode 100644 src/relay/op/debug.cc
 create mode 100644 tests/python/relay/test_debug.py

diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h
new file mode 100644
index 000000000..8243dc0a3
--- /dev/null
+++ b/include/tvm/relay/attrs/debug.h
@@ -0,0 +1,29 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file tvm/relay/attrs/debug.h
+ * \brief Auxiliary attributes for debug operators.
+ */
+#ifndef TVM_RELAY_ATTRS_DEBUG_H_
+#define TVM_RELAY_ATTRS_DEBUG_H_
+
+#include <tvm/attrs.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Options for the debug operators.
+ */
+struct DebugAttrs : public tvm::AttrsNode<DebugAttrs> {
+  EnvFunc debug_func;
+
+  TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") {
+    TVM_ATTR_FIELD(debug_func)
+        .describe("The function to use when debugging.");
+  }
+};
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_ATTRS_DEBUG_H_
diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h
index 1f37e9947..c2839a471 100644
--- a/include/tvm/relay/op_attr_types.h
+++ b/include/tvm/relay/op_attr_types.h
@@ -48,6 +48,11 @@ using TOpPattern = int;
  */
 using TOpIsStateful = bool;
 
+/*!
+ * \brief Mark the operator as non-computational.
+ */
+using TNonComputational = bool;
+
 /*!
  * \brief Computation description interface.
  *
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 69180837b..572589921 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -10,6 +10,7 @@ from . import module
 from . import ir_pass
 from .build_module import build, build_config, create_executor
 from . import parser
+from . import debug
 
 # Root operators
 from .op import Op
@@ -63,11 +64,5 @@ var = expr.var
 const = expr.const
 bind = expr.bind
 
-# pylint: disable=unused-argument
-@register_func("relay.debug")
-def _debug(*args):
-    import pdb
-    pdb.set_trace()
-
 # Parser
 fromtext = parser.fromtext
diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py
new file mode 100644
index 000000000..00ad7b440
--- /dev/null
+++ b/python/tvm/relay/debug.py
@@ -0,0 +1,25 @@
+# pylint: disable=wildcard-import, redefined-builtin, invalid-name
+"""The Relay IR namespace containing the IR definition and compiler."""
+from __future__ import absolute_import
+from .base import NodeBase, register_relay_node
+from ..api import register_func
+
+@register_relay_node
+class InterpreterState(NodeBase):
+    pass
+
+# pylint: disable=unused-argument
+def _debugger_init(expr, stack):
+    import pdb
+    pdb.set_trace()
+
+# pylint: disable=unused-argument
+@register_func("relay.debug")
+def _debug(*args):
+    _, _, _, ist = args
+    print("Relay Debugger")
+    print("  You can manipulate the expression under evaluation with the name `expr`.")
+    print("  You can manipulate the call stack with the name `stack`.")
+    print("--------------")
+    print("--------------")
+    _debugger_init(ist.current_expr, ist.stack)
diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py
index 4a6dfd9f7..63baa5128 100644
--- a/python/tvm/relay/op/__init__.py
+++ b/python/tvm/relay/op/__init__.py
@@ -3,6 +3,7 @@
 # operator defs
 from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
     Op
+from .op import debug
 
 # Operators
 from .reduce import *
@@ -13,6 +14,7 @@ from . import image
 from . import vision
 from . import op_attrs
 
+
 # operator registry
 from . import _tensor
 from . import _transform
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index dd3af9c44..b027211ac 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -8,6 +8,7 @@ from ..base import register_relay_node
 from ..expr import Expr
 from ...api import register_func
 from ...build_module import lower, build
+from . import _make
 
 @register_relay_node
 class Op(Expr):
@@ -183,3 +184,18 @@ def schedule_injective(attrs, outputs, target):
     """Generic schedule for binary broadcast."""
     with target:
         return topi.generic.schedule_injective(outputs)
+
+__DEBUG_COUNTER__ = 0
+
+def debug(expr, debug_func=None):
+    """The main entry point to the debugger."""
+    global __DEBUG_COUNTER__
+
+    if debug_func:
+        name = "debugger_func{}".format(__DEBUG_COUNTER__)
+        register_func(name, debug_func)
+        __DEBUG_COUNTER__ += 1
+    else:
+        name = ''
+
+    return _make.debug(expr, name)
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 33d06e9c6..734180c53 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -8,6 +8,7 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/pass.h>
+#include <tvm/relay/attrs/debug.h>
 #include "compile_engine.h"
 
 namespace tvm {
@@ -124,13 +125,48 @@ struct Stack {
   };
 };
 
+/*! \brief A representation of the interpreter state which can be passed back to Python. */
+class InterpreterState;
+
+/*! \brief A container capturing the state of the interpreter. */
+class InterpreterStateNode : public Node {
+ public:
+  using Frame = tvm::Map<Var, Value>;
+  using Stack = tvm::Array<Frame>;
+
+  /*! \brief The current expression under evaluation. */
+  Expr current_expr;
+
+  /*! \brief The call stack of the interpreter. */
+  Stack stack;
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("current_expr", &current_expr);
+    v->Visit("stack", &stack);
+  }
+
+  TVM_DLL static InterpreterState make(Expr current_expr, Stack stack);
+
+  static constexpr const char* _type_key = "relay.InterpreterState";
+  TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node);
+};
+
+RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef);
+
+InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
+  NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>();
+  n->current_expr = std::move(current_expr);
+  n->stack = std::move(stack);
+  return InterpreterState(n);
+}
+
 // NOTE: the current interpreter assumes A-normal form.
 // which is better for execution.
 //
 // It will run duplicated computations when taking program that
 // contains DAG in dataflow-form.
-// Conversion to ANF is recommended before running the interpretation.
 //
+// Conversion to ANF is recommended before running the interpretation.
 class Interpreter :
       public ExprFunctor<Value(const Expr& n)> {
  public:
@@ -209,6 +245,21 @@ class Interpreter :
 
   Value InvokePrimitiveOp(Function func,
                           const Array<Value>& args) {
+    auto call_node = func->body.as<CallNode>();
+
+    if (call_node && call_node->op == Op::Get("debug")) {
+      auto dattrs = call_node->attrs.as<DebugAttrs>();
+      auto interp_state = this->get_state(call_node->args[0]);
+
+      if (dattrs->debug_func.defined()) {
+        dattrs->debug_func(interp_state);
+      } else {
+        RELAY_DEBUG(interp_state);
+      }
+
+      return args[0];
+    }
+
     // Marshal the arguments.
     // Handle tuple input/output by flattening them.
     size_t arg_len = 0;
@@ -381,6 +432,16 @@ class Interpreter :
     }
   }
 
+  InterpreterState get_state(Expr e = Expr()) const {
+    InterpreterStateNode::Stack stack;
+    for (auto fr : this->stack_.frames) {
+      InterpreterStateNode::Frame frame = fr.locals;
+      stack.push_back(frame);
+    }
+    auto state = InterpreterStateNode::make(e, stack);
+    return state;
+  }
+
  private:
   // module
   Module mod_;
diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc
new file mode 100644
index 000000000..4c9b0a5ca
--- /dev/null
+++ b/src/relay/op/debug.cc
@@ -0,0 +1,54 @@
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file nn.cc
+ * \brief Property def of nn operators.
+ */
+
+#include <tvm/relay/op.h>
+#include <tvm/relay/attrs/debug.h>
+#include <topi/elemwise.h>
+#include <vector>
+#include "./type_relations.h"
+#include "./op_common.h"
+#include "./layout.h"
+
+namespace tvm {
+namespace relay {
+
+Array<Tensor> DebugCompute(const Attrs& attrs,
+                               const Array<Tensor>& inputs,
+                               const Type& out_type,
+                               const Target& target) {
+  return Array<Tensor>{ topi::identity(inputs[0]) };
+}
+
+RELAY_REGISTER_OP("debug")
+.describe(R"code(Enter the interpreter's debugger.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("program", "Tuple", "The program to execute before debugging.")
+.set_support_level(1)
+.add_type_rel("Debug", IdentityRel)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_attr<FTVMCompute>("FTVMCompute", DebugCompute);
+
+Expr MakeDebug(Expr expr, std::string name) {
+  auto dattrs = make_node<DebugAttrs>();
+  if (name.size() > 0) {
+    dattrs->debug_func = EnvFunc::Get(name);
+  } else {
+    dattrs->debug_func = EnvFunc();
+  }
+  static const Op& op = Op::Get("debug");
+  return CallNode::make(op, {expr}, Attrs(dattrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.debug")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
+  });
+
+}  // namespace relay
+}  // namespace tvm
+
diff --git a/tests/python/relay/test_debug.py b/tests/python/relay/test_debug.py
new file mode 100644
index 000000000..3463e2916
--- /dev/null
+++ b/tests/python/relay/test_debug.py
@@ -0,0 +1,32 @@
+from tvm.relay import var, const, create_executor
+from tvm.relay.op import debug
+
+
+_test_debug_hit = False
+
+def test_debug():
+    global _test_debug_hit
+    ex = create_executor()
+    x = var('x', shape=(), dtype='int32')
+    _test_debug_hit = False
+    def did_exec(x):
+        global _test_debug_hit
+        _test_debug_hit = True
+    prog = debug(x, debug_func=did_exec)
+    result = ex.evaluate(prog, { x: const(1) })
+    assert _test_debug_hit
+    assert result.asnumpy() == 1
+
+def test_debug_with_expr():
+    global _test_debug_hit
+    _test_debug_hit = False
+    ex = create_executor()
+    x = var('x', shape=(), dtype='int32')
+    _test_debug_hit = False
+    def did_exec(x):
+        global _test_debug_hit
+        _test_debug_hit = True
+    prog = debug(x + x * x, debug_func=did_exec)
+    result = ex.evaluate(prog, { x: const(2) })
+    assert _test_debug_hit
+    assert result.asnumpy() == 6
-- 
GitLab