From 9ce7f0a353bc092a5a683a38f4904eb83a6e5711 Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Sat, 26 Nov 2016 17:21:38 -0800
Subject: [PATCH] Check in inline and test

---
 include/tvm/ir.h             |  1 +
 include/tvm/ir_pass.h        | 15 ++++++---
 python/tvm/__init__.py       |  1 +
 python/tvm/_ctypes/_api.py   | 24 +++++++--------
 python/tvm/ir_pass.py        |  1 +
 src/c_api/c_api.cc           |  4 ++-
 src/c_api/c_api_ir.cc        |  5 ++-
 src/c_api/c_api_lang.cc      |  1 -
 src/c_api/c_api_pass.cc      | 35 +++++++++++++++++++++
 src/c_api/c_api_registry.h   | 22 +++++++++++--
 src/lang/tensor.cc           |  3 +-
 src/pass/inline.cc           | 60 ++++++++++++++++++++++++++++++++++++
 src/pass/ssa.cc              |  4 +--
 tests/cpp/ir_ssa_test.cc     |  6 ++--
 tests/python/test_inline.py  | 15 +++++++++
 tests/python/test_ir_pass.py | 17 ++++++++++
 16 files changed, 183 insertions(+), 31 deletions(-)
 create mode 100644 python/tvm/ir_pass.py
 create mode 100644 src/c_api/c_api_pass.cc
 create mode 100644 src/pass/inline.cc
 create mode 100644 tests/python/test_inline.py
 create mode 100644 tests/python/test_ir_pass.py

diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index dcf68e4a4..0ba993b02 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -18,6 +18,7 @@ namespace ir {
 
 using Halide::Internal::ExprNode;
 using Halide::Internal::IRNodeType;
+using Halide::Internal::ForType;
 
 /*! \brief Reduction operator operator */
 struct Reduce : public ExprNode<Reduce> {
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 29152cd13..def17377d 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -1,7 +1,10 @@
 /*!
  *  Copyright (c) 2016 by Contributors
  * \file ir_pass.h
- * \brief Collection of IR pass functions and visit functions
+ * \brief Collection of IR pass functions
+ *
+ *  All the pass functions in this file are for Stmt,
+ *  We can use PassFunction(Evaluate(expr)) to apply it to Expr
  */
 #ifndef TVM_IR_PASS_H_
 #define TVM_IR_PASS_H_
@@ -22,14 +25,14 @@ namespace ir {
  * \return Whether IR is in SSA form.
  * \note All the passes in this file uses SSA form and outputs SSA form.
  */
-bool VerifySSA(const IRNodeRef& ir);
+bool VerifySSA(const Stmt& ir);
 
 /*!
  * \brief Convert a IR node to be SSA form.
  * \param stmt The source statement to be converted.
  * \return The converted form.
  */
-Stmt ConvertSSA(const Stmt& stmt);
+Stmt ConvertSSA(Stmt stmt);
 
 /*!
  * \brief inline all calls of f in stmt.
@@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt);
  *
  * \note All the passes in this file uses SSA form and outputs SSA form.
  */
-Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
-
+Stmt Inline(FunctionRef f,
+            Array<Var> args,
+            Expr body,
+            Stmt stmt);
 
 }  // namespace ir
 }  // namespace tvm
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 4284b4595..f1c2ea41a 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -6,6 +6,7 @@ from . import tensor as tensor
 from . import expr
 from . import stmt
 from . import make
+from . import ir_pass
 from . import collections
 from . import schedule
 
diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py
index de8468ba5..8a1c0b247 100644
--- a/python/tvm/_ctypes/_api.py
+++ b/python/tvm/_ctypes/_api.py
@@ -224,21 +224,19 @@ def _init_function_module(root_namespace):
 
     module_obj = sys.modules["%s.function" % root_namespace]
     module_internal = sys.modules["%s._function_internal" % root_namespace]
-    module_make = sys.modules["%s.make" % root_namespace]
+    namespace_match = {
+        "_make_" : sys.modules["%s.make" % root_namespace],
+        "_pass_" : sys.modules["%s.ir_pass" % root_namespace]
+    }
 
     for name in op_names:
         hdl = FunctionHandle()
         check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
-        if name.startswith("_make_"):
-            fname = name[6:]
-        else:
-            fname = name
-
+        fname = name
+        target_module = module_internal if name.startswith('_') else module_obj
+        for k, v in namespace_match.items():
+            if name.startswith(k):
+                fname = name[len(k):]
+                target_module = v
         function = _make_function(hdl, fname)
-
-        if name.startswith("_make_"):
-            setattr(module_make, function.__name__, function)
-        elif function.__name__.startswith('_'):
-            setattr(module_internal, function.__name__, function)
-        else:
-            setattr(module_obj, function.__name__, function)
+        setattr(target_module, function.__name__, function)
diff --git a/python/tvm/ir_pass.py b/python/tvm/ir_pass.py
new file mode 100644
index 000000000..3ba8c219a
--- /dev/null
+++ b/python/tvm/ir_pass.py
@@ -0,0 +1 @@
+"""Namespace of IR pass functions"""
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 7e9e32b33..9d540ed63 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg,
   API_BEGIN();
   ret->arg_stack.resize(ret->arg_stack.size() + 1);
   APIVariantValue& v = ret->arg_stack.back();
+
   v.type_id = static_cast<ArgVariantID>(type_id);
   if (type_id == kStr) {
-    v = arg.v_str;
+    v.str = arg.v_str;
   }  else if (type_id == kNodeHandle) {
     v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
   } else {
     v.v_union = arg;
   }
+
   API_END_HANDLE_ERROR(ret->Clear());
 }
 
diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc
index 79c6ac7e4..94b65c230 100644
--- a/src/c_api/c_api_ir.cc
+++ b/src/c_api/c_api_ir.cc
@@ -9,9 +9,7 @@
 #include "./c_api_registry.h"
 
 namespace tvm {
-
-using namespace tvm::ir;
-using namespace Halide::Internal;
+namespace ir {
 
 using ArgStack = const std::vector<APIVariantValue>;
 using RetValue = APIVariantValue;
@@ -135,4 +133,5 @@ REGISTER_MAKE2(Block);
 REGISTER_MAKE3(IfThenElse);
 REGISTER_MAKE1(Evaluate);
 
+}  // namespace ir
 }  // namespace tvm
diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc
index b2d28164a..91bd2ef9c 100644
--- a/src/c_api/c_api_lang.cc
+++ b/src/c_api/c_api_lang.cc
@@ -19,7 +19,6 @@ using RetValue = APIVariantValue;
 TVM_REGISTER_API(_const)
 .set_body([](const ArgStack& args,  RetValue *ret) {
     using Halide::Internal::make_const;
-
     if (args.at(0).type_id == kLong) {
       *ret = make_const(args.at(1), args.at(0).operator int64_t());
     } else if (args.at(0).type_id == kDouble) {
diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
new file mode 100644
index 000000000..d3046ac91
--- /dev/null
+++ b/src/c_api/c_api_pass.cc
@@ -0,0 +1,35 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ *  Exposre of pass functions.
+ * \file c_api_pass.cc
+ */
+#include <tvm/expr.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "./c_api_registry.h"
+
+namespace tvm {
+namespace ir {
+
+using ArgStack = const std::vector<APIVariantValue>;
+using RetValue = APIVariantValue;
+
+// make from two arguments
+#define REGISTER_PASS1(PassName)                                  \
+  TVM_REGISTER_API(_pass_## PassName)                             \
+  .set_body([](const ArgStack& args,  RetValue *ret) {            \
+      *ret = PassName(args.at(0));                                \
+    })                                                            \
+
+#define REGISTER_PASS4(PassName)                                        \
+  TVM_REGISTER_API(_pass_## PassName)                                   \
+  .set_body([](const ArgStack& args,  RetValue *ret) {                  \
+      *ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3));  \
+    })                                                                  \
+
+REGISTER_PASS1(ConvertSSA);
+REGISTER_PASS1(VerifySSA);
+REGISTER_PASS4(Inline);
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h
index 97a7e0c14..8004bfe38 100644
--- a/src/c_api/c_api_registry.h
+++ b/src/c_api/c_api_registry.h
@@ -43,6 +43,17 @@ inline Type String2Type(std::string s) {
   return Type(code, bits, lanes);
 }
 
+inline const char* TypeId2Str(ArgVariantID type_id) {
+  switch (type_id) {
+    case kNull: return "Null";
+    case kLong: return "Long";
+    case kDouble: return "Double";
+    case kStr: return "Str";
+    case kNodeHandle: return "NodeHandle";
+    default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
+  }
+}
+
 /*! \brief Variant container for API calls */
 class APIVariantValue {
  public:
@@ -74,6 +85,11 @@ class APIVariantValue {
     v_union.v_long = value;
     return *this;
   }
+  inline APIVariantValue& operator=(bool value) {
+    type_id = kLong;
+    v_union.v_long = value;
+    return *this;
+  }
   inline APIVariantValue& operator=(std::string value) {
     type_id = kStr;
     str = std::move(value);
@@ -130,11 +146,13 @@ class APIVariantValue {
     return v_union.v_long;
   }
   inline operator bool() const {
-    CHECK_EQ(type_id, kLong);
+    CHECK_EQ(type_id, kLong)
+        << "expect boolean(int) but get " << TypeId2Str(type_id);
     return v_union.v_long != 0;
   }
   inline operator std::string() const {
-    CHECK_EQ(type_id, kStr);
+    CHECK_EQ(type_id, kStr)
+        << "expect Str but get " << TypeId2Str(type_id);
     return str;
   }
   inline operator Type() const {
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
index 643332701..fb02dde25 100644
--- a/src/lang/tensor.cc
+++ b/src/lang/tensor.cc
@@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const {
   CHECK_EQ(ndim(), indices.size())
       << "Tensor dimension mismatch in read"
       << "ndim = " << ndim() << ", indices.size=" << indices.size();
-  return Call::make(
+  auto n Call::make(
       (*this)->dtype, (*this)->name, indices, Call::Halide, *this);
+  return n;
 }
 
 
diff --git a/src/pass/inline.cc b/src/pass/inline.cc
new file mode 100644
index 000000000..669324225
--- /dev/null
+++ b/src/pass/inline.cc
@@ -0,0 +1,60 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file inline.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_pass.h>
+
+namespace tvm {
+namespace ir {
+namespace {
+
+// inliner to inline a function
+// the result may not be SSA,
+// ConvertSSA need to be applied after this pass
+class IRInline : public IRMutator {
+ public:
+  IRInline(FunctionRef f, Array<Var> args, Expr body)
+      : f_(f), args_(args), body_(body) {}
+
+  Expr Mutate(Expr expr) final {
+    const Call* call = expr.as<Call>();
+    if (call != nullptr && call->func == f_) {
+      return InlineCall(call);
+    } else {
+      return IRMutator::Mutate(expr);
+    }
+  }
+
+  Stmt Mutate(Stmt stmt) final {
+    return IRMutator::Mutate(stmt);
+  }
+
+ private:
+  FunctionRef f_;
+  Array<Var> args_;
+  Expr body_;
+
+  Expr InlineCall(const Call* op) {
+    Expr expr = body_;
+
+    CHECK_EQ(args_.size(), op->args.size())
+        << op->args.size() << " vs " << args_.size();
+    for (size_t i = 0; i < args_.size(); ++i) {
+      expr = Let::make(args_[i], op->args[i], expr);
+    }
+    return expr;
+  }
+};
+
+}  // namespace
+
+Stmt Inline(FunctionRef f,
+            Array<Var> args,
+            Expr body,
+            Stmt stmt) {
+  return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
+}
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc
index 12beffeb9..44b2454de 100644
--- a/src/pass/ssa.cc
+++ b/src/pass/ssa.cc
@@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator {
 
 }  // namespace
 
-bool VerifySSA(const IRNodeRef& ir) {
+bool VerifySSA(const Stmt& ir) {
   IRVerifySSA v;
   v.Visit(ir);
   return v.is_ssa;
 }
 
-Stmt ConvertSSA(const Stmt& stmt) {
+Stmt ConvertSSA(Stmt stmt) {
   return IRConvertSSA().Mutate(stmt);
 }
 
diff --git a/tests/cpp/ir_ssa_test.cc b/tests/cpp/ir_ssa_test.cc
index 0f0f9e6da..2de7dba08 100644
--- a/tests/cpp/ir_ssa_test.cc
+++ b/tests/cpp/ir_ssa_test.cc
@@ -10,9 +10,9 @@ TEST(IRSSA, Convert) {
   Var x("x"), y;
   Expr let = Let::make(x, 1, x + 1);
 
-  auto z = let + let;
+  auto z = Evaluate::make(let + let);
   CHECK(!ir::VerifySSA(z));
-  auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
+  auto z_ssa = ir::ConvertSSA(z);
   CHECK(ir::VerifySSA(z_ssa));
 }
 
@@ -20,7 +20,7 @@ TEST(IRSSA, Basic) {
   using namespace Halide::Internal;
   using namespace tvm;
   Var x("x"), y;
-  auto z = x + y;
+  auto z = Evaluate::make(x + y);
   CHECK(ir::VerifySSA(z));
 }
 
diff --git a/tests/python/test_inline.py b/tests/python/test_inline.py
new file mode 100644
index 000000000..9695a832f
--- /dev/null
+++ b/tests/python/test_inline.py
@@ -0,0 +1,15 @@
+import tvm
+
+def test_inline():
+    m = tvm.Var('m')
+    A = tvm.placeholder((m,), name='A')
+    T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
+    X = T(100)
+    stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
+    stmt = tvm.ir_pass.Inline(
+        T, T.source_op.iter_var, T.source_op.body, stmt)
+    print(stmt)
+    assert(tvm.ir_pass.VerifySSA(stmt))
+
+if __name__ == "__main__":
+    test_inline()
diff --git a/tests/python/test_ir_pass.py b/tests/python/test_ir_pass.py
new file mode 100644
index 000000000..23262f1cc
--- /dev/null
+++ b/tests/python/test_ir_pass.py
@@ -0,0 +1,17 @@
+import tvm
+
+def test_verify_ssa():
+    x = tvm.Var('x')
+    y = tvm.Var()
+    z = tvm.make.Evaluate(x + y)
+    assert(tvm.ir_pass.VerifySSA(z))
+
+
+def test_convert_ssa():
+    x = tvm.Var('x')
+    y = tvm.Var()
+    let = tvm.make.Let(x, 1, x + 1)
+    z = tvm.make.Evaluate(let + let)
+    assert(not tvm.ir_pass.VerifySSA(z))
+    z_ssa = tvm.ir_pass.ConvertSSA(z)
+    assert(tvm.ir_pass.VerifySSA(z_ssa))
-- 
GitLab