From 7d906c9d7d799e384c59ac1dd819d9babb49ff74 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?=
 <lolisa@marisa.moe>
Date: Sun, 30 Sep 2018 20:32:21 -0700
Subject: [PATCH] [Relay] Free Variables (#1786)

---
 include/tvm/relay/pass.h             |  30 +++++++
 python/tvm/relay/ir_pass.py          |   4 +
 src/relay/pass/type_visitor.h        |  12 +--
 src/relay/pass/util.cc               | 118 +++++++++++++++++++++++++++
 tests/python/relay/test_free_vars.py |  29 +++++++
 5 files changed, 187 insertions(+), 6 deletions(-)
 create mode 100644 src/relay/pass/util.cc
 create mode 100644 tests/python/relay/test_free_vars.py

diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h
index d3747c214..8b2a5fafd 100644
--- a/include/tvm/relay/pass.h
+++ b/include/tvm/relay/pass.h
@@ -92,6 +92,36 @@ bool AlphaEqual(const Type& t1, const Type& t2);
  */
 bool WellFormed(const Expr & e);
 
+/*! \brief Get free variables from expression e.
+ *
+ * Free variables are variables that are not bound by a let or a function parameter in the context.
+ *
+ * \param e the expression.
+ *
+ * \return the set of free variable.
+ */
+tvm::Array<Var> FreeVariables(const Expr & e);
+
+/*! \brief Get free type parameters from expression e.
+ *
+ * Free type parameters are type parameters that are not bound by a function type in the context.
+ *
+ * \param e the expression.
+ *
+ * \return the set of free type variables.
+ */
+tvm::Array<TypeParam> FreeTypeVariables(const Expr & e);
+
+/*! \brief Get free type parameters from type t.
+ *
+ * Free type parameters are type parameters that are not bound by a function type in the context.
+ *
+ * \param t the type.
+ *
+ * \return the set of free type variables.
+ */
+tvm::Array<TypeParam> FreeTypeVariables(const Type & t);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_PASS_H_
diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py
index 8a9612420..339b9f74d 100644
--- a/python/tvm/relay/ir_pass.py
+++ b/python/tvm/relay/ir_pass.py
@@ -14,3 +14,7 @@ check_expr = _ir_pass.check_expr
 well_formed = _ir_pass.well_formed
 
 check_kind = _ir_pass.check_kind
+
+free_vars = _ir_pass.free_vars
+
+free_type_vars = _ir_pass.free_type_vars
diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h
index 814894265..646826968 100644
--- a/src/relay/pass/type_visitor.h
+++ b/src/relay/pass/type_visitor.h
@@ -95,13 +95,13 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
                               type_params, type_constraints);
   }
 
-    Type VisitType_(const TupleTypeNode* op) override {
-      std::vector<Type> new_fields;
-      for (const Type& t : op->fields) {
-        new_fields.push_back(this->VisitType(t));
-      }
-      return TupleTypeNode::make(new_fields);
+  Type VisitType_(const TupleTypeNode* op) override {
+    std::vector<Type> new_fields;
+    for (const Type& t : op->fields) {
+      new_fields.push_back(this->VisitType(t));
     }
+    return TupleTypeNode::make(new_fields);
+  }
 
   Type VisitType_(const TypeRelationNode* type_rel) override {
     std::vector<Type> new_args;
diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc
new file mode 100644
index 000000000..5f87c3d4c
--- /dev/null
+++ b/src/relay/pass/util.cc
@@ -0,0 +1,118 @@
+/*!
+ * Copyright (c) 2018 by Contributors
+ *
+ * \file util.cc
+ *
+ * \brief simple util for relay.
+ */
+#include <tvm/relay/pass.h>
+#include <tvm/relay/expr_functor.h>
+#include "./type_visitor.h"
+
+namespace tvm {
+namespace relay {
+
+class FreeVar;
+class FreeTypeVar : private TypeVisitor<> {
+  std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars;
+  std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars;
+  FreeTypeVar(std::unordered_set<TypeParam, NodeHash, NodeEqual> * free_vars,
+              std::unordered_set<TypeParam, NodeHash, NodeEqual> * bound_vars) :
+    free_vars(free_vars), bound_vars(bound_vars) { }
+
+  void VisitType_(const TypeParamNode* tp) final {
+    auto var = GetRef<TypeParam>(tp);
+    if (bound_vars->count(var) == 0) {
+      free_vars->insert(var);
+    }
+  }
+
+  void VisitType_(const FuncTypeNode* f) final {
+    for (auto type_param : f->type_params) {
+      bound_vars->insert(type_param);
+    }
+
+    for (auto type_cs : f->type_constraints) {
+      this->VisitType(type_cs);
+    }
+
+    for (auto arg_type : f->arg_types) {
+      this->VisitType(arg_type);
+    }
+    this->VisitType(f->ret_type);
+  }
+  friend FreeVar;
+};
+
+class FreeVar : public ExprVisitor {
+  void VisitExpr_(const VarNode *v) final {
+    auto var = GetRef<Var>(v);
+    if (bound_vars.count(var) == 0) {
+      free_vars.insert(var);
+    }
+  }
+
+  void VisitExpr_(const FunctionNode *f) final {
+    for (const auto& tp : f->type_params) {
+      bound_types.insert(tp);
+    }
+    for (const auto& p : f->params) {
+      bound_vars.insert(p->var);
+    }
+    VisitExpr(f->body);
+    VisitType(f->ret_type);
+  }
+
+  void VisitExpr_(const LetNode *l) final {
+    bound_vars.insert(l->var);
+    VisitExpr(l->value);
+    VisitExpr(l->body);
+    VisitType(l->value_type);
+  }
+
+ public:
+  std::unordered_set<Var, NodeHash, NodeEqual> free_vars;
+  std::unordered_set<Var, NodeHash, NodeEqual> bound_vars;
+  std::unordered_set<TypeParam, NodeHash, NodeEqual> free_types;
+  std::unordered_set<TypeParam, NodeHash, NodeEqual> bound_types;
+
+  void VisitType(const Type& t) final {
+    FreeTypeVar(&free_types, &bound_types)(t);
+  }
+};
+
+tvm::Array<Var> FreeVariables(const Expr& e) {
+  FreeVar fv;
+  fv.VisitExpr(e);
+  return tvm::Array<Var>(fv.free_vars.begin(), fv.free_vars.end());
+}
+
+tvm::Array<TypeParam> FreeTypeVariables(const Expr& e) {
+  FreeVar fv;
+  fv.VisitExpr(e);
+  return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
+}
+
+tvm::Array<TypeParam> FreeTypeVariables(const Type& t) {
+  FreeVar fv;
+  fv.VisitType(t);
+  return tvm::Array<TypeParam>(fv.free_types.begin(), fv.free_types.end());
+}
+
+TVM_REGISTER_API("relay._ir_pass.free_vars")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = FreeVariables(args[0]);
+  });
+
+TVM_REGISTER_API("relay._ir_pass.free_type_vars")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    NodeRef x = args[0];
+    if (x.as<TypeNode>()) {
+      *ret = FreeTypeVariables(Downcast<Type>(x));
+    } else {
+      *ret = FreeTypeVariables(Downcast<Expr>(x));
+    }
+  });
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_free_vars.py b/tests/python/relay/test_free_vars.py
new file mode 100644
index 000000000..002646ada
--- /dev/null
+++ b/tests/python/relay/test_free_vars.py
@@ -0,0 +1,29 @@
+import tvm
+from tvm import relay
+from tvm.relay.ir_pass import free_vars, free_type_vars
+
+def test_free_vars():
+    x = relay.Var("x")
+    fvx = free_vars(x)
+    assert len(fvx) == 1
+    assert fvx[0] == x
+    v = relay.Constant(tvm.nd.array(10))
+    ty = relay.TensorType([], "int32")
+    let = relay.Let(x, v, x, ty)
+    fvx = free_vars(let)
+    assert len(free_vars(let)) == 0
+    f = relay.Function([relay.Param(x, ty)], ty, x)
+    assert len(free_vars(f)) == 0
+
+def test_free_type_vars():
+    tp = relay.TypeParam("")
+    ty = relay.TupleType([tp, relay.TensorType([], "int32")])
+    x = relay.Var("x")
+    y = relay.Var("y")
+    let = relay.Let(x, y, x, ty)
+    fvl = free_vars(let)
+    assert len(fvl) == 1
+    assert fvl[0] == y
+    ftvl = free_type_vars(let)
+    assert len(ftvl) == 1
+    assert ftvl[0] == tp
-- 
GitLab