diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 85a6b502d845c16c6a649ea6a0162184dcc0a9f7..1681f9b87d2f732e4e038c56192b46170f2e3ab4 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -182,6 +182,17 @@ class ExprMutator std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; }; +/* + * \brief Bind function parameters or free variables. + * + * Parameter binding can only happen if expr is a Function. + * binds cannot change internal arguments of internal functions. + * + * \param expr The function to be binded. + * \param binds The map of arguments to + */ +Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 2c9fa2808f850269e814b70d124897c47e58140d..f80d51772ae2117a3861c966d643979e108273fd 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -38,6 +38,16 @@ enum OpPatternKind { /*! \brief the operator pattern */ using TOpPattern = int; +/*! + * \brief Whether operator is stateful or contain internal state. + * + * All the primitive ops we registered so far are pure. + * This attribute is left for potential future compatible reasons. + * We can always work around the stateful ops by adding an additional + * handle argument and return it. + */ +using TOpIsStateful = bool; + /*! * \brief Computation description interface. * diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 5ff60c7035d35faae96a51f205fed2ad46dfbdbf..3ca81ebd027da2c14ec4db81844a32322b90f6a3 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); */ Expr DeadCodeElimination(const Expr& e); +/*! + * \brief Fold constant expressions. + * \param expr the expression to be optimized. + * \return The optimized expression. + */ +Expr FoldConstant(const Expr& expr); + +/*! + * \brief Fuse operations into expr into seperate functions. + * \param expr The expression. + * \param fuse_opt_level Optimization level. + * \return The optimized expression. + */ +Expr FuseOps(const Expr& expr, int fuse_opt_level); + + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 19f3a55d491a32431924465ae1b911fb84ab45be..92e1e72fdac2f7014387fb5057b0562b45236747 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -54,7 +54,7 @@ TupleGetItem = expr.TupleGetItem # helper functions var = expr.var const = expr.const - +bind = expr.bind # pylint: disable=unused-argument @register_func("relay.debug") diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 9bd03945c8474ec7512303306973625cfe6ea216..4bbab957ab1d1f1d41082997d64d21bb039e086d 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -102,6 +102,7 @@ class GraphRuntimeCodegen(ExprFunctor): self.target = target self.nodes = [] self.var_map = {} + self.params = {} self.compile_engine = compile_engine.get() self.lowered_funcs = set() self._name_map = {} @@ -162,8 +163,12 @@ class GraphRuntimeCodegen(ExprFunctor): assert isinstance(vtuple, tuple) return vtuple[op.index] - def visit_constant(self, _): - raise RuntimeError("constant not supported") + def visit_constant(self, op): + index = len(self.params) + name = "p%d" % index + self.params[name] = op.data + node = InputNode(name, {}) + return self.add_node(node, op.checked_type) def visit_function(self, _): raise RuntimeError("function not supported") @@ -312,6 +317,9 @@ class GraphRuntimeCodegen(ExprFunctor): lowered_funcs : List[tvm.LoweredFunc] The lowered functions. + + params : Dict[str, tvm.nd.NDArray] + Additional constant parameters. """ # First we convert all the parameters into input nodes. for param in func.params: @@ -324,7 +332,7 @@ class GraphRuntimeCodegen(ExprFunctor): self.heads = self.visit(func.body) graph_json = self._get_json() lowered_funcs = list(self.lowered_funcs) - return graph_json, lowered_funcs + return graph_json, lowered_funcs, self.params def _get_unique_name(self, name): if name not in self._name_map: diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 0f33e86ab5cdcfbe3f8a8c2f3d8c81cabe77e470..557e4edac681c893f3e1c29c8e889e34ce51642b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -6,6 +6,7 @@ from ..build_module import build as _tvm_build_module from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt from . import ir_pass +from . import expr from .backend import interpreter as _interpreter from .backend import graph_runtime_codegen as _graph_gen @@ -13,6 +14,7 @@ from .backend import graph_runtime_codegen as _graph_gen OPT_PASS_LEVEL = { "SimplifyInference": 0, "OpFusion": 1, + "FoldConstant": 2, "FoldScaleAxis": 3, } @@ -95,7 +97,27 @@ def build_config(**kwargs): return BuildConfig(**kwargs) -def optimize(func): +def _bind_params_by_name(func, params): + """Bind parameters of function by its name.""" + name_dict = {} + for arg in func.params: + name = arg.name_hint + if name in name_dict: + name_dict[name] = None + else: + name_dict[name] = arg + bind_dict = {} + for k, v in params.items(): + if k not in name_dict: + continue + arg = name_dict[k] + if arg is None: + raise ValueError("Multiple args in the function have name %s" % k) + bind_dict[arg] = expr.const(v) + return expr.bind(func, bind_dict) + + +def optimize(func, params=None): """Perform target invariant optimizations. Parameters @@ -103,6 +125,10 @@ def optimize(func): func : tvm.relay.Function The input to optimization. + params : Optional[Dict[str, tvm.nd.NDArray]] + Input parameters to the graph that do not change + during inference time. used for constant folding. + Returns ------- opt_func : tvm.relay.Function @@ -110,7 +136,11 @@ def optimize(func): """ cfg = BuildConfig.current - if cfg.pass_enabled("FoldScaleAxis"): + # bind expressions + if params: + func = _bind_params_by_name(func, params) + + if cfg.pass_enabled("SimplifyInference"): func = ir_pass.infer_type(func) func = ir_pass.simplify_inference(func) @@ -119,6 +149,10 @@ def optimize(func): func = ir_pass.backward_fold_scale_axis(func) func = ir_pass.infer_type(func) func = ir_pass.forward_fold_scale_axis(func) + + if cfg.pass_enabled("FoldConstant"): + func = ir_pass.fold_constant(func) + return func @@ -147,8 +181,7 @@ def build(func, params : dict of str to NDArray Input parameters to the graph that do not change - during inference time. Used for pre-compute - folding optimization. + during inference time. Used for constant folding. Returns ------- @@ -176,14 +209,14 @@ def build(func, cfg = BuildConfig.current with tophub_context: - func = optimize(func) + func = optimize(func, params) # Fuse ops before running code gen func = ir_pass.infer_type(func) func = ir_pass.fuse_ops(func, cfg.opt_level) # Graph code generation func = ir_pass.infer_type(func) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) - graph_json, lowered_funcs = graph_gen.codegen(func) + graph_json, lowered_funcs, params = graph_gen.codegen(func) mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host) return graph_json, mod, params @@ -210,21 +243,22 @@ class GraphExecutor(_interpreter.Executor): self.target = target def _make_executor(self, func): + graph_json, mod, params = build(func, target=self.target) + gmodule = _graph_rt.create(graph_json, mod, self.ctx) + if params: + gmodule.set_input(*params) def _graph_wrapper(*args): - graph_json, mod, params = build(func, target=self.target) - assert params is None - gmodule = _graph_rt.create(graph_json, mod, self.ctx) # Create map of inputs. for i, arg in enumerate(args): gmodule.set_input(i, arg) # Run the module, and fetch the output. gmodule.run() - return gmodule.get_output(0) + # make a copy so multiple invocation won't hurt perf. + return gmodule.get_output(0).copyto(_nd.cpu(0)) return _graph_wrapper - def create_executor(kind="debug", mod=None, ctx=None, diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f82ea09a102ae75ec51124b5d891bc3a34cfcdba..d71db0036f203b30b4c7e6e884a038e6d10fbd77 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,6 +6,7 @@ from numbers import Number as _Number import numpy as _np from .base import RelayNode, register_relay_node from . import _make +from . import _expr from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd @@ -577,3 +578,24 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") return Constant(value) + + +def bind(expr, binds): + """Bind an free variables in expr or function arguments. + + We can bind parameters expr if it is a function. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]] + The specific bindings. + + Returns + ------- + result : tvm.relay.Expr + The expression or function after binding. + """ + return _expr.Bind(expr, binds) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index b1a76d6fae6fb0ad4a126407d7f49299431b30f8..9d59980f61274f514370380b054f54d374c6c813 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -259,6 +259,22 @@ def structural_hash(value): raise TypeError(msg) +def fold_constant(expr): + """Fold the constant expression in expr. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + transformed_expr : tvm.relay.Expr + The transformed expression. + """ + return _ir_pass.FoldConstant(expr) + + def fuse_ops(expr, opt_level=1): """Fuse operators in expr together. diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 08f903a26d3e99b45d5fa214b04372db1f8dafac..5e3ee1761c38e3e8ab9496a84cc8152cefda7a07 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -6,8 +6,8 @@ * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ - #include <tvm/relay/expr_functor.h> +#include "type_functor.h" namespace tvm { namespace relay { @@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitType(const Type& t) { return; } +// Implement bind. +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) + : args_map_(args_map) { + } + + Expr VisitExpr_(const LetNode* op) final { + CHECK(!args_map_.count(op->var)) + << "Cannot bind an internel variable in let"; + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const FunctionNode* op) final { + for (Var param : op->params) { + CHECK(!args_map_.count(param)) + << "Cannnot bind an internal function parameter"; + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef<Var>(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return id; + } + } + + private: + const tvm::Map<Var, Expr>& args_map_; +}; + +Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { + if (const FunctionNode* func = expr.as<FunctionNode>()) { + Expr new_body = ExprBinder(args_map).Mutate(func->body); + Array<Var> new_params; + for (Var param : func->params) { + if (!args_map.count(param)) { + new_params.push_back(param); + } + } + if (new_body.same_as(func->body) && + new_params.size() == func->params.size()) { + return expr; + } + return FunctionNode::make(new_params, + new_body, + func->ret_type, + func->type_params, + func->attrs); + } else { + return ExprBinder(args_map).Mutate(expr); + } +} + + +TVM_REGISTER_API("relay._expr.Bind") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef input = args[0]; + if (input->derived_from<ExprNode>()) { + *ret = Bind(Downcast<Expr>(input), args[1]); + } else { + CHECK(input->derived_from<TypeNode>()); + *ret = Bind(Downcast<Type>(input), args[1]); + } + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 25651286ed9e5a6fdf9db8d071d4a059433a71ea..d0ae57bb01e19852176db72f844fd22fe5ca0a56 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -11,8 +11,6 @@ #include <memory> #include <mutex> -#include "./../pass/type_subst.h" - namespace dmlc { // enable registry DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc0daa3cb9c61e50495fdeef39cc4d66e1208826 --- /dev/null +++ b/src/relay/ir/type_functor.cc @@ -0,0 +1,159 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file type_functor.cc + * \brief Implementations of type functors. + */ +#include "type_functor.h" + +namespace tvm { +namespace relay { + +void TypeVisitor::VisitType_(const TypeVarNode* op) { +} + +void TypeVisitor::VisitType_(const TensorTypeNode* op) { +} + +void TypeVisitor::VisitType_(const IncompleteTypeNode* op) { +} + +void TypeVisitor::VisitType_(const FuncTypeNode* op) { + for (auto type_param : op->type_params) { + this->VisitType(type_param); + } + + for (auto type_cs : op->type_constraints) { + this->VisitType(type_cs); + } + + for (auto arg_type : op->arg_types) { + this->VisitType(arg_type); + } + this->VisitType(op->ret_type); +} + +void TypeVisitor::VisitType_(const TupleTypeNode* op) { + for (const Type& t : op->fields) { + this->VisitType(t); + } +} + +void TypeVisitor::VisitType_(const TypeRelationNode* op) { + for (const Type& t : op->args) { + this->VisitType(t); + } +} + + +// Type Mutator. +Array<Type> TypeMutator::MutateArray(Array<Type> arr) { + // The array will do copy on write + // If no changes are made, the original array will be returned. + for (size_t i = 0; i < arr.size(); ++i) { + Type ty = arr[i]; + Type new_ty = VisitType(ty); + if (!ty.same_as(new_ty)) { + arr.Set(i, new_ty); + } + } + return arr; +} + +Type TypeMutator::VisitType_(const TypeVarNode* op) { + return GetRef<TypeVar>(op); +} + +Type TypeMutator::VisitType_(const TensorTypeNode* op) { + // TODO(tvm-team) recursively visit to replace Var + return GetRef<Type>(op); +} + +Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { + return GetRef<Type>(op); +} + +Type TypeMutator::VisitType_(const FuncTypeNode* op) { + bool changed = false; + Array<TypeVar> type_params; + for (auto type_param : op->type_params) { + auto new_type_param = VisitType(type_param); + changed = changed || !new_type_param.same_as(type_param); + if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) { + type_params.push_back(GetRef<TypeVar>(tin)); + } else { + LOG(FATAL) << new_type_param << std::endl; + } + } + + Array<TypeConstraint> type_constraints; + for (auto type_cs : op->type_constraints) { + auto new_type_cs = VisitType(type_cs); + changed = changed || !new_type_cs.same_as(type_cs); + if (const TypeConstraintNode* tin = + new_type_cs.as_derived<TypeConstraintNode>()) { + type_constraints.push_back(GetRef<TypeConstraint>(tin)); + } else { + LOG(FATAL) << new_type_cs << std::endl; + } + } + + Array<Type> new_args = MutateArray(op->arg_types); + changed = changed || new_args.same_as(op->arg_types); + + Type new_ret_type = VisitType(op->ret_type); + changed = changed || new_ret_type.same_as(op->ret_type); + + if (!changed) return GetRef<Type>(op); + return FuncTypeNode::make(new_args, + new_ret_type, + type_params, + type_constraints); +} + +Type TypeMutator::VisitType_(const TupleTypeNode* op) { + Array<Type> new_fields = MutateArray(op->fields); + if (new_fields.same_as(op->fields)) { + return GetRef<Type>(op); + } else { + return TupleTypeNode::make(new_fields); + } +} + +Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { + Array<Type> new_args = MutateArray(type_rel->args); + if (new_args.same_as(type_rel->args)) { + return GetRef<Type>(type_rel); + } else { + return TypeRelationNode::make(type_rel->func, + new_args, + type_rel->num_inputs, + type_rel->attrs); + } +} + +// Implements bind. +class TypeBinder : public TypeMutator { + public: + explicit TypeBinder(const tvm::Map<TypeVar, Type>& args_map) + : args_map_(args_map) {} + + Type VisitType_(const TypeVarNode* op) override { + auto id = GetRef<TypeVar>(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return id; + } + } + + private: + const tvm::Map<TypeVar, Type>& args_map_; +}; + +Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) { + return TypeBinder(args_map).VisitType(type); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index f51c8c746eb9192c87f26fc4701f3d869847d619..e8dfd2b7cd7cdc2a6c498bd57781e149fa217537 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -91,113 +91,39 @@ class TypeFunctor<R(const Type& n, Args...)> { }; /*! - * \brief A type visitor for vistiors which make use of internal - * mutable state. - * - * We recursively visit each type contained inside the visitor. + * \brief A type visitor that recursively visit types. */ -class TypeVisitor : - public ::tvm::relay::TypeFunctor<void(const Type& n)> { +class TypeVisitor : public TypeFunctor<void(const Type& n)> { public: - void VisitType_(const TypeVarNode* op) override {} - - void VisitType_(const FuncTypeNode* op) override { - for (auto type_param : op->type_params) { - this->VisitType(type_param); - } - - for (auto type_cs : op->type_constraints) { - this->VisitType(type_cs); - } - - for (auto arg_type : op->arg_types) { - this->VisitType(arg_type); - } - this->VisitType(op->ret_type); - } - - void VisitType_(const TensorTypeNode* op) override {} - - void VisitType_(const TupleTypeNode* op) override { - for (const Type& t : op->fields) { - this->VisitType(t); - } - } - - void VisitType_(const TypeRelationNode* op) override { - for (const Type& t : op->args) { - this->VisitType(t); - } - } - - void VisitType_(const IncompleteTypeNode* op) override {} + void VisitType_(const TypeVarNode* op) override; + void VisitType_(const IncompleteTypeNode* op) override; + void VisitType_(const TensorTypeNode* op) override; + void VisitType_(const FuncTypeNode* op) override; + void VisitType_(const TupleTypeNode* op) override; + void VisitType_(const TypeRelationNode* op) override; }; -// A functional visitor for rebuilding an AST in place. -struct TypeMutator : TypeFunctor<Type(const Type& n)> { - Type VisitType_(const TensorTypeNode* op) override { - // TODO(@jroesch): maybe we should recursively visit - return TensorTypeNode::make(op->shape, op->dtype); - } - - Type VisitType_(const TypeVarNode* op) override { - return GetRef<TypeVar>(op); - } - - Type VisitType_(const FuncTypeNode* op) override { - Array<TypeVar> type_params; - for (auto type_param : op->type_params) { - auto new_type_param = VisitType(type_param); - if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) { - type_params.push_back(GetRef<TypeVar>(tin)); - } else { - CHECK(false) << new_type_param << std::endl; - } - } - - Array<TypeConstraint> type_constraints; - for (auto type_cs : op->type_constraints) { - auto new_type_cs = VisitType(type_cs); - if (const TypeConstraintNode* tin = - new_type_cs.as_derived<TypeConstraintNode>()) { - type_constraints.push_back(GetRef<TypeConstraint>(tin)); - } else { - CHECK(false) << new_type_cs << std::endl; - } - } - - std::vector<Type> args; - for (auto arg_type : op->arg_types) { - args.push_back(VisitType(arg_type)); - } - - return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type), - type_params, type_constraints); - } +// Mutator that transform a type to another one. +class TypeMutator : public TypeFunctor<Type(const Type& n)> { + public: + Type VisitType_(const TypeVarNode* op) override; + Type VisitType_(const TensorTypeNode* op) override; + Type VisitType_(const IncompleteTypeNode* op) override; + Type VisitType_(const FuncTypeNode* op) override; + Type VisitType_(const TupleTypeNode* op) override; + Type VisitType_(const TypeRelationNode* type_rel) override; - Type VisitType_(const TupleTypeNode* op) override { - std::vector<Type> new_fields; - for (const Type& t : op->fields) { - new_fields.push_back(this->VisitType(t)); - } - return TupleTypeNode::make(new_fields); - } + private: + Array<Type> MutateArray(Array<Type> arr); +}; - Type VisitType_(const TypeRelationNode* type_rel) override { - std::vector<Type> new_args; - for (const Type& t : type_rel->args) { - new_args.push_back(this->VisitType(t)); - } - return TypeRelationNode::make(type_rel->func, - new_args, - type_rel->num_inputs, - type_rel->attrs); - } +/*! + * \brief Bind free type variables in the type. + * \param type The type to be updated. + * \param args_map The binding map. + */ +Type Bind(const Type& type, const Map<TypeVar, Type>& args_map); - Type VisitType_(const IncompleteTypeNode* op) override { - return GetRef<Type>(op); - } -}; } // namespace relay } // namespace tvm #endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_ diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 6f8dce3875aef55302217f044553c916eec97371..4c814bc1614f1d5e534816db22d46c7fa2d5bde0 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -71,7 +71,8 @@ std::vector<T> AsVector(const Array<T> &array) { .add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_type_rel("Broadcast", BroadcastRel) \ - .set_attr<TOpPattern>("TOpPattern", kBroadcast) + .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ + .set_attr<TOpIsStateful>("TOpIsStateful", false) } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5d514b76556e73188d184abc5dc8c162a9923dd --- /dev/null +++ b/src/relay/pass/fold_constant.cc @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file constant_folding.cc + */ +#include <tvm/relay/pass.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/interpreter.h> + +namespace tvm { +namespace relay { + +using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; + + +// TODO(tvm-team) consider combine dead-code with constant folder. +// or make a more powerful partial evaluator. +class ConstantFolder : public ExprMutator { + public: + explicit ConstantFolder(FInterpreter executor) + : executor_(executor) { + } + + Expr VisitExpr_(const LetNode* op) final { + Expr value = this->Mutate(op->value); + if (value.as<ConstantNode>()) { + memo_[op->var] = value; + return this->Mutate(op->body); + } else { + Var var = Downcast<Var>(this->Mutate(op->var)); + Expr body = this->Mutate(op->body); + if (var.same_as(op->var) && + value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef<Expr>(op); + } else { + return LetNode::make(var, value, body); + } + } + } + + Expr VisitExpr_(const CallNode* call) final { + static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful"); + Expr res = ExprMutator::VisitExpr_(call); + call = res.as<CallNode>(); + // We don't constant fold function with zero arguments. + // This is a heuristic that is useful. + // For example it is harmful to fold ones(shape=(4, 5)). + if (call->args.size() == 0) return res; + const OpNode* op = call->op.as<OpNode>(); + if (op == nullptr) return res; + // skip stateful ops. + if (op_stateful.get(GetRef<Op>(op), false)) return res; + bool all_const_args = true; + for (Expr arg : call->args) { + if (arg.as<ConstantNode>() == nullptr) { + all_const_args = false; + } + } + if (all_const_args) { + return ConstEvaluate(res); + } else { + return res; + } + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr res = ExprMutator::VisitExpr_(op); + op = res.as<TupleGetItemNode>(); + if (const auto* tuple = op->tuple.as<TupleNode>()) { + return tuple->fields[op->index]; + } else { + return res; + } + } + + private: + // Internal interepreter. + FInterpreter executor_; + // Convert value to expression. + Expr ValueToExpr(Value value) { + if (const auto* val = value.as<TensorValueNode>()) { + return ConstantNode::make(val->data); + } else if (const auto* val = value.as<TupleValueNode>()) { + Array<Expr> fields; + for (Value field : val->fields) { + fields.push_back(ValueToExpr(field)); + } + return TupleNode::make(fields); + } else { + LOG(FATAL) << "Cannot handle " << value->type_key(); + return Expr(); + } + } + // Constant evaluate a expression. + Expr ConstEvaluate(Expr expr) { + expr = InferType(expr, Module(nullptr)); + expr = FuseOps(expr, 0); + expr = InferType(expr, Module(nullptr)); + return ValueToExpr(executor_(expr)); + } +}; + + +Expr FoldConstant(const Expr& expr) { + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + Target target = Target::create("llvm"); + return ConstantFolder(CreateInterpreter( + Module(nullptr), ctx, target)).Mutate(expr); +} + +TVM_REGISTER_API("relay._ir_pass.FoldConstant") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FoldConstant(args[0]); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index bf52297e8930f46feb0df111e196f3d247dbf0ae..d42494409b53227f74d1a76537cb411156863756 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -22,6 +22,23 @@ namespace relay { std::unordered_map<const Node*, size_t> GetExprRefCount(const Expr& body); +/*! + * \brief Substitute var with subst. + * \param type The type to be substituted. + * \param tvar The type variable to be substituted. + * \param subst The target of substitution. + * \return The substituted result. + */ +Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); + +/*! + * \brief Substitute type vars in type. + * \param type The type to be substituted. + * \param subst_map The map of substitution. + * \return The substituted result. + */ +Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5cabfbdabc49e6c1b884d76fe675654ac14e7550..13da159e99a85b54927e3a1fcc3730d0e2ee5d97 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -24,7 +24,7 @@ #include <tvm/relay/expr_functor.h> #include <tvm/relay/pass.h> #include "type_solver.h" -#include "type_subst.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -278,7 +278,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); - inst_ty = TypeSubst(inst_ty, subst_map); + inst_ty = Bind(inst_ty, subst_map); return Downcast<FuncType>(inst_ty); } diff --git a/src/relay/pass/type_subst.cc b/src/relay/pass/type_subst.cc deleted file mode 100644 index 76507058f05965f20f586a3a1f8c5b0380ceaa57..0000000000000000000000000000000000000000 --- a/src/relay/pass/type_subst.cc +++ /dev/null @@ -1,39 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file type_subst.cc - * \brief Function for substituting a concrete type in place of a type ID - */ -#include "./type_subst.h" -#include "../ir/type_functor.h" - -namespace tvm { -namespace relay { - -struct TypeSubstV : TypeMutator { - tvm::Map<TypeVar, Type> subst_map; - - explicit TypeSubstV(tvm::Map<TypeVar, Type> subst_map) - : subst_map(subst_map) {} - - Type VisitType_(const TypeVarNode* op) override { - auto id = GetRef<TypeVar>(op); - if (subst_map.find(id) != subst_map.end()) { - return this->subst_map[id]; - } else { - return id; - } - } -}; - -Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst) { - TypeSubstV ty_sub({ {target, subst} }); - return ty_sub.VisitType(type); -} - -Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map) { - TypeSubstV ty_sub(subst_map); - return ty_sub.VisitType(type); -} - -} // namespace relay -} // namespace tvm diff --git a/src/relay/pass/type_subst.h b/src/relay/pass/type_subst.h deleted file mode 100644 index 808e3536ae3063d76ac0b66da3e339db46aff1f3..0000000000000000000000000000000000000000 --- a/src/relay/pass/type_subst.h +++ /dev/null @@ -1,19 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pass/type_subst.h - * \brief Utility functions for substituting types. - */ -#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_ -#define TVM_RELAY_PASS_TYPE_SUBST_H_ - -#include <tvm/relay/expr.h> - -namespace tvm { -namespace relay { - -Type TypeSubst(const Type& type, const TypeVar& target, const Type& subst); -Type TypeSubst(const Type& type, tvm::Map<TypeVar, Type> subst_map); - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_PASS_TYPE_SUBST_H_ diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index ebc4e6fc16e614b82a866fc5a3fe43b4a29a8c1c..8f7179deea5351a1be3b7c16f4ad427fc1490339 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -13,7 +13,6 @@ namespace tvm { namespace relay { // FreeTypeVar - class FreeTypeVarTVisitor : public TypeVisitor { public: FreeTypeVarTVisitor( diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 7f857b72ad1cf7ff0ddd8f3bd9fc49cf2a1e1950..7b610f82f6a53d799c0121128095e39f8a391a20 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -1,6 +1,8 @@ import numpy as np +import tvm from tvm import relay +from tvm.contrib import graph_runtime from tvm.relay.ir_pass import infer_type from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add @@ -27,7 +29,7 @@ def check_rts(expr, args, expected_result, mod=None): graph = relay.create_executor('graph', mod=mod) eval_result = intrp.evaluate(expr)(*args) rts_result = graph.evaluate(expr)(*args) - np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) + tvm.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) def test_add_op_scalar(): """ @@ -71,7 +73,26 @@ def test_add_op_broadcast(): y_data = np.random.rand(1, 5).astype('float32') check_rts(func, [x_data, y_data], x_data + y_data) + +def test_with_params(): + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(1, 5)) + func = relay.Function([x, y], add(x, y)) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(1, 5).astype('float32') + params = {"y": y_data} + graph, lib, params = relay.build(func, "llvm", params=params) + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input(**params) + mod.set_input(x=x_data) + mod.run() + res = mod.get_output(0).asnumpy() + ref_res = y_data + x_data + tvm.testing.assert_allclose(res, ref_res) + + if __name__ == "__main__": + test_with_params() test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py new file mode 100644 index 0000000000000000000000000000000000000000..8377bb9fb953b68264665b2e71578be7760f62d1 --- /dev/null +++ b/tests/python/relay/test_ir_bind.py @@ -0,0 +1,23 @@ +""" test bind function.""" +import tvm +from tvm import relay + + +def test_bind_params(): + x = relay.var("x") + y = relay.var("y") + z = relay.add(x, y) + f = relay.Function([x, y], z) + fbinded = relay.bind(f, {x : relay.const(1, "float32")}) + fexpected =relay.Function( + [y], + relay.add(relay.const(1, "float32"), y)) + assert relay.ir_pass.alpha_equal(fbinded, fexpected) + + zbinded = relay.bind(z, {y: x}) + zexpected = relay.add(x, x) + assert relay.ir_pass.alpha_equal(zbinded, zexpected) + + +if __name__ == "__main__": + test_bind_params() diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9e397be975a7ebfcee3b97f68bc3e42fa4d84a --- /dev/null +++ b/tests/python/relay/test_pass_fold_constant.py @@ -0,0 +1,75 @@ +import numpy as np +from tvm import relay + + +def test_fold_const(): + c_data = np.array([1, 2, 3]).astype("float32") + def before(): + c = relay.const(c_data) + x = relay.var("x") + y = relay.add(c, c) + y = relay.multiply(y, relay.const(2, "float32")) + y = relay.add(x, y) + z = relay.add(y, c) + return relay.Function([x], z) + + def expected(): + x = relay.var("x") + c_folded = (c_data + c_data) * 2 + y = relay.add(x, relay.const(c_folded)) + z = relay.add(y, relay.const(c_data)) + return relay.Function([x], z) + zz = relay.ir_pass.fold_constant(before()) + zexpected = expected() + assert relay.ir_pass.alpha_equal(zz, zexpected) + + +def test_fold_let(): + c_data = np.array(1).astype("float32") + def before(): + sb = relay.ScopeBuilder() + x = relay.var("x") + t1 = sb.let("t1", relay.const(c_data)) + t2 = sb.let("t2", relay.add(t1, t1)) + t3 = sb.let("t3", relay.add(t2, x)) + sb.ret(t3) + return relay.Function([x], sb.get()) + + def expected(): + sb = relay.ScopeBuilder() + x = relay.var("x") + c_folded = (c_data + c_data) + t3 = sb.let("t3", relay.add(relay.const(c_folded), x)) + sb.ret(t3) + return relay.Function([x], sb.get()) + + zz = relay.ir_pass.fold_constant(before()) + zexpected = expected() + assert relay.ir_pass.graph_equal(zz, zexpected) + + +def test_fold_tuple(): + c_data = np.array(1).astype("float32") + def before(): + c = relay.const(c_data) + x = relay.var("x") + y = relay.Tuple([x, c]) + z = relay.add(y[1], c) + z = relay.add(z, y[0]) + return relay.Function([x], z) + + def expected(): + c = relay.const(c_data + c_data) + x = relay.var("x") + z = relay.add(c, x) + return relay.Function([x], z) + + zz = relay.ir_pass.fold_constant(before()) + zexpected = expected() + assert relay.ir_pass.graph_equal(zz, zexpected) + + +if __name__ == "__main__": + test_fold_const() + test_fold_let() + test_fold_tuple()