From 00e6108eac57e0e4d7118e56522566d6097905f9 Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Tue, 25 Dec 2018 10:32:51 -0800 Subject: [PATCH] Add a the ability to trigger debugging in the interpreter without recompiling (#2219) --- include/tvm/relay/attrs/debug.h | 29 ++++++++++++++ include/tvm/relay/op_attr_types.h | 5 +++ python/tvm/relay/__init__.py | 7 +--- python/tvm/relay/debug.py | 25 ++++++++++++ python/tvm/relay/op/__init__.py | 2 + python/tvm/relay/op/op.py | 16 ++++++++ src/relay/backend/interpreter.cc | 63 ++++++++++++++++++++++++++++++- src/relay/op/debug.cc | 54 ++++++++++++++++++++++++++ tests/python/relay/test_debug.py | 32 ++++++++++++++++ 9 files changed, 226 insertions(+), 7 deletions(-) create mode 100644 include/tvm/relay/attrs/debug.h create mode 100644 python/tvm/relay/debug.py create mode 100644 src/relay/op/debug.cc create mode 100644 tests/python/relay/test_debug.py diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h new file mode 100644 index 000000000..8243dc0a3 --- /dev/null +++ b/include/tvm/relay/attrs/debug.h @@ -0,0 +1,29 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/attrs/debug.h + * \brief Auxiliary attributes for debug operators. + */ +#ifndef TVM_RELAY_ATTRS_DEBUG_H_ +#define TVM_RELAY_ATTRS_DEBUG_H_ + +#include <tvm/attrs.h> +#include <string> + +namespace tvm { +namespace relay { + +/*! + * \brief Options for the debug operators. + */ +struct DebugAttrs : public tvm::AttrsNode<DebugAttrs> { + EnvFunc debug_func; + + TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") { + TVM_ATTR_FIELD(debug_func) + .describe("The function to use when debugging."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_DEBUG_H_ diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 1f37e9947..c2839a471 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -48,6 +48,11 @@ using TOpPattern = int; */ using TOpIsStateful = bool; +/*! + * \brief Mark the operator as non-computational. + */ +using TNonComputational = bool; + /*! * \brief Computation description interface. * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 69180837b..572589921 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -10,6 +10,7 @@ from . import module from . import ir_pass from .build_module import build, build_config, create_executor from . import parser +from . import debug # Root operators from .op import Op @@ -63,11 +64,5 @@ var = expr.var const = expr.const bind = expr.bind -# pylint: disable=unused-argument -@register_func("relay.debug") -def _debug(*args): - import pdb - pdb.set_trace() - # Parser fromtext = parser.fromtext diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py new file mode 100644 index 000000000..00ad7b440 --- /dev/null +++ b/python/tvm/relay/debug.py @@ -0,0 +1,25 @@ +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The Relay IR namespace containing the IR definition and compiler.""" +from __future__ import absolute_import +from .base import NodeBase, register_relay_node +from ..api import register_func + +@register_relay_node +class InterpreterState(NodeBase): + pass + +# pylint: disable=unused-argument +def _debugger_init(expr, stack): + import pdb + pdb.set_trace() + +# pylint: disable=unused-argument +@register_func("relay.debug") +def _debug(*args): + _, _, _, ist = args + print("Relay Debugger") + print(" You can manipulate the expression under evaluation with the name `expr`.") + print(" You can manipulate the call stack with the name `stack`.") + print("--------------") + print("--------------") + _debugger_init(ist.current_expr, ist.stack) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 4a6dfd9f7..63baa5128 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -3,6 +3,7 @@ # operator defs from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \ Op +from .op import debug # Operators from .reduce import * @@ -13,6 +14,7 @@ from . import image from . import vision from . import op_attrs + # operator registry from . import _tensor from . import _transform diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index dd3af9c44..b027211ac 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -8,6 +8,7 @@ from ..base import register_relay_node from ..expr import Expr from ...api import register_func from ...build_module import lower, build +from . import _make @register_relay_node class Op(Expr): @@ -183,3 +184,18 @@ def schedule_injective(attrs, outputs, target): """Generic schedule for binary broadcast.""" with target: return topi.generic.schedule_injective(outputs) + +__DEBUG_COUNTER__ = 0 + +def debug(expr, debug_func=None): + """The main entry point to the debugger.""" + global __DEBUG_COUNTER__ + + if debug_func: + name = "debugger_func{}".format(__DEBUG_COUNTER__) + register_func(name, debug_func) + __DEBUG_COUNTER__ += 1 + else: + name = '' + + return _make.debug(expr, name) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 33d06e9c6..734180c53 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -8,6 +8,7 @@ #include <tvm/relay/expr_functor.h> #include <tvm/relay/interpreter.h> #include <tvm/relay/pass.h> +#include <tvm/relay/attrs/debug.h> #include "compile_engine.h" namespace tvm { @@ -124,13 +125,48 @@ struct Stack { }; }; +/*! \brief A representation of the interpreter state which can be passed back to Python. */ +class InterpreterState; + +/*! \brief A container capturing the state of the interpreter. */ +class InterpreterStateNode : public Node { + public: + using Frame = tvm::Map<Var, Value>; + using Stack = tvm::Array<Frame>; + + /*! \brief The current expression under evaluation. */ + Expr current_expr; + + /*! \brief The call stack of the interpreter. */ + Stack stack; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("current_expr", ¤t_expr); + v->Visit("stack", &stack); + } + + TVM_DLL static InterpreterState make(Expr current_expr, Stack stack); + + static constexpr const char* _type_key = "relay.InterpreterState"; + TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node); +}; + +RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef); + +InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { + NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>(); + n->current_expr = std::move(current_expr); + n->stack = std::move(stack); + return InterpreterState(n); +} + // NOTE: the current interpreter assumes A-normal form. // which is better for execution. // // It will run duplicated computations when taking program that // contains DAG in dataflow-form. -// Conversion to ANF is recommended before running the interpretation. // +// Conversion to ANF is recommended before running the interpretation. class Interpreter : public ExprFunctor<Value(const Expr& n)> { public: @@ -209,6 +245,21 @@ class Interpreter : Value InvokePrimitiveOp(Function func, const Array<Value>& args) { + auto call_node = func->body.as<CallNode>(); + + if (call_node && call_node->op == Op::Get("debug")) { + auto dattrs = call_node->attrs.as<DebugAttrs>(); + auto interp_state = this->get_state(call_node->args[0]); + + if (dattrs->debug_func.defined()) { + dattrs->debug_func(interp_state); + } else { + RELAY_DEBUG(interp_state); + } + + return args[0]; + } + // Marshal the arguments. // Handle tuple input/output by flattening them. size_t arg_len = 0; @@ -381,6 +432,16 @@ class Interpreter : } } + InterpreterState get_state(Expr e = Expr()) const { + InterpreterStateNode::Stack stack; + for (auto fr : this->stack_.frames) { + InterpreterStateNode::Frame frame = fr.locals; + stack.push_back(frame); + } + auto state = InterpreterStateNode::make(e, stack); + return state; + } + private: // module Module mod_; diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc new file mode 100644 index 000000000..4c9b0a5ca --- /dev/null +++ b/src/relay/op/debug.cc @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file nn.cc + * \brief Property def of nn operators. + */ + +#include <tvm/relay/op.h> +#include <tvm/relay/attrs/debug.h> +#include <topi/elemwise.h> +#include <vector> +#include "./type_relations.h" +#include "./op_common.h" +#include "./layout.h" + +namespace tvm { +namespace relay { + +Array<Tensor> DebugCompute(const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + return Array<Tensor>{ topi::identity(inputs[0]) }; +} + +RELAY_REGISTER_OP("debug") +.describe(R"code(Enter the interpreter's debugger. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("program", "Tuple", "The program to execute before debugging.") +.set_support_level(1) +.add_type_rel("Debug", IdentityRel) +.set_attr<TOpPattern>("TOpPattern", kOpaque) +.set_attr<FTVMCompute>("FTVMCompute", DebugCompute); + +Expr MakeDebug(Expr expr, std::string name) { + auto dattrs = make_node<DebugAttrs>(); + if (name.size() > 0) { + dattrs->debug_func = EnvFunc::Get(name); + } else { + dattrs->debug_func = EnvFunc(); + } + static const Op& op = Op::Get("debug"); + return CallNode::make(op, {expr}, Attrs(dattrs), {}); +} + +TVM_REGISTER_API("relay.op._make.debug") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv); + }); + +} // namespace relay +} // namespace tvm + diff --git a/tests/python/relay/test_debug.py b/tests/python/relay/test_debug.py new file mode 100644 index 000000000..3463e2916 --- /dev/null +++ b/tests/python/relay/test_debug.py @@ -0,0 +1,32 @@ +from tvm.relay import var, const, create_executor +from tvm.relay.op import debug + + +_test_debug_hit = False + +def test_debug(): + global _test_debug_hit + ex = create_executor() + x = var('x', shape=(), dtype='int32') + _test_debug_hit = False + def did_exec(x): + global _test_debug_hit + _test_debug_hit = True + prog = debug(x, debug_func=did_exec) + result = ex.evaluate(prog, { x: const(1) }) + assert _test_debug_hit + assert result.asnumpy() == 1 + +def test_debug_with_expr(): + global _test_debug_hit + _test_debug_hit = False + ex = create_executor() + x = var('x', shape=(), dtype='int32') + _test_debug_hit = False + def did_exec(x): + global _test_debug_hit + _test_debug_hit = True + prog = debug(x + x * x, debug_func=did_exec) + result = ex.evaluate(prog, { x: const(2) }) + assert _test_debug_hit + assert result.asnumpy() == 6 -- GitLab