diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 37c91ffe4ed22ef5d46d2aff1cea1838b79aa774..14b3cd91701c6203afcf3828b273f1dceff8ccd1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -248,6 +248,13 @@ class FunctionNode : public ExprNode { */ TVM_DLL FuncType func_type_annotation() const; + /*! + * \brief Check whether the function is a primitive function. + * + * \return Whether the function is primitive or not. + */ + bool IsPrimitive() const; + TVM_DLL static Function make(tvm::Array<Var> params, Expr body, Type ret_type, diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b66132f277753e3b2fa2a090bcfb786cc4590435..69180837b72451696dd6fa469c84b3d1c2ed2322 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -5,6 +5,7 @@ from ..api import register_func from . import base from . import ty from . import expr +from . import expr_functor from . import module from . import ir_pass from .build_module import build, build_config, create_executor @@ -53,6 +54,10 @@ Let = expr.Let If = expr.If TupleGetItem = expr.TupleGetItem +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +ExprMutator = expr_functor.ExprMutator + # helper functions var = expr.var const = expr.const diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 0da9b81269aa5dc69fa8ca453c1a39a33df957c6..91d09973ea8fbf6a2a4c6972c3c1c5d9dbf7c963 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -24,7 +24,8 @@ import attr from . import _backend from . import compile_engine from ..op import Op -from ..expr import Function, GlobalVar, ExprFunctor +from ..expr import Function, GlobalVar +from ..expr_functor import ExprFunctor from ..ty import TupleType, TensorType @@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor): op_name, inputs, {}) return self.add_node(op_node, call) + def visit_op(self, _): + raise Exception("can not compile op in non-eta expanded form") + def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 4725c0a7a07d96b10ec93115acda49ca1ba18b9e..e0c1f68ad4317a9a2b0ade88a817b2f288267ae1 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -222,12 +222,13 @@ class Function(Expr): params, body, ret_type=None, - type_params=None): + type_params=None, + attrs=None): if type_params is None: type_params = convert([]) self.__init_handle_by_constructor__( - _make.Function, params, body, ret_type, type_params) + _make.Function, params, body, ret_type, type_params, attrs) def __call__(self, *args): """Invoke the gobal function. @@ -343,131 +344,6 @@ class TempExpr(Expr): return _expr.TempExprRealize(self) -class ExprFunctor(object): - """ - An abstract visitor defined over Expr. - - Defines the default dispatch over expressions, and - implements memoization. - """ - def __init__(self): - self.memo_map = {} - - # pylint: disable=no-else-return - def visit(self, expr): - """Apply the visitor to an expression.""" - found = self.memo_map.get(expr) - if found: - return found - - if isinstance(expr, Function): - res = self.visit_function(expr) - elif isinstance(expr, Call): - res = self.visit_call(expr) - elif isinstance(expr, Let): - res = self.visit_let(expr) - elif isinstance(expr, Var): - res = self.visit_var(expr) - elif isinstance(expr, GlobalVar): - res = self.visit_global_var(expr) - elif isinstance(expr, If): - res = self.visit_if(expr) - elif isinstance(expr, Tuple): - res = self.visit_tuple(expr) - elif isinstance(expr, TupleGetItem): - res = self.visit_tuple_getitem(expr) - elif isinstance(expr, Constant): - res = self.visit_constant(expr) - else: - raise Exception("warning unhandled case: {0}".format(type(expr))) - - self.memo_map[expr] = res - return res - - def visit_function(self, _): - raise NotImplementedError() - - def visit_let(self, _): - raise NotImplementedError() - - def visit_call(self, _): - raise NotImplementedError() - - def visit_var(self, _): - raise NotImplementedError() - - def visit_type(self, typ): - return typ - - def visit_if(self, _): - raise NotImplementedError() - - def visit_tuple(self, _): - raise NotImplementedError() - - def visit_tuple_getitem(self, _): - raise NotImplementedError() - - def visit_constant(self, _): - raise NotImplementedError() - - def visit_global_var(self, _): - raise NotImplementedError() - - -class ExprMutator(ExprFunctor): - """ - A functional visitor over Expr. - - The default behavior recursively traverses the AST - and reconstructs the AST. - """ - def visit_function(self, fn): - new_body = self.visit(fn.body) - return Function( - list(fn.params), - fn.ret_type, new_body, - fn.type_params) - - def visit_let(self, let): - new_var = self.visit(let.var) - new_val = self.visit(let.value) - new_body = self.visit(let.body) - return Let(new_var, new_val, new_body) - - def visit_call(self, call): - new_fn = self.visit(call.op) - new_args = [self.visit(arg) for arg in call.args] - return Call(new_fn, new_args, call.attrs) - - def visit_var(self, rvar): - return rvar - - def visit_global_id(self, global_var): - return global_var - - def visit_if(self, ite): - return If( - self.visit(ite.guard), - self.visit(ite.true_b), - self.visit(ite.false_b)) - - def visit_tuple(self, tup): - return Tuple([self.visit(field) for field in tup.fields]) - - def visit_tuple_getitem(self, op): - tuple_value = self.visit(op.tuple_value) - if not tuple_value.same_as(op.tuple_value): - return TupleGetItem(tuple_value, op.index) - return op - - def visit_global_var(self, gvar): - return gvar - - def visit_constant(self, rconst): - return rconst - - class TupleWrapper(object): """TupleWrapper. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py new file mode 100644 index 0000000000000000000000000000000000000000..eafe5f09309fffe2eee78c30362c71201e5cb5d4 --- /dev/null +++ b/python/tvm/relay/expr_functor.py @@ -0,0 +1,155 @@ +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The expression functor of Relay.""" + +from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant +from .op import Op + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + + Defines the default dispatch over expressions, and + implements memoization. + """ + def __init__(self): + self.memo_map = {} + + # pylint: disable=no-else-return + def visit(self, expr): + """Apply the visitor to an expression.""" + found = self.memo_map.get(expr) + if found: + return found + + if isinstance(expr, Function): + res = self.visit_function(expr) + elif isinstance(expr, Call): + res = self.visit_call(expr) + elif isinstance(expr, Let): + res = self.visit_let(expr) + elif isinstance(expr, Var): + res = self.visit_var(expr) + elif isinstance(expr, GlobalVar): + res = self.visit_global_var(expr) + elif isinstance(expr, If): + res = self.visit_if(expr) + elif isinstance(expr, Tuple): + res = self.visit_tuple(expr) + elif isinstance(expr, TupleGetItem): + res = self.visit_tuple_getitem(expr) + elif isinstance(expr, Constant): + res = self.visit_constant(expr) + elif isinstance(expr, Op): + res = self.visit_op(expr) + else: + raise Exception("warning unhandled case: {0}".format(type(expr))) + + self.memo_map[expr] = res + + return res + + def visit_function(self, _): + raise NotImplementedError() + + def visit_let(self, _): + raise NotImplementedError() + + def visit_call(self, _): + raise NotImplementedError() + + def visit_var(self, _): + raise NotImplementedError() + + def visit_type(self, typ): + return typ + + def visit_if(self, _): + raise NotImplementedError() + + def visit_tuple(self, _): + raise NotImplementedError() + + def visit_tuple_getitem(self, _): + raise NotImplementedError() + + def visit_global_var(self, _): + raise NotImplementedError() + + def visit_op(self, _): + raise NotImplementedError() + + def visit_constant(self, _): + raise NotImplementedError() + + +class ExprMutator(ExprFunctor): + """ + A functional visitor over Expr. + + The default behavior recursively traverses the AST + and reconstructs the AST. + """ + def visit_function(self, fn): + new_body = self.visit(fn.body) + return Function( + list(fn.params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_var(self, rvar): + return rvar + + def visit_global_id(self, global_var): + return global_var + + def visit_if(self, ite): + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return TupleGetItem(tuple_value, op.index) + return op + + def visit_global_var(self, gvar): + return gvar + + def visit_op(self, op): + return op + + def visit_constant(self, const): + return const + + def visit_constructor(self, con): + return con + + def visit_match(self, m): + return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) + + def visit_ref_new(self, r): + return RefNew(self.visit(r.value)) + + def visit_ref_write(self, r): + return RefWrite(self.visit(r.ref), self.visit(r.value)) + + def visit_ref_read(self, r): + return RefRead(self.visit(r.ref)) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index b8938bd34804dd6bcdd38ea08cd2ac0a6e29c930..42394955cc64d751ec60e9de0f282ad9d72c105b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -157,14 +157,14 @@ class ScheduleGetter : int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { - CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce) + CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) << "Two complicated op in a primitive function " << " master=" << master_op_ << " current=" << op; } - if (op_pattern >= master_op_patetrn_) { + if (op_pattern >= master_op_pattern_) { master_op_ = op; master_attrs_ = call_node->attrs; - master_op_patetrn_ = op_pattern; + master_op_pattern_ = op_pattern; } if (outputs.size() != 1) { const auto* tuple_type = @@ -213,7 +213,7 @@ class ScheduleGetter : tvm::Target target_; Op master_op_; Attrs master_attrs_; - int master_op_patetrn_{0}; + int master_op_pattern_{0}; std::ostringstream readable_name_stream_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; }; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 5bef4a22f371531560705a930cf42f26ccc6b1e5..33d06e9c6c282e2832a3dccc30a272b1269e2b44 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -292,17 +292,10 @@ class Interpreter : } } - // Check if function is a primitive function. - bool IsPrimitive(const Function& func) const { - NodeRef res = FunctionGetAttr(func, "Primitive"); - const ir::IntImm* pval = res.as<ir::IntImm>(); - return pval && pval->value != 0; - } - // Invoke the closure Value Invoke(const Closure& closure, const tvm::Array<Value>& args) { // Get a reference to the function inside the closure. - if (IsPrimitive(closure->func)) { + if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6f1260b05b996e1086de5cd096512ca961f0963e..cdb2a32a0009b7cf1ec83c047f426d9f1e5e992a 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const { return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); } +bool FunctionNode::IsPrimitive() const { + NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive"); + const ir::IntImm* pval = res.as<ir::IntImm>(); + return pval && pval->value != 0; +} + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } @@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_API("relay._make.Function") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FunctionNode::make(args[0], args[1], args[2], args[3]); + *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 79ea3e22b1390747afdbcc0f5e47445d9a821e56..b2b35c51a1ca2cedcac4157ea9d8f69d003bcacb 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator { std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_; // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { - NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive"); - const ir::IntImm* pval = res.as<ir::IntImm>(); - if (pval && pval->value != 0) { + if (fn_node->IsPrimitive()) { return GetRef<Expr>(fn_node); } else { return ExprMutator::VisitExpr_(fn_node);