diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8b2a5fafd8f0bfecbb1b4fbff10865ea7be55448..3678aee32850119e201e16587da3825e091bad38 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -80,7 +80,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); */ bool AlphaEqual(const Type& t1, const Type& t2); -/*! brief Check that each Var is only bind once. +/*! \brief Check that each Var is only bound once. * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * @@ -88,9 +88,9 @@ bool AlphaEqual(const Type& t1, const Type& t2); * * \param e the expression to check. * - * \return true iff all Var in e is bind at most once. + * \return true iff all Var in e is bound at most once. */ -bool WellFormed(const Expr & e); +bool WellFormed(const Expr& e); /*! \brief Get free variables from expression e. * @@ -100,7 +100,7 @@ bool WellFormed(const Expr & e); * * \return the set of free variable. */ -tvm::Array<Var> FreeVariables(const Expr & e); +tvm::Array<Var> FreeVariables(const Expr& e); /*! \brief Get free type parameters from expression e. * @@ -110,7 +110,7 @@ tvm::Array<Var> FreeVariables(const Expr & e); * * \return the set of free type variables. */ -tvm::Array<TypeParam> FreeTypeVariables(const Expr & e); +tvm::Array<TypeParam> FreeTypeVariables(const Expr& e); /*! \brief Get free type parameters from type t. * @@ -120,7 +120,20 @@ tvm::Array<TypeParam> FreeTypeVariables(const Expr & e); * * \return the set of free type variables. */ -tvm::Array<TypeParam> FreeTypeVariables(const Type & t); +tvm::Array<TypeParam> FreeTypeVariables(const Type& t); + +/*! \brief Remove expressions which does not effect the program result. + * + * It will remove let binding that are not referenced, and if branch that are not entered. + * + * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. + * Another example is `if (true) then 1 else 2` will be optimized into 1. + * + * \param e the expression to optimize. + * + * \return the optimized expression. + */ +Expr DeadCodeElimination(const Expr& e); } // namespace relay } // namespace tvm diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 313e0a5c3da862a7c01d6e9fc3e821b760a0ac47..0fc8e42b8bcb58e743b1a944a83620f1091495d3 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -282,6 +282,21 @@ inline void NDArray::reset() { } } +/*! \brief return the size of data the DLTensor hold, in term of number of bytes + * + * \param arr the input DLTensor + * + * \return number of bytes of data in the DLTensor. + */ +inline size_t GetDataSize(const DLTensor& arr) { + size_t size = 1; + for (tvm_index_t i = 0; i < arr.ndim; ++i) { + size *= static_cast<size_t>(arr.shape[i]); + } + size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; + return size; +} + inline void NDArray::CopyFrom(DLTensor* other) { CHECK(data_ != nullptr); CopyFromTo(other, &(data_->dl_tensor)); diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index f321083aa44346247ba3805b5a6082a84e4c01bd..f1432803e9e209715f4bb94b37b3ead6ad079870 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -4,4 +4,5 @@ from . import ir def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... -def well_formed(expr: ir.Expr) -> bool: ... \ No newline at end of file +def well_formed(expr: ir.Expr) -> bool: ... +def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... \ No newline at end of file diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 6e52f209d0c6b82d6b6fd82d111e333ac8ef771a..accb782659dffecc58b40cabc57e966678157f04 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -16,12 +16,12 @@ def _convert_to_value(arg, ctxt=tvm.cpu(0)): """Convert Python values into the appropriate types for the Relay evaluator. """ - if isinstance(arg, int): + if isinstance(arg, bool): # bool is subclass of int + return tvm.nd.array(np.array(arg, dtype='uint8'), ctxt) + elif isinstance(arg, int): return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) elif isinstance(arg, float): return tvm.nd.array(arg, ctxt) - elif isinstance(arg, bool): - return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) elif isinstance(arg, np.ndarray): return tvm.nd.array(arg, ctxt) elif isinstance(arg, tvm.ndarray.NDArray): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 78cc5027c32cae40161d71587493586cc2ccabd9..6de6437b9eb9aad573e7603f12fc20fde1da7c86 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -6,15 +6,16 @@ Exposes an interface for configuring the passes and scripting them in Python. """ from . import _ir_pass +from . import _make # pylint: disable=invalid-name def infer_type(env, expr): - """Infer the type of expr under the context of env + """Infer the type of expr under the context of env. Parameters ---------- env : relay.Environment - The global environmemt. + The global environment. expr : relay.Expr The input expression. @@ -34,3 +35,37 @@ check_kind = _ir_pass.check_kind free_vars = _ir_pass.free_vars free_type_vars = _ir_pass.free_type_vars + +def dead_code_elimination(e): + """ Remove expressions which does not effect the program result (dead code). + + Parameters + ---------- + e: relay.Expr + The input Expression + + Returns + ------- + result: relay.Expr + An expression which is semantically equal to the input expression, + but with dead code removed. + """ + return _ir_pass.dead_code_elimination(e) + +def alpha_equal(lhs, rhs): + """Compare two Relay expr for structural equivalence (alpha equivalence). + + Parameters + ---------- + lhs: relay.Expr + One of the input Expression. + rhs: relay.Expr + One of the input Expression. + + + Returns + ------- + result: bool + True iff lhs is alpha equal to rhs. + """ + return bool(_make._alpha_equal(lhs, rhs)) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index c7cf9a346b6888b4b4bda3ac66c65cb17dc976d7..a6ac1857bfa826bf1693aebf24cb4bc76c844f06 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -12,7 +12,7 @@ class Type(NodeBase): """Compare two Relay types for structural equivalence using alpha equivalence. """ - return bool(_make._type_alpha_eq(self, other)) + return bool(_make._type_alpha_equal(self, other)) def __ne__(self, other): return not self.__eq__(other) diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 39f55af6fe70eb7e4c4e1e5628c1180eb346c17e..3c4c3d78063f92a0c0cc64231dfbe2ee86a2a748 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,10 +1,11 @@ /*! * Copyright (c) 2018 by Contributors * \file src/tvm/relay/pass/alpha_eq.cc - * \brief The structral equivalence comparison. + * \brief Check that two type are syntactically equal up to alpha equivalence. */ #include <tvm/ir_pass.h> #include <tvm/relay/expr_functor.h> +#include <tvm/runtime/ndarray.h> #include "./type_visitor.h" #include "tvm/relay/pass.h" @@ -13,6 +14,25 @@ namespace relay { using namespace tvm::runtime; +bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { + if (lhs.defined() != rhs.defined()) { + return false; + } else if (lhs.same_as(rhs)) { + return true; + } else { + auto ldt = lhs->dtype; + auto rdt = rhs->dtype; + CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t s = GetDataSize(*lhs.operator->()); + return memcmp(lhs->data, rhs->data, s) == 0; + } else { + return false; + } + } +} + struct TypeAlphaEq : TypeVisitor<const Type&> { tvm::Map<TypeParam, TypeParam> eq_map; bool equal; @@ -38,8 +58,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const TensorTypeNode *tt1, const Type& t2) final { - if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) { + void VisitType_(const TensorTypeNode* tt1, const Type& t2) final { + if (const TensorTypeNode* tt2 = t2.as<TensorTypeNode>()) { DataTypeEqual(tt1->dtype, tt2->dtype); ShapeEqual(tt1->shape, tt2->shape); } else { @@ -47,8 +67,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final { - if (const IncompleteTypeNode *bt2 = t2.as<IncompleteTypeNode>()) { + void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { + if (const IncompleteTypeNode* bt2 = t2.as<IncompleteTypeNode>()) { equal = equal && bt1 == bt2; return; } else { @@ -56,8 +76,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const TypeParamNode *ti1, const Type& t2) final { - if (const TypeParamNode *ti2 = t2.as<TypeParamNode>()) { + void VisitType_(const TypeParamNode* ti1, const Type& t2) final { + if (const TypeParamNode* ti2 = t2.as<TypeParamNode>()) { auto tid1 = GetRef<TypeParam>(ti1); auto tid2 = GetRef<TypeParam>(ti2); @@ -86,8 +106,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const FuncTypeNode *op, const Type& t2) final { - if (const FuncTypeNode *ta2 = t2.as<FuncTypeNode>()) { + void VisitType_(const FuncTypeNode* op, const Type& t2) final { + if (const FuncTypeNode* ta2 = t2.as<FuncTypeNode>()) { if (op->arg_types.size() != ta2->arg_types.size() || op->type_params.size() != ta2->type_params.size() || op->type_constraints.size() != ta2->type_constraints.size()) { @@ -128,8 +148,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { - if (const TypeRelationNode *tr2 = t2.as<TypeRelationNode>()) { + void VisitType_(const TypeRelationNode* tr1, const Type& t2) final { + if (const TypeRelationNode* tr2 = t2.as<TypeRelationNode>()) { if (tr1->func != tr2->func || tr1->num_inputs != tr2->num_inputs || tr1->attrs != tr2->attrs) { @@ -153,8 +173,8 @@ struct TypeAlphaEq : TypeVisitor<const Type&> { } } - void VisitType_(const TupleTypeNode *op, const Type& t2) final { - if (const TupleTypeNode *pt = t2.as<TupleTypeNode>()) { + void VisitType_(const TupleTypeNode* op, const Type& t2) final { + if (const TupleTypeNode* pt = t2.as<TupleTypeNode>()) { if (op->fields.size() != pt->fields.size()) { equal = false; return; @@ -185,8 +205,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { bool equal; AlphaEq() : eq_map(), equal(true) {} - void VisitExpr_(const VarNode *e1, const Expr& e2) final { - if (const VarNode *id2 = e2.as<VarNode>()) { + void VisitExpr_(const VarNode* e1, const Expr& e2) final { + if (const VarNode* id2 = e2.as<VarNode>()) { auto local1 = GetRef<Var>(e1); auto local2 = GetRef<Var>(id2); // We handle open terms with this rule assuming variables are identical. @@ -207,17 +227,17 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final { - if (const GlobalVarNode *g2 = e2.as<GlobalVarNode>()) { + void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final { + if (const GlobalVarNode* g2 = e2.as<GlobalVarNode>()) { equal = equal && g1 == g2; } else { equal = false; } } - void VisitExpr_(const TupleNode *pl1, const Expr& e2) final { + void VisitExpr_(const TupleNode* pl1, const Expr& e2) final { Tuple prod1 = GetRef<Tuple>(pl1); - if (const TupleNode *pl2 = e2.as<TupleNode>()) { + if (const TupleNode* pl2 = e2.as<TupleNode>()) { Tuple prod2 = GetRef<Tuple>(pl2); if (prod1->fields.size() != prod2->fields.size()) { equal = false; @@ -232,8 +252,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const ParamNode *p1, const Expr& e2) final { - if (const ParamNode *p2 = e2.as<ParamNode>()) { + void VisitExpr_(const ParamNode* p1, const Expr& e2) final { + if (const ParamNode* p2 = e2.as<ParamNode>()) { eq_map.Set(p1->var, p2->var); equal = equal && AlphaEqual(p1->type, p2->type); } else { @@ -241,8 +261,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const FunctionNode *func1, const Expr& e2) final { - if (const FunctionNode *func2 = e2.as<FunctionNode>()) { + void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { + if (const FunctionNode* func2 = e2.as<FunctionNode>()) { if (func1->params.size() != func2->params.size()) { equal = false; return; @@ -258,8 +278,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const CallNode *op, const Expr& e2) final { - if (const CallNode *call = e2.as<CallNode>()) { + void VisitExpr_(const CallNode* op, const Expr& e2) final { + if (const CallNode* call = e2.as<CallNode>()) { this->VisitExpr(op->op, call->op); if (op->args.size() != call->args.size()) { @@ -276,8 +296,8 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { } } - void VisitExpr_(const LetNode *op, const Expr& e2) final { - if (const LetNode *let = e2.as<LetNode>()) { + void VisitExpr_(const LetNode* op, const Expr& e2) final { + if (const LetNode* let = e2.as<LetNode>()) { eq_map.Set(op->var, let->var); this->VisitExpr(op->value, let->value); this->VisitExpr(op->body, let->body); @@ -285,6 +305,36 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> { equal = false; } } + + void VisitExpr_(const IfNode* op, const Expr& e2) final { + if (const IfNode* i = e2.as<IfNode>()) { + VisitExpr(op->cond, i->cond); + VisitExpr(op->true_branch, i->true_branch); + VisitExpr(op->false_branch, i->false_branch); + } else { + equal = false; + } + } + + void VisitExpr_(const OpNode* op, const Expr& e2) final { + if (const OpNode* o = e2.as<OpNode>()) { + equal = equal && op->name == o->name; + } else { + equal = false; + } + } + + void VisitExpr_(const ConstantNode* op, const Expr& e2) final { + if (const ConstantNode* c = e2.as<ConstantNode>()) { + if (AlphaEqual(op->tensor_type(), c->tensor_type())) { + equal = equal && SameNDArray(op->data, c->data); + } else { + equal = false; + } + } else { + equal = false; + } + } }; bool AlphaEqual(const Expr& e1, const Expr& e2) { @@ -294,15 +344,15 @@ bool AlphaEqual(const Expr& e1, const Expr& e2) { } // TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_eq") - .set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_API("relay._make._alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { Expr e1 = args[0]; Expr e2 = args[1]; *ret = AlphaEqual(e1, e2); }); -TVM_REGISTER_API("relay._make._type_alpha_eq") - .set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_API("relay._make._type_alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { Type t1 = args[0]; Type t2 = args[1]; *ret = AlphaEqual(t1, t2); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc new file mode 100644 index 0000000000000000000000000000000000000000..05036042a6354a2caf9335a6fa77238da385555d --- /dev/null +++ b/src/relay/pass/dead_code.cc @@ -0,0 +1,119 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file dead_code.cc + * + * \brief Remove code that does not effect the program result. + * + * The algorithm is implemented by two visitor: + * CalcDep turn an expr into a dependency graph of expr, + * GenLet turn the dependency graph into a let list, taking only the used value. + */ +#include <tvm/relay/pass.h> +#include <tvm/relay/expr_functor.h> +#include "let_list.h" + +namespace tvm { +namespace relay { + +bool IsBoolLit(const Expr& e, bool b) { + if (const ConstantNode* c = e.as<ConstantNode>()) { + if (c->is_scalar()) { + auto dt = c->tensor_type()->dtype; + if (dt == UInt(8)) { + return *reinterpret_cast<const uint8_t*>(c->data->data) == b; + } else if (dt == UInt(16)) { + return *reinterpret_cast<const uint16_t*>(c->data->data) == b; + } else if (dt == UInt(32)) { + return *reinterpret_cast<const uint32_t*>(c->data->data) == b; + } else if (dt == UInt(64)) { + return *reinterpret_cast<const uint64_t*>(c->data->data) == b; + } else if (dt == Int(8)) { + return *reinterpret_cast<const int8_t*>(c->data->data) == b; + } else if (dt == Int(16)) { + return *reinterpret_cast<const int16_t*>(c->data->data) == b; + } else if (dt == Int(32)) { + return *reinterpret_cast<const int32_t*>(c->data->data) == b; + } else if (dt == Int(64)) { + return *reinterpret_cast<const int64_t*>(c->data->data) == b; + } + } + } + return false; +} + +// calculate the dependency graph from expression +class CalcDep : private ExprMutator { + public: + static Expr Eliminate(const Expr& e) { + CalcDep cd; + auto res = cd(e); + GenLet gl(cd.var_map_); + gl(res); + return gl.lets_.Get(res); + } + + private: + struct Binder { + Type t; + Expr e; + Binder(const Type& t, const Expr& e) : t(t), e(e) { } + }; + using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>; + VarMap var_map_; + + Expr VisitExpr_(const IfNode* i) final { + auto cond = VisitExpr(i->cond); + if (IsBoolLit(cond, true)) { + return Eliminate(i->true_branch); + } else if (IsBoolLit(cond, false)) { + return Eliminate(i->false_branch); + } else { + return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch)); + } + } + + Expr VisitExpr_(const LetNode* l) final { + var_map_.insert(std::pair<Var, Binder>(l->var, + Binder(l->value_type, + Eliminate(l->value)))); + return VisitExpr(l->body); + } + + Expr VisitExpr_(const FunctionNode* f) final { + return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params); + } + + // generate the let list from dependency graph + class GenLet : private ExprVisitor { + private: + LetList lets_; + VarMap var_map_; + explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } + friend CalcDep; + + void VisitExpr_(const VarNode* vn) final { + Var v = GetRef<Var>(vn); + if (var_map_.count(v) != 0) { + auto val = var_map_.at(v); + var_map_.erase(v); + // erase before visit to handle letrec + VisitExpr(val.e); + // visit before push back so the dependency of dependency is before the dependency + lets_.Push(v, val.t, val.e); + } + } + }; +}; + +Expr DeadCodeElimination(const Expr& e) { + return CalcDep::Eliminate(e); +} + +TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = DeadCodeElimination(args[0]); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 04c178f25dfaaeb2f1480ff0d3e29695af3c5c63..574111e39b64717d2280b894fc0e5485db247c38 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -25,15 +25,6 @@ inline void VerifyDataType(DLDataType dtype) { CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } -inline size_t GetDataSize(const DLTensor& arr) { - size_t size = 1; - for (tvm_index_t i = 0; i < arr.ndim; ++i) { - size *= arr.shape[i]; - } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; -} - inline size_t GetDataAlignment(const DLTensor& arr) { size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; if (align < kAllocAlignment) return kAllocAlignment; @@ -129,8 +120,8 @@ DLManagedTensor* NDArray::ToDLPack() const { } NDArray NDArray::Empty(std::vector<int64_t> shape, - DLDataType dtype, - DLContext ctx) { + DLDataType dtype, + DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.data_->dl_tensor); diff --git a/tests/python/relay/test_dead_code_elimination.py b/tests/python/relay/test_dead_code_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..10f60be32f55739be6815ee7b5302c1a4ec7cc1e --- /dev/null +++ b/tests/python/relay/test_dead_code_elimination.py @@ -0,0 +1,77 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import dead_code_elimination, alpha_equal +from tvm.relay.ir_builder import convert, IRBuilder +from tvm.relay.op import log, add, equal, subtract, concat + +class env: + def __init__(self): + self.a = relay.Var("a") + self.b = relay.Var("b") + self.c = relay.Var("c") + self.d = relay.Var("d") + self.e = relay.Var("e") + self.x = relay.Var("x") + self.y = relay.Var("y") + self.z = relay.Var("z") + self.shape = tvm.convert([1, 2, 3]) + self.tt = relay.TensorType(self.shape, "float32") + self.int32 = relay.TensorType([], "int32") + self.float32 = relay.TensorType([], "float32") + self.one = convert(1.0) + self.two = convert(2.0) + self.three = convert(3.0) + +e = env() + +def test_let(): + orig = relay.Let(e.x, e.y, e.z, e.tt) + assert alpha_equal(dead_code_elimination(orig), e.z) + +def test_used_let(): + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) + assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt)) + +def test_chain_unused_let(): + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt) + assert alpha_equal(dead_code_elimination(orig), e.e) + +# make sure we dont infinite loop +def test_recursion(): + """ + Program: + let f(n: i32, data: f32) -> f32 = { + if (n == 0) { + return data; + } else { + return f(n - 1, log(data)); + } + } + f(2, 10000); + """ + f = relay.Var("f") + n = relay.Var("n") + np = relay.Param(n, e.int32) + data = relay.Var("data") + datap = relay.Param(data, e.float32) + funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) + value = relay.Function([np, datap], e.float32, funcbody, []) + orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) + assert alpha_equal(dead_code_elimination(orig), orig) + assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three) + +def test_op_let(): + assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two)) + +def test_if(): + orig = relay.If(convert(True), e.a, e.b) + assert alpha_equal(dead_code_elimination(orig), e.a) + + +if __name__ == "__main__": + test_let() + test_used_let() + test_chain_unused_let() + test_recursion() + test_op_let() + test_if() diff --git a/tests/python/relay/test_pass_alpha_eq.py b/tests/python/relay/test_pass_alpha_equal.py similarity index 89% rename from tests/python/relay/test_pass_alpha_eq.py rename to tests/python/relay/test_pass_alpha_equal.py index d925b54d47d2cb94b4ad0e10f45f72c9f14b2b27..93f8a8fbc0b31c700203d7d42c6c90cbcb0e3a26 100644 --- a/tests/python/relay/test_pass_alpha_eq.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -1,8 +1,9 @@ import tvm from tvm import relay +from tvm.relay.ir_pass import alpha_equal +from tvm.relay.ir_builder import convert - -def test_tensor_type_alpha_eq(): +def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32") t3 = relay.TensorType((3, 4, 5), "float32") @@ -13,8 +14,14 @@ def test_tensor_type_alpha_eq(): t2 = relay.TensorType((), "float32") assert t1 == t2 +def test_constant_alpha_equal(): + x = convert(1) + y = convert(2) + assert alpha_equal(x, x) + assert not alpha_equal(x, y) + assert alpha_equal(x, convert(1)) -def test_incomplete_type_alpha_eq(): +def test_incomplete_type_alpha_equal(): t1 = relay.IncompleteType(relay.Kind.Shape) t2 = relay.IncompleteType(relay.Kind.Type) t3 = relay.IncompleteType(relay.Kind.Type) @@ -26,7 +33,7 @@ def test_incomplete_type_alpha_eq(): assert t2 != t3 -def test_type_param_alpha_eq(): +def test_type_param_alpha_equal(): t1 = relay.TypeParam("v1", relay.Kind.Type) t2 = relay.TypeParam("v2", relay.Kind.Shape) t3 = relay.TypeParam("v3", relay.Kind.Type) @@ -48,7 +55,7 @@ def test_type_param_alpha_eq(): assert ft1 != ft3 # kinds still do not match -def test_func_type_alpha_eq(): +def test_func_type_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") @@ -108,7 +115,7 @@ def test_func_type_alpha_eq(): assert ft != more_rels -def test_tuple_type_alpha_eq(): +def test_tuple_type_alpha_equal(): t1 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3, 4), "float32") tp1 = relay.TypeParam("v1", relay.Kind.Type) @@ -126,7 +133,7 @@ def test_tuple_type_alpha_eq(): assert tup1 != tup4 -def test_type_relation_alpha_eq(): +def test_type_relation_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") t3 = relay.TensorType((1, 2, 3, 4), "float32") @@ -162,9 +169,9 @@ def test_type_relation_alpha_eq(): if __name__ == "__main__": - test_tensor_type_alpha_eq() - test_incomplete_type_alpha_eq() - test_type_param_alpha_eq() - test_func_type_alpha_eq() - test_tuple_type_alpha_eq() - test_type_relation_alpha_eq() + test_tensor_type_alpha_equal() + test_incomplete_type_alpha_equal() + test_type_param_alpha_equal() + test_func_type_alpha_equal() + test_tuple_type_alpha_equal() + test_type_relation_alpha_equal() diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 5b837558042485721a97da035caa1d4b1e6031a2..97baf701347ac9204bbc9efa23dd49784479f01b 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -120,9 +120,9 @@ def test_recursion(): Program: def f(n: i32, data: f32) -> f32 { if (n == 0) { - return f(n - 1, log(data)); - } else { return data; + } else { + return f(n - 1, log(data)); } } f(2, 10000); @@ -133,9 +133,9 @@ def test_recursion(): data = b.param('data', ty='float32') with b.decl(f, n, data): with b.if_scope(equal(n, convert(0))): - b.ret(f(subtract(n, convert(1)), log(data))) - with b.else_scope(): b.ret(data) + with b.else_scope(): + b.ret(f(subtract(n, convert(1)), log(data))) b.ret(f(convert(2.0), convert(10000.0))) assert_decl_has_type(b.env, 'f', func_type( ['int32', 'float32'], 'float32')) @@ -160,11 +160,11 @@ def test_concat(): if __name__ == "__main__": test_dual_op() - test_recursion() test_monomorphic_let() test_single_op() test_add_op() test_add_broadcast_op() test_decl() + test_recursion() test_concat()