From 7d906c9d7d799e384c59ac1dd819d9babb49ff74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= <lolisa@marisa.moe> Date: Sun, 30 Sep 2018 20:32:21 -0700 Subject: [PATCH] [Relay] Free Variables (#1786) --- include/tvm/relay/pass.h | 30 +++++++ python/tvm/relay/ir_pass.py | 4 + src/relay/pass/type_visitor.h | 12 +-- src/relay/pass/util.cc | 118 +++++++++++++++++++++++++++ tests/python/relay/test_free_vars.py | 29 +++++++ 5 files changed, 187 insertions(+), 6 deletions(-) create mode 100644 src/relay/pass/util.cc create mode 100644 tests/python/relay/test_free_vars.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index d3747c214..8b2a5fafd 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -92,6 +92,36 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr & e); +/*! \brief Get free variables from expression e. + * + * Free variables are variables that are not bound by a let or a function parameter in the context. + * + * \param e the expression. + * + * \return the set of free variable. + */ +tvm::Array<Var> FreeVariables(const Expr & e); + +/*! \brief Get free type parameters from expression e. + * + * Free type parameters are type parameters that are not bound by a function type in the context. + * + * \param e the expression. + * + * \return the set of free type variables. + */ +tvm::Array<TypeParam> FreeTypeVariables(const Expr & e); + +/*! \brief Get free type parameters from type t. + * + * Free type parameters are type parameters that are not bound by a function type in the context. + * + * \param t the type. + * + * \return the set of free type variables. + */ +tvm::Array<TypeParam> FreeTypeVariables(const Type & t); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_H_ diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8a9612420..339b9f74d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -14,3 +14,7 @@ check_expr = _ir_pass.check_expr well_formed = _ir_pass.well_formed check_kind = _ir_pass.check_kind + +free_vars = _ir_pass.free_vars + +free_type_vars = _ir_pass.free_type_vars diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 814894265..646826968 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -95,13 +95,13 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> { type_params, type_constraints); } - Type VisitType_(const TupleTypeNode* op) override { - std::vector<Type> new_fields; - for (const Type& t : op->fields) { - new_fields.push_back(this->VisitType(t)); - } - return TupleTypeNode::make(new_fields); + Type VisitType_(const TupleTypeNode* op) override { + std::vector<Type> new_fields; + for (const Type& t : op->fields) { + new_fields.push_back(this->VisitType(t)); } + return TupleTypeNode::make(new_fields); + } Type VisitType_(const TypeRelationNode* type_rel) override { std::vector<Type> new_args; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc new file mode 100644 index 000000000..5f87c3d4c --- /dev/null +++ b/src/relay/pass/util.cc @@ -0,0 +1,118 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file util.cc + * + * \brief simple util for relay. + */ +#include <tvm/relay/pass.h> +#include <tvm/relay/expr_functor.h> +#include "./type_visitor.h" + +namespace tvm { +namespace relay { + +class FreeVar; +class FreeTypeVar : private TypeVisitor<> { + std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars; + std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars; + FreeTypeVar(std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars, + std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars) : + free_vars(free_vars), bound_vars(bound_vars) { } + + void VisitType_(const TypeParamNode* tp) final { + auto var = GetRef<TypeParam>(tp); + if (bound_vars->count(var) == 0) { + free_vars->insert(var); + } + } + + void VisitType_(const FuncTypeNode* f) final { + for (auto type_param : f->type_params) { + bound_vars->insert(type_param); + } + + for (auto type_cs : f->type_constraints) { + this->VisitType(type_cs); + } + + for (auto arg_type : f->arg_types) { + this->VisitType(arg_type); + } + this->VisitType(f->ret_type); + } + friend FreeVar; +}; + +class FreeVar : public ExprVisitor { + void VisitExpr_(const VarNode *v) final { + auto var = GetRef<Var>(v); + if (bound_vars.count(var) == 0) { + free_vars.insert(var); + } + } + + void VisitExpr_(const FunctionNode *f) final { + for (const auto& tp : f->type_params) { + bound_types.insert(tp); + } + for (const auto& p : f->params) { + bound_vars.insert(p->var); + } + VisitExpr(f->body); + VisitType(f->ret_type); + } + + void VisitExpr_(const LetNode *l) final { + bound_vars.insert(l->var); + VisitExpr(l->value); + VisitExpr(l->body); + VisitType(l->value_type); + } + + public: + std::unordered_set<Var, NodeHash, NodeEqual> free_vars; + std::unordered_set<Var, NodeHash, NodeEqual> bound_vars; + std::unordered_set<TypeParam, NodeHash, NodeEqual> free_types; + std::unordered_set<TypeParam, NodeHash, NodeEqual> bound_types; + + void VisitType(const Type& t) final { + FreeTypeVar(&free_types, &bound_types)(t); + } +}; + +tvm::Array<Var> FreeVariables(const Expr& e) { + FreeVar fv; + fv.VisitExpr(e); + return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end()); +} + +tvm::Array<TypeParam> FreeTypeVariables(const Expr& e) { + FreeVar fv; + fv.VisitExpr(e); + return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end()); +} + +tvm::Array<TypeParam> FreeTypeVariables(const Type& t) { + FreeVar fv; + fv.VisitType(t); + return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end()); +} + +TVM_REGISTER_API("relay._ir_pass.free_vars") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = FreeVariables(args[0]); + }); + +TVM_REGISTER_API("relay._ir_pass.free_type_vars") +.set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as<TypeNode>()) { + *ret = FreeTypeVariables(Downcast<Type>(x)); + } else { + *ret = FreeTypeVariables(Downcast<Expr>(x)); + } + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_free_vars.py b/tests/python/relay/test_free_vars.py new file mode 100644 index 000000000..002646ada --- /dev/null +++ b/tests/python/relay/test_free_vars.py @@ -0,0 +1,29 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import free_vars, free_type_vars + +def test_free_vars(): + x = relay.Var("x") + fvx = free_vars(x) + assert len(fvx) == 1 + assert fvx[0] == x + v = relay.Constant(tvm.nd.array(10)) + ty = relay.TensorType([], "int32") + let = relay.Let(x, v, x, ty) + fvx = free_vars(let) + assert len(free_vars(let)) == 0 + f = relay.Function([relay.Param(x, ty)], ty, x) + assert len(free_vars(f)) == 0 + +def test_free_type_vars(): + tp = relay.TypeParam("") + ty = relay.TupleType([tp, relay.TensorType([], "int32")]) + x = relay.Var("x") + y = relay.Var("y") + let = relay.Let(x, y, x, ty) + fvl = free_vars(let) + assert len(fvl) == 1 + assert fvl[0] == y + ftvl = free_type_vars(let) + assert len(ftvl) == 1 + assert ftvl[0] == tp -- GitLab