From 38f03f1f78f88a115829444c6e790b68417250ae Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Fri, 18 Nov 2016 21:30:15 -0800
Subject: [PATCH] SSA Pass

---
 HalideIR                                      |   2 +-
 Makefile                                      |   1 +
 include/tvm/expr.h                            |   2 +
 include/tvm/ir_pass.h                         |   2 +-
 src/pass/ir_pass.cc                           | 138 --------------
 src/pass/ssa.cc                               | 171 ++++++++++++++++++
 tests/cpp/{ir_pass_test.cc => ir_ssa_test.cc} |  28 +--
 7 files changed, 191 insertions(+), 153 deletions(-)
 delete mode 100644 src/pass/ir_pass.cc
 create mode 100644 src/pass/ssa.cc
 rename tests/cpp/{ir_pass_test.cc => ir_ssa_test.cc} (52%)

diff --git a/HalideIR b/HalideIR
index 4becbde67..24a7c0357 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit 4becbde67c8aa565941b02648cea90f50211f8dc
+Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507
diff --git a/Makefile b/Makefile
index 7daddbd95..d2f8bd71b 100644
--- a/Makefile
+++ b/Makefile
@@ -1,3 +1,4 @@
+export CXX=g++
 export LDFLAGS = -pthread -lm
 export CFLAGS =  -std=c++11 -Wall -O2  -Wno-unknown-pragmas -funroll-loops\
 	 -Iinclude -Idmlc-core/include -IHalideIR/src  -fPIC
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 0be106346..24e792fea 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -27,7 +27,9 @@ using Halide::abs;
 using Halide::select;
 
 using Halide::Expr;
+
 using Halide::VarExpr;
+using Halide::IR::FunctionRef;
 using Halide::IR::FunctionBaseNode;
 using Halide::Internal::Stmt;
 
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 3baca6bae..c02e57565 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -28,7 +28,7 @@ bool VerifySSA(const IRNodeRef& ir);
  * \param stmt The source statement to be converted.
  * \return The converted form.
  */
-Stmt ConvertSSA(const Stmt stmt);
+Stmt ConvertSSA(const Stmt& stmt);
 
 /*!
  * \brief inline all calls of f in stmt.
diff --git a/src/pass/ir_pass.cc b/src/pass/ir_pass.cc
deleted file mode 100644
index 9abf04bd2..000000000
--- a/src/pass/ir_pass.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/*!
- *  Copyright (c) 2016 by Contributors
- * \file ir_pass.cc
- */
-#include <tvm/ir.h>
-#include <tvm/ir_visitor.h>
-#include <tvm/ir_mutator.h>
-#include <unordered_set>
-
-namespace tvm {
-namespace ir {
-namespace {
-
-struct SetVarDef {
-  // get var definition from node
-  using FType = IRFunctor<const Variable*(const IRNodeRef&)>;
-  static FGetVarDef& vtable_get_var_def() {  // NOLINT(*)
-    static FGetVarDef inst; return inst;
-  }
-  static FSetVarExpr& vtable_set_var_expr() {  // NOLINT(*)
-    static FSetVarExpr inst; return inst;
-  }
-  static FSetVarStmt& vtable_set_var_expr() {  // NOLINT(*)
-    static FSetVarStmt inst; return inst;
-  }
-};
-
-  // return a new node to
-  using FSetVarExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
-  // return a new node to
-  using FSetVarStmt = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
-
-inline const Variable* GetVarDef(const IRNodeRef& n) {
-  if (n.as<Let>()) {
-    return n.as<Let>()->var.get();
-  } else if (n.as<LetStmt>()) {
-    return n.as<LetStmt>()->var.get();
-  } else if (n.as<For>()) {
-    return n.as<For>()->loop_var.get();
-  } else if (n.as<Allocate>()) {
-    return n.as<Allocate>()->buffer_var.get();
-  } else {
-    return nullptr;
-  }
-}
-
-inline Expr ResetVar(const Expr& n, VarExpr var) {
-  if (n.as<Let>()) {
-    std::shared_ptr<Let> x = std::make_shared<Let>(*n.as<Let>());
-    x->var = var;
-    return Expr(x);
-  } else if (n.as<Allocate>()) {
-  }
-}
-
-inline Stmt ResetVarDef(const Stmt& n, VarExpr var) {
-  if (n.as<LetStmt>()) {
-    std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*n.as<Let>());
-    x->var = var;
-    return Expr(x);
-  } else if (n.as<For>()) {
-    std::shared_ptr<For> x = std::make_shared<For>(*n.as<Let>());
-    x->loop_var = var;
-    return Expr(x);
-  } else {
-    LOG(FATAL) << "not reached";
-  }
-}
-
-class IRVerifySSA : public IRVisitor {
- public:
-  bool is_ssa{true};
-  std::unordered_set<const Variable*> defined;
-
-  void Visit(const IRNodeRef& n) final {
-    if (!is_ssa) return;
-    const Variable* v = GetVarDef(n);
-    if (v != nullptr) {
-      if (defined.count(v) != 0) {
-        is_ssa = false; return;
-      } else {
-        defined.insert(v);
-      }
-    }
-    IRVisitor::Visit(n);
-  }
-};
-
-class IRConvertSSA : public IRMutator {
- public:
-  Expr Mutate(Expr expr) final {
-    static const auto& f = IRConvertSSA::vtable_expr();
-    return (f.can_dispatch(expr) ?
-            f(expr, expr, this) : IRMutator::Mutate(expr));
-  }
-  Stmt Mutate(Stmt stmt) final {
-    static const auto& f = IRMutatorExample::vtable_stmt();
-    return (f.can_dispatch(stmt) ?
-            f(stmt, stmt, this) : IRMutator::Mutate(stmt));
-  }
-  using FConvertExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
-  using FConvertStmt = IRFunctor<Stmt(const IRNodeRef&, const Expr&, IRConvertSSA *)>;
-  std::unordered_map<const Variable*, std::vector<VarExpr> > scope;
-  std::unordered_set<const Variable*> defined;
-};
-
-temple<>
-
-TVM_STATIC_IR_FUNCTOR(IRConvertSSA, vtable_expr)
-.set_dispatch<Let>([](const Let* op, const Expr& e, IRConvertSSA* m) {
-    VarExpr var = op->var;
-    if (m->defined.count(var.get()) != 0) {
-      var = Variable::make(var->type, var->name_hint);
-    }
-    // insert scope before recursion.
-    m->scope[var.get()].push_back(var);
-    Expr new_expr = Mutate(e);
-    m->scope[var.get()].pop_back();
-
-    if (!var.same_as(op->var)) {
-      std::shared_ptr<Let> x = std::make_shared<Let>(*new_expr.as<Let>());
-      x->var = var;
-      return Expr(x);
-    } else {
-      return new_expr;
-    }
-  });
-
-}  // namespace
-
-bool VerifySSA(const IRNodeRef& ir) {
-  IRVerifySSA v;
-  v.Visit(ir);
-  return v.is_ssa;
-}
-
-}  // namespace ir
-}  // namespace tvm
diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc
new file mode 100644
index 000000000..556626418
--- /dev/null
+++ b/src/pass/ssa.cc
@@ -0,0 +1,171 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ *  SSA related checks and pass.
+ * \file ssa.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace ir {
+namespace {
+
+// global functor to get var definition from
+struct FGetVarDef {
+  using FType = IRFunctor<VarExpr (const IRNodeRef&)>;
+  static FType& vtable() {  // NOLINT(*)
+    static FType inst; return inst;
+  }
+};
+TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
+.set_dispatch<Let>([](const Let* op) {
+    return op->var;
+  })
+.set_dispatch<LetStmt>([](const LetStmt* op) {
+    return op->var;
+  })
+.set_dispatch<For>([](const For* op) {
+    return op->loop_var;
+  })
+.set_dispatch<Allocate>([](const Allocate* op) {
+    return op->buffer_var;
+  });
+
+struct FSetVarDef {
+  using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
+  using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>;
+  static FTypeExpr& vtable_expr() {  // NOLINT(*)
+    static FTypeExpr inst; return inst;
+  }
+  static FTypeStmt& vtable_stmt() {  // NOLINT(*)
+    static FTypeStmt inst; return inst;
+  }
+};
+TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr)
+.set_dispatch<Let>([](const Let* op, VarExpr var) {
+    std::shared_ptr<Let> x = std::make_shared<Let>(*op);
+    x->var = var;
+    return Expr(x);
+  });
+
+TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt)
+.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) {
+    std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op);
+    x->var = var;
+    return Stmt(x);
+  })
+.set_dispatch<For>([](const For* op, VarExpr var) {
+    std::shared_ptr<For> x = std::make_shared<For>(*op);
+    x->loop_var = var;
+    return Stmt(x);
+  });
+
+class IRVerifySSA : public IRVisitor {
+ public:
+  bool is_ssa{true};
+
+  void Visit(const IRNodeRef& n) final {
+    if (!is_ssa) return;
+    static auto& fget_var_def = FGetVarDef::vtable();
+    if (fget_var_def.can_dispatch(n)) {
+      VarExpr v = fget_var_def(n);
+      if (defined_.count(v.get()) != 0) {
+        is_ssa = false; return;
+      } else {
+        defined_[v.get()] = 1;
+      }
+    }
+    IRVisitor::Visit(n);
+  }
+
+ private:
+  std::unordered_map<const Variable*, int> defined_;
+};
+
+class IRConvertSSA : public IRMutator {
+ public:
+  Expr Mutate(Expr expr) final {
+    static auto& fget_var_def = FGetVarDef::vtable();
+    static auto& fset_var_def = FSetVarDef::vtable_expr();
+    if (fget_var_def.can_dispatch(expr)) {
+      VarExpr v = fget_var_def(expr);
+      VarExpr new_var = v;
+      if (defined_.count(v.get()) != 0) {
+        CHECK(expr.as<Allocate>() == nullptr)
+            << "One allocation in two places, cannot rename buffer in allocate";
+        new_var = Variable::make(v->type, v->name_hint);
+      } else {
+        defined_.insert(v.get());
+      }
+      scope_[v.get()].push_back(new_var);
+      Expr new_expr = IRMutator::Mutate(expr);
+      scope_[v.get()].pop_back();
+
+      if (!new_var.same_as(v)) {
+        return fset_var_def(new_expr, new_var);
+      } else {
+        return new_expr;
+      }
+    } else if (expr.as<Variable>()) {
+      const Variable* v = expr.as<Variable>();
+      if (scope_.count(v) != 0) {
+        return scope_[v].back();
+      } else {
+        return expr;
+      }
+    } else {
+      Expr e = IRMutator::Mutate(expr);
+      return e;
+
+    }
+  }
+
+  Stmt Mutate(Stmt stmt) final {
+    static auto& fget_var_def = FGetVarDef::vtable();
+    static auto& fset_var_def = FSetVarDef::vtable_stmt();
+    if (fget_var_def.can_dispatch(stmt)) {
+      VarExpr v = fget_var_def(stmt);
+      VarExpr new_var = v;
+      if (defined_.count(v.get()) != 0) {
+        new_var = Variable::make(v->type, v->name_hint);
+      } else {
+        defined_.insert(v.get());
+      }
+      scope_[v.get()].push_back(new_var);
+      Stmt new_stmt = IRMutator::Mutate(stmt);
+      scope_[v.get()].pop_back();
+
+      if (!new_var.same_as(v)) {
+        return fset_var_def(new_stmt, new_var);
+      } else {
+        return new_stmt;
+      }
+    } else {
+      return IRMutator::Mutate(stmt);
+    }
+  }
+
+ private:
+  std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
+  std::unordered_set<const Variable*> defined_;
+};
+
+}  // namespace
+
+bool VerifySSA(const IRNodeRef& ir) {
+  IRVerifySSA v;
+  v.Visit(ir);
+  return v.is_ssa;
+}
+
+Stmt ConvertSSA(const Stmt& stmt) {
+  return IRConvertSSA().Mutate(stmt);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/tests/cpp/ir_pass_test.cc b/tests/cpp/ir_ssa_test.cc
similarity index 52%
rename from tests/cpp/ir_pass_test.cc
rename to tests/cpp/ir_ssa_test.cc
index 4397cfa12..0f0f9e6da 100644
--- a/tests/cpp/ir_pass_test.cc
+++ b/tests/cpp/ir_ssa_test.cc
@@ -3,23 +3,25 @@
 #include <tvm/tvm.h>
 #include <tvm/ir_pass.h>
 
-TEST(IRPass, Substitute) {
+
+TEST(IRSSA, Convert) {
+  using namespace Halide::Internal;
+  using namespace tvm;
+  Var x("x"), y;
+  Expr let = Let::make(x, 1, x + 1);
+
+  auto z = let + let;
+  CHECK(!ir::VerifySSA(z));
+  auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
+  CHECK(ir::VerifySSA(z_ssa));
+}
+
+TEST(IRSSA, Basic) {
   using namespace Halide::Internal;
   using namespace tvm;
   Var x("x"), y;
   auto z = x + y;
-  {
-    auto zz = ir::Substitute({{y.get(), 11}}, z);
-    std::ostringstream os;
-    os << zz;
-    CHECK(os.str() == "(x + 11)");
-  }
-  {
-    auto zz = ir::Substitute({{z.get(), 11}}, z);
-    std::ostringstream os;
-    os << zz;
-    CHECK(os.str() == "11");
-  }
+  CHECK(ir::VerifySSA(z));
 }
 
 int main(int argc, char ** argv) {
-- 
GitLab