From 395804e52447cd7e1812f0c3959ad990e3cbfccc Mon Sep 17 00:00:00 2001 From: Jared Roesch <roeschinc@gmail.com> Date: Sat, 22 Dec 2018 09:42:48 -0800 Subject: [PATCH] Small refactors and bug fixes. (#2281) --- include/tvm/relay/expr.h | 7 + python/tvm/relay/__init__.py | 5 + .../relay/backend/graph_runtime_codegen.py | 6 +- python/tvm/relay/expr.py | 130 +-------------- python/tvm/relay/expr_functor.py | 155 ++++++++++++++++++ src/relay/backend/compile_engine.cc | 8 +- src/relay/backend/interpreter.cc | 9 +- src/relay/ir/expr.cc | 8 +- src/relay/pass/fuse_ops.cc | 4 +- 9 files changed, 188 insertions(+), 144 deletions(-) create mode 100644 python/tvm/relay/expr_functor.py diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 37c91ffe4..14b3cd917 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -248,6 +248,13 @@ class FunctionNode : public ExprNode { */ TVM_DLL FuncType func_type_annotation() const; + /*! + * \brief Check whether the function is a primitive function. + * + * \return Whether the function is primitive or not. + */ + bool IsPrimitive() const; + TVM_DLL static Function make(tvm::Array<Var> params, Expr body, Type ret_type, diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b66132f27..69180837b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -5,6 +5,7 @@ from ..api import register_func from . import base from . import ty from . import expr +from . import expr_functor from . import module from . import ir_pass from .build_module import build, build_config, create_executor @@ -53,6 +54,10 @@ Let = expr.Let If = expr.If TupleGetItem = expr.TupleGetItem +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +ExprMutator = expr_functor.ExprMutator + # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 0da9b8126..91d09973e 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -24,7 +24,8 @@ import attr from . import _backend from . import compile_engine from ..op import Op -from ..expr import Function, GlobalVar, ExprFunctor +from ..expr import Function, GlobalVar +from ..expr_functor import ExprFunctor from ..ty import TupleType, TensorType @@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor): op_name, inputs, {}) return self.add_node(op_node, call) + def visit_op(self, _): + raise Exception("can not compile op in non-eta expanded form") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4725c0a7a..e0c1f68ad 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -222,12 +222,13 @@ class Function(Expr): params, body, ret_type=None, - type_params=None): + type_params=None, + attrs=None): if type_params is None: type_params = convert([]) self.__init_handle_by_constructor__( - _make.Function, params, body, ret_type, type_params) + _make.Function, params, body, ret_type, type_params, attrs) def __call__(self, *args): """Invoke the gobal function. @@ -343,131 +344,6 @@ class TempExpr(Expr): return _expr.TempExprRealize(self) -class ExprFunctor(object): - """ - An abstract visitor defined over Expr. - - Defines the default dispatch over expressions, and - implements memoization. - """ - def __init__(self): - self.memo_map = {} - - # pylint: disable=no-else-return - def visit(self, expr): - """Apply the visitor to an expression.""" - found = self.memo_map.get(expr) - if found: - return found - - if isinstance(expr, Function): - res = self.visit_function(expr) - elif isinstance(expr, Call): - res = self.visit_call(expr) - elif isinstance(expr, Let): - res = self.visit_let(expr) - elif isinstance(expr, Var): - res = self.visit_var(expr) - elif isinstance(expr, GlobalVar): - res = self.visit_global_var(expr) - elif isinstance(expr, If): - res = self.visit_if(expr) - elif isinstance(expr, Tuple): - res = self.visit_tuple(expr) - elif isinstance(expr, TupleGetItem): - res = self.visit_tuple_getitem(expr) - elif isinstance(expr, Constant): - res = self.visit_constant(expr) - else: - raise Exception("warning unhandled case: {0}".format(type(expr))) - - self.memo_map[expr] = res - return res - - def visit_function(self, _): - raise NotImplementedError() - - def visit_let(self, _): - raise NotImplementedError() - - def visit_call(self, _): - raise NotImplementedError() - - def visit_var(self, _): - raise NotImplementedError() - - def visit_type(self, typ): - return typ - - def visit_if(self, _): - raise NotImplementedError() - - def visit_tuple(self, _): - raise NotImplementedError() - - def visit_tuple_getitem(self, _): - raise NotImplementedError() - - def visit_constant(self, _): - raise NotImplementedError() - - def visit_global_var(self, _): - raise NotImplementedError() - - -class ExprMutator(ExprFunctor): - """ - A functional visitor over Expr. - - The default behavior recursively traverses the AST - and reconstructs the AST. - """ - def visit_function(self, fn): - new_body = self.visit(fn.body) - return Function( - list(fn.params), - fn.ret_type, new_body, - fn.type_params) - - def visit_let(self, let): - new_var = self.visit(let.var) - new_val = self.visit(let.value) - new_body = self.visit(let.body) - return Let(new_var, new_val, new_body) - - def visit_call(self, call): - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - return Call(new_fn, new_args, call.attrs) - - def visit_var(self, rvar): - return rvar - - def visit_global_id(self, global_var): - return global_var - - def visit_if(self, ite): - return If( - self.visit(ite.guard), - self.visit(ite.true_b), - self.visit(ite.false_b)) - - def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_tuple_getitem(self, op): - tuple_value = self.visit(op.tuple_value) - if not tuple_value.same_as(op.tuple_value): - return TupleGetItem(tuple_value, op.index) - return op - - def visit_global_var(self, gvar): - return gvar - - def visit_constant(self, rconst): - return rconst - - class TupleWrapper(object): """TupleWrapper. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py new file mode 100644 index 000000000..eafe5f093 --- /dev/null +++ b/python/tvm/relay/expr_functor.py @@ -0,0 +1,155 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression functor of Relay.""" + +from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant +from .op import Op + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + + Defines the default dispatch over expressions, and + implements memoization. + """ + def __init__(self): + self.memo_map = {} + + # pylint: disable=no-else-return + def visit(self, expr): + """Apply the visitor to an expression.""" + found = self.memo_map.get(expr) + if found: + return found + + if isinstance(expr, Function): + res = self.visit_function(expr) + elif isinstance(expr, Call): + res = self.visit_call(expr) + elif isinstance(expr, Let): + res = self.visit_let(expr) + elif isinstance(expr, Var): + res = self.visit_var(expr) + elif isinstance(expr, GlobalVar): + res = self.visit_global_var(expr) + elif isinstance(expr, If): + res = self.visit_if(expr) + elif isinstance(expr, Tuple): + res = self.visit_tuple(expr) + elif isinstance(expr, TupleGetItem): + res = self.visit_tuple_getitem(expr) + elif isinstance(expr, Constant): + res = self.visit_constant(expr) + elif isinstance(expr, Op): + res = self.visit_op(expr) + else: + raise Exception("warning unhandled case: {0}".format(type(expr))) + + self.memo_map[expr] = res + + return res + + def visit_function(self, _): + raise NotImplementedError() + + def visit_let(self, _): + raise NotImplementedError() + + def visit_call(self, _): + raise NotImplementedError() + + def visit_var(self, _): + raise NotImplementedError() + + def visit_type(self, typ): + return typ + + def visit_if(self, _): + raise NotImplementedError() + + def visit_tuple(self, _): + raise NotImplementedError() + + def visit_tuple_getitem(self, _): + raise NotImplementedError() + + def visit_global_var(self, _): + raise NotImplementedError() + + def visit_op(self, _): + raise NotImplementedError() + + def visit_constant(self, _): + raise NotImplementedError() + + +class ExprMutator(ExprFunctor): + """ + A functional visitor over Expr. + + The default behavior recursively traverses the AST + and reconstructs the AST. + """ + def visit_function(self, fn): + new_body = self.visit(fn.body) + return Function( + list(fn.params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_var(self, rvar): + return rvar + + def visit_global_id(self, global_var): + return global_var + + def visit_if(self, ite): + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return TupleGetItem(tuple_value, op.index) + return op + + def visit_global_var(self, gvar): + return gvar + + def visit_op(self, op): + return op + + def visit_constant(self, const): + return const + + def visit_constructor(self, con): + return con + + def visit_match(self, m): + return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) + + def visit_ref_new(self, r): + return RefNew(self.visit(r.value)) + + def visit_ref_write(self, r): + return RefWrite(self.visit(r.ref), self.visit(r.value)) + + def visit_ref_read(self, r): + return RefRead(self.visit(r.ref)) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index b8938bd34..42394955c 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -157,14 +157,14 @@ class ScheduleGetter : int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { - CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce) + CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) << "Two complicated op in a primitive function " << " master=" << master_op_ << " current=" << op; } - if (op_pattern >= master_op_patetrn_) { + if (op_pattern >= master_op_pattern_) { master_op_ = op; master_attrs_ = call_node->attrs; - master_op_patetrn_ = op_pattern; + master_op_pattern_ = op_pattern; } if (outputs.size() != 1) { const auto* tuple_type = @@ -213,7 +213,7 @@ class ScheduleGetter : tvm::Target target_; Op master_op_; Attrs master_attrs_; - int master_op_patetrn_{0}; + int master_op_pattern_{0}; std::ostringstream readable_name_stream_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; }; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 5bef4a22f..33d06e9c6 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -292,17 +292,10 @@ class Interpreter : } } - // Check if function is a primitive function. - bool IsPrimitive(const Function& func) const { - NodeRef res = FunctionGetAttr(func, "Primitive"); - const ir::IntImm* pval = res.as<ir::IntImm>(); - return pval && pval->value != 0; - } - // Invoke the closure Value Invoke(const Closure& closure, const tvm::Array<Value>& args) { // Get a reference to the function inside the closure. - if (IsPrimitive(closure->func)) { + if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6f1260b05..cdb2a32a0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const { return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); } +bool FunctionNode::IsPrimitive() const { + NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive"); + const ir::IntImm* pval = res.as<ir::IntImm>(); + return pval && pval->value != 0; +} + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } @@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_API("relay._make.Function") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); + *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 79ea3e22b..b2b35c51a 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator { std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_; // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { - NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive"); - const ir::IntImm* pval = res.as<ir::IntImm>(); - if (pval && pval->value != 0) { + if (fn_node->IsPrimitive()) { return GetRef<Expr>(fn_node); } else { return ExprMutator::VisitExpr_(fn_node); -- GitLab