From 4f1473f3a10f613df795816bcee541feef12f48d Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 22 Jan 2017 17:45:31 -0800
Subject: [PATCH] [CODEGEN] Add LoweredFunc, MakeAPI to build a C API function
 (#23)

* [CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice

* update halideir
---
 HalideIR                                  |   2 +-
 include/tvm/buffer.h                      |   3 +
 include/tvm/c_runtime_api.h               |  40 ++-
 include/tvm/codegen.h                     |  68 +++++
 include/tvm/ir.h                          |  42 +++
 include/tvm/ir_mutator.h                  |  15 +
 include/tvm/ir_pass.h                     |   7 +-
 include/tvm/ir_visitor.h                  |  11 +
 include/tvm/module.h                      | 108 ++++++++
 python/tvm/collections.py                 |   6 +
 src/base/common.h                         |   3 +-
 src/c_api/c_api_codegen.cc                |  13 +-
 src/codegen/codegen_c.cc                  | 217 +++++++++++----
 src/codegen/codegen_c.h                   |  23 +-
 src/codegen/make_api.cc                   | 200 ++++++++++++++
 src/codegen/split_host_device.cc          | 218 +++++++++++++++
 src/pass/inline.cc                        |  36 +--
 src/pass/ir_mutator.cc                    | 316 ++++++++++++----------
 src/pass/ir_util.h                        |  70 +++++
 src/pass/ir_visitor.cc                    | 123 +++++----
 src/pass/schedule_ops.cc                  |  81 +++---
 src/pass/simple_passes.cc                 |  36 +++
 tests/python/test_codegen_cuda.py         |  28 +-
 tests/python/test_codegen_makeapi.py      |  27 ++
 tests/python/test_pass_storage_flatten.py |   1 +
 25 files changed, 1346 insertions(+), 348 deletions(-)
 create mode 100644 include/tvm/codegen.h
 create mode 100644 include/tvm/module.h
 create mode 100644 src/codegen/make_api.cc
 create mode 100644 src/codegen/split_host_device.cc
 create mode 100644 src/pass/ir_util.h
 create mode 100644 src/pass/simple_passes.cc
 create mode 100644 tests/python/test_codegen_makeapi.py

diff --git a/HalideIR b/HalideIR
index adfa66240..30bf0f043 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
+Subproject commit 30bf0f043e6388418958fd1f29259ee43c42b600
diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h
index beed7e9d1..2e4d7debc 100644
--- a/include/tvm/buffer.h
+++ b/include/tvm/buffer.h
@@ -50,6 +50,9 @@ class Buffer : public NodeRef {
    * \return the pointer to the internal node container
    */
   inline const BufferNode* operator->() const;
+
+  /*! \brief specify container node */
+  using ContainerType = BufferNode;
 };
 
 /*! \brief Node to represent a buffer */
diff --git a/include/tvm/c_runtime_api.h b/include/tvm/c_runtime_api.h
index 1a21adc41..25b81d80c 100644
--- a/include/tvm/c_runtime_api.h
+++ b/include/tvm/c_runtime_api.h
@@ -30,6 +30,7 @@
 #endif
 
 #include <stdint.h>
+#include <stddef.h>
 
 
 TVM_EXTERN_C {
@@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
 TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
 
 /*!
- * \brief Launch a generated TVM function
+ * \brief TVM Function API: Get resource requirement
+ *
+ *  By default TVM function try not to do internal allocations.
+ *  Instead, TVMFuncRequirement can be called, given the input arguments.
+ *
+ * \param func function handle to be launched.
+ * \param args The arguments
+ * \param arg_type_ids The type id of the arguments
+ * \param num_args Number of arguments.
+ * \param out_workspace_size The workspace size needed to launch this function.
+ * \param out_workspace_align The alignment requirement of workspace.
+ *
+ * \note The data pointer in the arrays is not used by requirement.
+ */
+TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
+                               TVMArg* args,
+                               int* arg_type_ids,
+                               int num_args,
+                               size_t* out_workspace_size,
+                               size_t* out_workspace_align);
+
+/*!
+ * \brief TVM Function API: Launch generated function.
+ *
  * \param func function handle to be launched.
  * \param args The arguments
  * \param arg_type_ids The type id of the arguments
  * \param num_args Number of arguments.
  * \param stream The stream this function to be launched on.
+ * \param workspace Additional workspace used to launch this function.
+ *
+ * \sa TVMFuncRequirement
  */
-TVM_DLL int TVMLaunch(TVMFunctionHandle func,
-                      TVMArg* args,
-                      int* arg_type_ids,
-                      int num_args,
-                      TVMStreamHandle stream);
+TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
+                          TVMArg* args,
+                          int* arg_type_ids,
+                          int num_args,
+                          TVMStreamHandle stream,
+                          TVMArrayHandle workspace);
 }  // TVM_EXTERN_C
 
 #endif  // TVM_C_RUNTIME_API_H_
diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h
new file mode 100644
index 000000000..b4a15e5a5
--- /dev/null
+++ b/include/tvm/codegen.h
@@ -0,0 +1,68 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file codegen.h
+ * \brief Collection of Lowlevel IR pass to codegen.
+ */
+#ifndef TVM_CODEGEN_H_
+#define TVM_CODEGEN_H_
+
+#include <string>
+#include "./base.h"
+#include "./expr.h"
+#include "./module.h"
+
+namespace tvm {
+/*! \brief namespace for lowlevel IR pass and codegen */
+namespace codegen {
+/*!
+ * \brief Make an user callable API LoweredFunc.
+ *
+ *  The main task of this function is to create code to :
+ *   - Map the values in the api_args to of Var that is required by body.
+ *   - Insert assertions to check type/value of the passed arguments.
+ *
+ * \param body The body of the function.
+ * \param name The name of the function.
+ * \param api_args Arguments to the function, can be either Var, or Buffer
+ * \param num_packed_args Number of arguments that are processed in packed form.
+ * \return a LoweredFunc with the specified signiture.
+ *
+ * \note
+ *  The function signiture have two cases
+ *
+ *  if num_packed_args is zero:
+ *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
+ *
+ *  if num_packed_args is not zero:
+ *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
+ *         api_arg_k, api_arg_k+1, ... api_arg_n)
+ *
+ *       where n == len(api_args), k == num_packed_args
+ *
+ *  There is no thread_axis in generated function.
+ */
+LoweredFunc MakeAPI(Stmt body,
+                    std::string name,
+                    Array<NodeRef> api_args,
+                    int num_packed_args);
+
+/*!
+ * \brief Count number of undefined vars in f.
+ * \param f The function to be checked.
+ * \return Number of undefined vars.
+ */
+Array<Var> UndefinedVars(const LoweredFunc& f);
+
+/*!
+ * \brief Split the function into a host function and device functions.
+ * \param func The function to be splitted.
+ *
+ * \return Array of functions, the first one is host function,
+ *     the others are device functions.
+ */
+Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
+
+}  // namespace codegen
+}  // namespace tvm
+
+#endif  // TVM_CODEGEN_H_
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index dd53d53b2..067610421 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
   static constexpr const char* Min = "Min";
 };
 
+/*! \brief namespace of TVM Intrinsic functions */
+namespace intrinsic {
+// Most of the intrinsics is to enab
+/*!
+ * \brief See pesudo code
+ *
+ *  Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
+ *     assert(arg_type_id[i] == typeid(Type));
+ *     return args[i];
+ *  }
+ */
+constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
+/*!
+ * \brief See pesudo code
+ *
+ *  Type tvm_array_get_field(TVMArray* arr, int field_id) {
+ *     return arr->field;
+ *  }
+ * \sa TVMArrayFieldKind
+ */
+constexpr const char* tvm_array_get_field = "tvm_array_get_field";
+/*!
+ * \brief See pesudo code
+ *
+ *  bool tvm_handle_is_null(void* handle) {
+ *     return handle == nullptr
+ *  }
+ */
+constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
+
+/*! \brief The field id of each field in array */
+enum TVMArrayFieldKind {
+  kData = 0,
+  kNDim = 1,
+  kShape = 2,
+  kStrides = 3,
+  kTypeCode = 4,
+  kTypeBits = 5,
+  kTypeLanes = 6
+};
+}   // namespace intrinsic
+
 // Reuse IR node defintiion from HalideIR
 using Halide::Internal::IntImm;
 using Halide::Internal::UIntImm;
diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h
index 3106df1ff..b57bca25e 100644
--- a/include/tvm/ir_mutator.h
+++ b/include/tvm/ir_mutator.h
@@ -9,6 +9,7 @@
 #include <tvm/ir_functor.h>
 #include <unordered_map>
 #include "./expr.h"
+#include "./ir.h"
 
 namespace tvm {
 namespace ir {
@@ -51,6 +52,20 @@ class IRMutator {
   static FMutateExpr& vtable_expr();  // NOLINT(*)
   /*! \return internal stmt of expr */
   static FMutateStmt& vtable_stmt();  // NOLINT(*)
+  // Set of overloadable functions
+  // The underscore allows Mutate not to be shadowed by inheritance
+  virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
+  virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
+  virtual Stmt Mutate_(const For* op, const Stmt& s);
+  virtual Stmt Mutate_(const Provide* op, const Stmt& s);
+  virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
+  virtual Stmt Mutate_(const Realize* op, const Stmt& s);
+  virtual Stmt Mutate_(const Store* op, const Stmt& s);
+  virtual Stmt Mutate_(const Free* op, const Stmt& s);
+  virtual Expr Mutate_(const Call* op, const Expr& e);
+  virtual Expr Mutate_(const Load* op, const Expr& s);
+  virtual Expr Mutate_(const Variable* op, const Expr& e);
+  virtual Expr Mutate_(const Let* op, const Expr& e);
 };
 
 /*!
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index a45bbbb91..a2c2956a9 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
  */
 bool VerifySSA(const Stmt& ir);
 
+/*!
+ * \brief Whether the expression have side effect.
+ * \return whether expression have side effect
+ */
+bool HasSideEffect(const Expr& e);
+
 /*!
  * \brief Convert a IR node to be SSA form.
  * \param stmt The source statement to be converted.
@@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
             Array<Var> args,
             Expr body);
 
-
 /*!
  * \brief Flatten the multi-dimensional read/write
  *  to single dimensional Load/Store
diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h
index b64406d7e..0df5d3e32 100644
--- a/include/tvm/ir_visitor.h
+++ b/include/tvm/ir_visitor.h
@@ -34,6 +34,17 @@ class IRVisitor {
   using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
   /*! \return internal vtable*/
   static FVisit& vtable();
+  // overloadable visit function.
+  virtual void Visit_(const Variable* op);
+  virtual void Visit_(const AttrStmt* op);
+  virtual void Visit_(const LetStmt* op);
+  virtual void Visit_(const For* op);
+  virtual void Visit_(const Allocate* op);
+  virtual void Visit_(const Load* op);
+  virtual void Visit_(const Store* op);
+  virtual void Visit_(const Let* op);
+  virtual void Visit_(const Free* op);
+  virtual void Visit_(const Call* op);
 };
 
 /*!
diff --git a/include/tvm/module.h b/include/tvm/module.h
new file mode 100644
index 000000000..263fdc2f2
--- /dev/null
+++ b/include/tvm/module.h
@@ -0,0 +1,108 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file module.h
+ * \brief Low level IR module,
+ *  Contains lowered function information.
+ */
+#ifndef TVM_MODULE_H_
+#define TVM_MODULE_H_
+
+#include <tvm/container.h>
+#include <ir/FunctionBase.h>
+#include <string>
+
+#include "./base.h"
+#include "./expr.h"
+#include "./tensor.h"
+
+namespace tvm {
+
+// Internal node container of lowered function.
+class LoweredFuncNode;
+
+// Internal node container of module.
+class ModuleNode;
+
+/*!
+ * \brief LoweredFunc represents function after lowering.
+ *  This is the final IR representation before codegen.
+ */
+class LoweredFunc : public FunctionRef {
+ public:
+  LoweredFunc() {}
+  explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const LoweredFuncNode* operator->() const;
+  /*! \brief specify container node */
+  using ContainerType = LoweredFuncNode;
+};
+
+/*! \brief Node container of LoweredFunc */
+class LoweredFuncNode : public FunctionBaseNode {
+ public:
+  /*! \brief The name of the function */
+  std::string name;
+  /*!
+   * \brief The arguments of the function
+   *  This function can only take pod type(int, float) and void* as arguments.
+   */
+  Array<Var> args;
+  /*!
+   * \brief The IterVar axis of threads
+   *  Each axis need host function to specify a size.
+   * \note Calling convention into LoweredFunc
+   *
+   * Assume we have a LoweredFunc f, a call into f
+   *   Call(f, arg1, arg2, ..., arg_n,
+   *        size_axis_1, size_axis_2, ... size_axis_m)
+   *
+   * Here n = len(args), m = len(thread_axis)
+   *
+   * The CodeGen should take this and translate this call
+   * to corresponding API specific kernel launchs or function calls.
+   */
+  Array<IterVar> thread_axis;
+  /*!
+   * \brief The hint data type of Var handles defined in LetStmt
+   *  Can be used as hint when generating type signiture.
+   *  The creation rule is given by
+   *  handle_data_type[var_handle] = make_const(the_type, 0);
+   *
+   * \note Expr is used instead Type, because Type cannot be hold by Map.
+   *  constant Expr of given type is used.
+   */
+  Map<Var, Expr> handle_data_type;
+  /*! \brief The body statment of the function */
+  Stmt body;
+  /*! \return name of the operation */
+  const std::string& func_name() const final {
+    return name;
+  }
+  // there is no return value, but return 1
+  // to enable Call into this function.
+  int num_outputs() const final {
+    return 1;
+  }
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("thread_axis", &thread_axis);
+    v->Visit("handle_data_type", &handle_data_type);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "LoweredFunc";
+  TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
+};
+
+// Implementations of inline functions
+inline const LoweredFuncNode* LoweredFunc::operator->() const {
+  return static_cast<const LoweredFuncNode*>(node_.get());
+}
+
+}  // namespace tvm
+
+#endif  // TVM_MODULE_H_
diff --git a/python/tvm/collections.py b/python/tvm/collections.py
index 85e629cc9..2e43e2e6b 100644
--- a/python/tvm/collections.py
+++ b/python/tvm/collections.py
@@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
 class Buffer(NodeBase):
     """Represent a Buffer in TVM."""
     pass
+
+
+@register_node
+class LoweredFunc(NodeBase):
+    """Represent a LoweredFunc in TVM."""
+    pass
diff --git a/src/base/common.h b/src/base/common.h
index ea2f4bdad..432ec74db 100644
--- a/src/base/common.h
+++ b/src/base/common.h
@@ -7,6 +7,7 @@
 #define TVM_BASE_COMMON_H_
 
 #include <tvm/base.h>
+#include <tvm/expr.h>
 #include <string>
 
 namespace tvm {
@@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
   } else if (s.substr(0, 5) == "float") {
     code = Type::Float; s = s.substr(5);
   } else if (s == "handle") {
-    return Type(Type::Handle, 32, 1);
+    return Handle();
   } else {
     LOG(FATAL) << "unknown type " << s;
   }
diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc
index 365033ea4..0fa5973a4 100644
--- a/src/c_api/c_api_codegen.cc
+++ b/src/c_api/c_api_codegen.cc
@@ -5,6 +5,7 @@
  */
 #include <tvm/expr.h>
 #include <tvm/ir.h>
+#include <tvm/codegen.h>
 
 #include "./c_api_registry.h"
 #include "../codegen/codegen_c.h"
@@ -17,9 +18,19 @@ using RetValue = APIVariantValue;
 
 TVM_REGISTER_API(_codegen_CompileToC)
 .set_body([](const ArgStack& args, RetValue *ret) {
-    *ret = CodeGenC().Compile(
+    *ret = CodeGenC().Compile(args.at(0), args.at(1));
+  });
+
+TVM_REGISTER_API(_codegen_MakeAPI)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    *ret = MakeAPI(
         args.at(0), args.at(1), args.at(2), args.at(3));
   });
 
+TVM_REGISTER_API(_codegen_SplitHostDevice)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    *ret = SplitHostDevice(args.at(0));
+  });
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
index a42569e9a..327778db0 100644
--- a/src/codegen/codegen_c.cc
+++ b/src/codegen/codegen_c.cc
@@ -9,24 +9,27 @@ namespace codegen {
 
 using namespace ir;
 
-std::string CodeGenC::Compile(
-    Stmt stmt, std::string fun_name,
-    Array<Var> args, bool output_ssa) {
+std::string CodeGenC::Compile(LoweredFunc f,
+                              bool output_ssa) {
   print_ssa_form_ = output_ssa;
   // skip the first underscore, so SSA variable starts from _1
   if (print_ssa_form_) GetUniqueName("_");
+  // add to alloc buffer type.
+  for (const auto & kv : f->handle_data_type) {
+    HandleTypeRegister(kv.first.get(), kv.second.type());
+  }
 
   this->indent += 2;
-  this->stream << "void " << fun_name << "(";
-  for (size_t i = 0; i < args.size(); ++i) {
-    Var v = args[i];
+  this->stream << "void " << f->name << "(";
+  for (size_t i = 0; i < f->args.size(); ++i) {
+    Var v = f->args[i];
     std::string vid = AllocVarID(v.get());
     if (i != 0) stream << ", ";
     PrintType(v.type(), stream);
     stream << ' ' << vid;
   }
   stream << ") {\n";
-  this->PrintStmt(stmt);
+  this->PrintStmt(f->body);
   this->indent -= 2;
   this->PrintIndent();
   this->stream << "}\n";
@@ -104,12 +107,22 @@ std::string CodeGenC::GetVarID(const Variable* v) const {
   return it->second;
 }
 
-bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const {
-  auto it = alloc_buf_type_.find(buf_var);
-  if (it == alloc_buf_type_.end()) return false;
+bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
+  auto it = handle_data_type_.find(buf_var);
+  if (it == handle_data_type_.end()) return false;
   return it->second == t;
 }
 
+void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) {
+  auto it = handle_data_type_.find(buf_var);
+  if (it == handle_data_type_.end()) {
+    handle_data_type_[buf_var] = t;
+  } else {
+    CHECK(it->second == t)
+        << "conflicting buf var type";
+  }
+}
+
 void CodeGenC::PrintIndent() {
   for (int i = 0; i < this->indent; ++i) {
     this->stream << ' ';
@@ -234,6 +247,18 @@ inline void PrintBinaryExpr(const T* op,
   os << ')';
 }
 
+inline void PrintBinaryIntrinsitc(const Call* op,
+                                  const char *opstr,
+                                  std::ostream& os,  // NOLINT(*)
+                                  CodeGenC* p) {
+  CHECK_EQ(op->args.size(), 2U);
+  os << '(';
+  p->PrintExpr(op->args[0], os);
+  os << opstr;
+  p->PrintExpr(op->args[1], os);
+  os << ')';
+}
+
 TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
 .set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) {  // NOLINT(*)
     p->PrintType(op->type, os);
@@ -300,24 +325,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
 .set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
     os << '!';
     p->PrintExpr(op->a, os);
-  })
-.set_dispatch<Call>([](const Call *op, std::ostream& os, CodeGenC* p) {  // NOLINT(*)
-    os << op->name << "(";
-    for (size_t i = 0; i < op->args.size(); i++) {
-      p->PrintExpr(op->args[i], os);
-      if (i < op->args.size() - 1) {
-        os << ", ";
-      }
-    }
-    os << ")";
   });
 
 TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
-.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) {
-    std::string cond = p->PrintExpr(op->condition);
-    p->PrintIndent();
-    p->stream << "assert(" << cond << ");\n";
-  })
 .set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
     p->PrintStmt(op->body);
   })
@@ -372,14 +382,95 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
 
 TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
 .DISPATCH_EXPR(Load)
+.DISPATCH_EXPR(Call)
 .DISPATCH_EXPR(Let)
 .DISPATCH_EXPR(Ramp)
 .DISPATCH_EXPR(Broadcast)
 .DISPATCH_EXPR(Select);
 
+
+void CodeGenC::PrintExpr(const Call *op, std::ostream& os) {  // NOLINT(*)
+  CodeGenC* p = this;
+  if (op->is_intrinsic(Call::bitwise_and)) {
+    PrintBinaryIntrinsitc(op, " & ", os, p);
+  } else if (op->is_intrinsic(Call::bitwise_xor)) {
+    PrintBinaryIntrinsitc(op, " ^ ", os, p);
+  } else if (op->is_intrinsic(Call::bitwise_or)) {
+    PrintBinaryIntrinsitc(op, " | ", os, p);
+  } else if (op->is_intrinsic(Call::bitwise_not)) {
+    CHECK_EQ(op->args.size(), 1U);
+    os << "(~";
+    p->PrintExpr(op->args[0], os);
+    os << ')';
+  } else if (op->is_intrinsic(Call::shift_left)) {
+    PrintBinaryIntrinsitc(op, " << ", os, p);
+  } else if (op->is_intrinsic(Call::shift_right)) {
+    PrintBinaryIntrinsitc(op, " >> ", os, p);
+  } else if (op->is_intrinsic(Call::address_of)) {
+    const Load *l = op->args[0].as<Load>();
+    CHECK(op->args.size() == 1 && l);
+    os << "((";
+    p->PrintType(l->type.element_of(), os);
+    os << " *)" << p->GetVarID(l->buffer_var.get())
+       << " + ";
+    p->PrintExpr(l->index, os);
+    os << ')';
+  } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
+    CHECK_EQ(op->args.size(), 3U);
+    if (!op->type.is_handle()) {
+      os << '(';
+      p->PrintType(op->type, os);
+      os << ')';
+    }
+    os << "(((TVMArg*)";
+    p->PrintExpr(op->args[0], os);
+    os << ")[" << op->args[2] << "].";
+    if (op->type.is_handle()) {
+      os << "v_handle";
+    } else if (op->type.is_float()) {
+      os << "v_double";
+    } else if (op->type.is_int() || op->type.is_uint()) {
+      os << "v_long";
+    } else {
+      LOG(FATAL) << "donot know how to handle type" << op->type;
+    }
+    os << ")";
+  } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
+    CHECK_EQ(op->args.size(), 2U);
+    os << "(((TVMArray*)";
+    p->PrintExpr(op->args[0], os);
+    os << ")->";
+    switch (op->args[1].as<IntImm>()->value) {
+      case intrinsic::kData: os << "data"; break;
+      case intrinsic::kShape: os << "shape"; break;
+      case intrinsic::kStrides: os << "strides"; break;
+      case intrinsic::kNDim: os << "ndim"; break;
+      case intrinsic::kTypeCode: os << "dtype.type_code"; break;
+      case intrinsic::kTypeBits: os << "dtype.bits"; break;
+      case intrinsic::kTypeLanes: os << "dtype.lanes"; break;
+      default: LOG(FATAL) << "unknown field code";
+    }
+    os << ')';
+  } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
+    CHECK_EQ(op->args.size(), 1U);
+    os << "(";
+    p->PrintExpr(op->args[0], os);
+    os << " == NULL)";
+  } else {
+    os << op->name << "(";
+    for (size_t i = 0; i < op->args.size(); i++) {
+      p->PrintExpr(op->args[i], os);
+      if (i < op->args.size() - 1) {
+        os << ", ";
+      }
+    }
+    os << ")";
+  }
+}
+
 void CodeGenC::PrintExpr(const Load* op, std::ostream& os) {  // NOLINT(*)
   std::string vid = GetVarID(op->buffer_var.get());
-  if (!BufferTypeMatch(op->buffer_var.get(), op->type)) {
+  if (!HandleTypeMatch(op->buffer_var.get(), op->type)) {
     os << "((const ";
     PrintType(op->type, os);
     os << "*)" << vid << ')';
@@ -416,7 +507,8 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
 .set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
 .set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
 .set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
-.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); });
+.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); })
+.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); });
 
 
 void CodeGenC::PrintStmt(const LetStmt* op) {
@@ -426,10 +518,20 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
     var_idmap_[op->var.get()] = value;
   } else {
     PrintIndent();
-    PrintType(op->var.type(), this->stream);
-    this->stream << ' '
-       << AllocVarID(op->var.get())
-       << " = " << value << ";\n";
+    if (op->var.type() == Handle() &&
+        handle_data_type_.count(op->var.get())) {
+      PrintType(handle_data_type_.at(op->var.get()), stream);
+      stream << "* "
+             << AllocVarID(op->var.get())
+             << " = (";
+      PrintType(handle_data_type_.at(op->var.get()), stream);
+      stream << "*)"  << value << ";\n";
+    } else {
+      PrintType(op->var.type(), this->stream);
+      this->stream << ' '
+                   << AllocVarID(op->var.get())
+                   << " = " << value << ";\n";
+    }
   }
   PrintStmt(op->body);
 }
@@ -439,7 +541,7 @@ void CodeGenC::PrintStmt(const Store* op) {
   std::string value = this->PrintExpr(op->value);
   this->PrintIndent();
   std::string vid = GetVarID(op->buffer_var.get());
-  if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) {
+  if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) {
     this->stream << "((";
     PrintType(op->value.type(), this->stream);
     this->stream << "*)" << vid << ')';
@@ -452,16 +554,25 @@ void CodeGenC::PrintStmt(const Store* op) {
 }
 
 void CodeGenC::PrintStmt(const Allocate* op) {
-  this->PrintIndent();
-  int32_t constant_size = op->constant_allocation_size();
-  std::string vid = AllocVarID(op->buffer_var.get());
-  CHECK(!op->new_expr.defined());
   CHECK(!is_zero(op->condition));
-  CHECK_GT(constant_size, 0)
-      << "Can only handle constant size stack allocation for now";
-  PrintType(op->type, stream);
-  stream << ' '<< vid << '['
-         << constant_size << "]\n;";
+  std::string vid = AllocVarID(op->buffer_var.get());
+  if (op->new_expr.defined()) {
+    // Prefer global static allocation for the program
+    CHECK_EQ(op->free_function, "nop");
+    std::string new_data = PrintExpr(op->new_expr);
+    this->PrintIndent();
+    PrintType(op->type, stream);
+    stream << "* "<< vid << '=' << new_data << ";\n";
+  } else {
+    this->PrintIndent();
+    int32_t constant_size = op->constant_allocation_size();
+    CHECK_GT(constant_size, 0)
+        << "Can only handle constant size stack allocation for now";
+    PrintType(op->type, stream);
+    stream << ' '<< vid << '['
+           << constant_size << "]\n;";
+  }
+  HandleTypeRegister(op->buffer_var.get(), op->type);
   this->PrintStmt(op->body);
 }
 
@@ -469,15 +580,29 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
   if (op->type_key == "scope") {
     IterVar iv(op->node.node_);
     if (iv->thread_tag.length() != 0) {
-      this->PrintIndent();
-      PrintType(iv->var.type(), stream);
-      stream << ' '
-             << AllocVarID(iv->var.get())
-             << " = " << iv->thread_tag << ";\n";
+      if (!var_idmap_.count(iv->var.get())) {
+        this->PrintIndent();
+        PrintType(iv->var.type(), stream);
+        stream << ' '
+               << AllocVarID(iv->var.get())
+               << " = " << iv->thread_tag << ";\n";
+      }
     }
   }
   this->PrintStmt(op->body);
 }
 
+void CodeGenC::PrintStmt(const AssertStmt* op) {
+  std::string cond = PrintExpr(op->condition);
+  PrintIndent();
+  if (op->message.as<StringImm>()) {
+    // GLOG style check
+    stream << "CHECK(" << cond << ") << \""
+           << op->message.as<StringImm>()->value << "\";\n";
+  } else {
+    stream << "assert(" << cond << ");\n";
+  }
+}
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h
index a8ce1828e..4630e9990 100644
--- a/src/codegen/codegen_c.h
+++ b/src/codegen/codegen_c.h
@@ -8,6 +8,7 @@
 
 #include <tvm/ir.h>
 #include <tvm/ir_visitor.h>
+#include <tvm/module.h>
 #include <string>
 #include <unordered_map>
 
@@ -23,16 +24,12 @@ class CodeGenC {
  public:
   /*!
    * \brief Generate the C code of statement
-   * \param body The body of the function.
-   * \param fun_name The name of the function.
-   * \param args The arguments to the function.
+   * \param f The function to be compiled
    * \param output_ssa Whether output ssa form.
    * \note Only call compile once,
    *  create a new codegen object each time.
    */
-  std::string Compile(Stmt body,
-                      std::string fun_name,
-                      Array<Var> args,
+  std::string Compile(LoweredFunc f,
                       bool output_ssa);
   /*!
    * \brief Print the Stmt n to CodeGenC->stream
@@ -49,7 +46,7 @@ class CodeGenC {
    * \brief Same as PrintExpr, but simply returns result string
    * \param n The expression to be printed.
    */
-  inline std::string PrintExpr(const Expr& n) {
+  std::string PrintExpr(const Expr& n) {
     std::ostringstream os;
     PrintExpr(n, os);
     return os.str();
@@ -85,7 +82,9 @@ class CodeGenC {
   virtual void PrintStmt(const ir::Store* op);
   virtual void PrintStmt(const ir::Allocate* op);
   virtual void PrintStmt(const ir::AttrStmt* op);
+  virtual void PrintStmt(const ir::AssertStmt* op);
   virtual void PrintExpr(const ir::Load* op, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExpr(const ir::Call* op, std::ostream& os);  // NOLINT(*)
   virtual void PrintExpr(const ir::Let* op, std::ostream& os);  // NOLINT(*)
   virtual void PrintExpr(const ir::Ramp* op, std::ostream& os);  // NOLINT(*)
   virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os);  // NOLINT(*)
@@ -116,7 +115,13 @@ class CodeGenC {
    * \param buf_var The buffer variable.
    * \param t The type to be checked.
    */
-  bool BufferTypeMatch(const Variable* buf_var, Type t) const;
+  bool HandleTypeMatch(const Variable* buf_var, Type t) const;
+  /*!
+   * \brief Register the data type of buf_var
+   * \param buf_var The buffer variable.
+   * \param t The type to be checked.
+   */
+  void HandleTypeRegister(const Variable* buf_var, Type t);
   /*!
    * \brief get a unique name with the corresponding prefix
    * \param prefix The prefix of the name
@@ -128,7 +133,7 @@ class CodeGenC {
   /*! \brief name of each variable */
   std::unordered_map<const Variable*, std::string> var_idmap_;
   /*! \brief the data type of allocated buffers */
-  std::unordered_map<const Variable*, Type> alloc_buf_type_;
+  std::unordered_map<const Variable*, Type> handle_data_type_;
   /*! \brief name allocation map */
   std::unordered_map<std::string, int> name_alloc_map_;
   /*! \brief assignment map of ssa */
diff --git a/src/codegen/make_api.cc b/src/codegen/make_api.cc
new file mode 100644
index 000000000..227faf37f
--- /dev/null
+++ b/src/codegen/make_api.cc
@@ -0,0 +1,200 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file make_api.cc Build API function.
+ */
+#include <tvm/codegen.h>
+#include <tvm/ir.h>
+#include <tvm/buffer.h>
+
+#include <vector>
+#include <utility>
+#include <unordered_set>
+
+#include "../pass/ir_util.h"
+
+namespace tvm {
+namespace codegen {
+using namespace ir;
+
+inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
+  return Call::make(
+      t, intrinsic::tvm_array_get_field,
+      {arr, IntImm::make(Int(32), kind)},
+      Call::PureIntrinsic);
+}
+
+inline Stmt AssertNull(Var handle, std::string msg) {
+  return AssertStmt::make(Call::make(
+      Bool(1), intrinsic::tvm_handle_is_null,
+      {handle}, Call::PureIntrinsic), msg);
+}
+
+inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
+  return AssertStmt::make(lhs == rhs, msg);
+}
+
+LoweredFunc MakeAPI(Stmt body,
+                    std::string name,
+                    Array<NodeRef> api_args,
+                    int num_packed_args) {
+  const Type tvm_index_type = UInt(32);
+  const Stmt nop = Evaluate::make(0);
+  // Data field definitions
+  // The packed fields
+  Var v_packed_args("args", Handle());
+  Var v_packed_arg_type_ids("arg_type_ids", Handle());
+  Var v_num_packed_args("num_args", Int(32));
+  // The arguments of the function.
+  Array<Var> args;
+  // seq_init gives sequence of initialization
+  // seq_check gives sequence of later checks after iniit
+  std::vector<Stmt> seq_init, seq_check;
+  std::unordered_set<const Variable*> visited;
+  // the handle data types
+  Map<Var, Expr> handle_data_type;
+  // ---------------------------
+  // local function defintiions
+  // load i-th argument as type t
+  auto f_arg_value = [&](Type t, int i) {
+    Array<Expr> call_args{
+      v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)};
+    return Call::make(
+        t, intrinsic::tvm_api_load_arg, call_args,
+        Call::PureIntrinsic);
+  };
+  // get declaration of argument i
+  auto f_arg_decl = [&](int i) {
+    std::ostringstream os;
+    os << "arg" << i;
+    const Variable* v = api_args[i].as<Variable>();
+    return Var(os.str(), v ? v->type: Handle());
+  };
+  // Push related into assertions or variable defintion
+  // given the symbolic declaration and concrete value
+  auto f_push = [&](Expr sym, Expr value, std::string field) {
+    if (sym.as<Variable>()) {
+      // If sym is a Variable and this Variable is not yet defined
+      // add this to defintion.
+      Var v(sym.node_);
+      if (!visited.count(v.get())) {
+        seq_init.emplace_back(LetStmt::make(v, value, nop));
+        visited.insert(v.get());
+        return true;
+      }
+    }
+    // otherwise, assume sym is already defined, insert assertion.
+    std::ostringstream os;
+    os << "Field " << field << " has a unsatisfied constraint";
+    seq_check.emplace_back(MakeAssertEQ(sym, value, os.str()));
+    return false;
+  };
+  // ---------------------------
+  // start of logics
+  // add signiture for packed arguments.
+  if (num_packed_args != 0) {
+    args.push_back(v_packed_args);
+    args.push_back(v_packed_arg_type_ids);
+    args.push_back(v_num_packed_args);
+    std::ostringstream os;
+    os << "expected num_args to be " << num_packed_args;
+    seq_init.emplace_back(
+        MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
+  }
+
+  for (size_t i = 0; i < api_args.size(); ++i) {
+    Var v_arg = f_arg_decl(i);
+    if (i < static_cast<size_t>(num_packed_args)) {
+      seq_init.emplace_back(LetStmt::make(
+          v_arg, f_arg_value(v_arg.type(), i), nop));
+    } else {
+      args.push_back(v_arg);
+    }
+    // add checks for functions.
+    if (api_args[i].as<Variable>()) {
+      f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint);
+    } else {
+      // Buffer checks
+      CHECK(api_args[i].as<BufferNode>())
+          << "api_args can only be Buffer or Var";
+      Buffer buf(api_args[i].node_);
+      // dimension checks
+      Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim);
+      std::ostringstream ndim_err_msg;
+      ndim_err_msg << "arg_" << i
+                   << ".ndim is expected to equal "
+                   << buf->shape.size();
+      seq_init.emplace_back(
+          MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()),
+                       ndim_err_msg.str()));
+      // type checks
+      Type dtype = buf->dtype;
+      std::ostringstream type_err_msg;
+      type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
+      Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) ==
+                   UIntImm::make(UInt(8), dtype.code()) &&
+                   TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) ==
+                   UIntImm::make(UInt(8), dtype.bits()) &&
+                   TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) ==
+                   UIntImm::make(UInt(16), dtype.lanes()));
+      seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
+      // Data Field
+      if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
+                 v_arg->name_hint + ".data")) {
+        Var vptr(buf->ptr);
+        handle_data_type.Set(vptr, make_const(buf->dtype, 0));
+      }
+      // shape field
+      Var v_shape(v_arg->name_hint + ".shape", Handle());
+      handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0));
+      seq_init.emplace_back(LetStmt::make(
+          v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
+      for (size_t k = 0; k < buf->shape.size(); ++k) {
+        std::ostringstream field_name;
+        field_name << v_shape->name_hint << '[' << k << ']';
+        f_push(buf->shape[k],
+               cast(buf->shape[k].type(),
+                    Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))),
+               field_name.str());
+      }
+      // strides field
+      Var v_strides(v_arg->name_hint + ".strides", Handle());
+      handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0));
+      seq_init.emplace_back(LetStmt::make(
+          v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
+      if (buf->strides.size() == 0) {
+        std::ostringstream stride_err_msg;
+        stride_err_msg << "arg_" << i << ".strides:"
+                       << " expected to be nullptr for contiguous array";
+        seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
+      } else {
+        for (size_t k = 0; k < buf->strides.size(); ++k) {
+          std::ostringstream field_name;
+          field_name << v_strides->name_hint << '[' << k << ']';
+          f_push(buf->strides[k],
+                 cast(buf->shape[k].type(),
+                      Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))),
+                 field_name.str());
+        }
+      }
+    }
+  }
+
+  std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
+  n->name = name;
+  n->args = args;
+  n->handle_data_type = handle_data_type;
+  n->body = MergeNest({seq_init, seq_check}, body);
+  LoweredFunc f(n);
+  Array<Var> undefined = UndefinedVars(f);
+  if (undefined.size() != 0) {
+    std::ostringstream os;
+    for (Var v : undefined) {
+      os << " \'" << v->name_hint << "\' ";
+    }
+    os << " does not appeared in api_args";
+    LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
+  }
+  return f;
+}
+}  // namespace codegen
+}  // namespace tvm
diff --git a/src/codegen/split_host_device.cc b/src/codegen/split_host_device.cc
new file mode 100644
index 000000000..1560fda4e
--- /dev/null
+++ b/src/codegen/split_host_device.cc
@@ -0,0 +1,218 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file split_host_device.cc
+ * \brief Split device function from host.
+ */
+#include <tvm/codegen.h>
+#include <tvm/ir.h>
+#include <tvm/module.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+#include <unordered_map>
+
+namespace tvm {
+namespace codegen {
+
+using namespace ir;
+
+// use/def analysis, also delete unreferenced lets
+class IRUseDefAnalysis : public IRMutator {
+ public:
+  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
+    if (op->type_key == "thread_extent") {
+      IterVar iv(op->node.node_);
+      CHECK_NE(iv->thread_tag.length(), 0U);
+      // thread_extent can appear multiple times
+      // use the first appearance as def.
+      if (!use_count_.count(iv->var.get())) {
+        this->HandleDef(iv->var.get());
+        thread_axis_.push_back(iv);
+        thread_extent_.push_back(op->value);
+      }
+
+      Expr value = op->value;
+      if (visit_thread_extent_) {
+        value = this->Mutate(value);
+      }
+      Stmt body = this->Mutate(op->body);
+      if (value.same_as(value) && body.same_as(body)) return s;
+      return AttrStmt::make(op->node, op->type_key, value, body);
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+
+  Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
+    this->HandleDef(op->var.get());
+    Stmt body = this->Mutate(op->body);
+    // eliminate unreferenced let
+    if (use_count_.at(op->var.get()) == 0 &&
+        !HasSideEffect(op->value)) {
+      return body;
+    } else {
+      Expr value = this->Mutate(op->value);
+      if (body.same_as(op->body) &&
+          value.same_as(op->value)) {
+        return s;
+      } else {
+        return LetStmt::make(op->var, value, body);
+      }
+    }
+  }
+
+  Stmt Mutate_(const For *op, const Stmt& s) final {
+    this->HandleDef(op->loop_var.get());
+    return IRMutator::Mutate_(op, s);
+  }
+
+  Stmt Mutate_(const Allocate *op, const Stmt& s) final {
+    this->HandleDef(op->buffer_var.get());
+    return IRMutator::Mutate_(op, s);
+  }
+
+  Stmt Mutate_(const Store *op, const Stmt& s) final {
+    this->HandleUse(op->buffer_var);
+    return IRMutator::Mutate_(op, s);
+  }
+
+  Expr Mutate_(const Let *op, const Expr& e) final {
+    this->HandleDef(op->var.get());
+    Expr body = this->Mutate(op->body);
+    // eliminate unreferenced let
+    if (use_count_.at(op->var.get()) == 0 &&
+        !HasSideEffect(op->value)) {
+      return body;
+    } else {
+      Expr value = this->Mutate(op->value);
+      if (body.same_as(op->body) &&
+          value.same_as(op->value)) {
+        return e;
+      } else {
+        return Let::make(op->var, value, body);
+      }
+    }
+  }
+
+  Expr Mutate_(const Variable *op, const Expr& e) final {
+    this->HandleUse(e);
+    return IRMutator::Mutate_(op, e);
+  }
+
+  Expr Mutate_(const Load *op, const Expr& e) final {
+    this->HandleUse(op->buffer_var);
+    return IRMutator::Mutate_(op, e);
+  }
+
+  void HandleDef(const Variable* v) {
+    CHECK(!use_count_.count(v))
+        << "variable is already defined";
+    use_count_[v] = 0;
+  }
+
+  void HandleUse(const Expr& v) {
+    CHECK(v.as<Variable>());
+    Var var(v.node_);
+    auto it = use_count_.find(var.get());
+    if (it != use_count_.end()) {
+      if (it->second >= 0) {
+        ++it->second;
+      }
+    } else {
+      undefined_.push_back(var);
+      use_count_[var.get()] = -1;
+    }
+  }
+
+  // The fields are publically readible to
+  // be accessible to the users.
+  bool visit_thread_extent_{true};
+  Array<Var> undefined_;
+  Array<IterVar> thread_axis_;
+  Array<Expr> thread_extent_;
+  std::unordered_map<const Variable*, int> use_count_;
+};
+
+class HostDeviceSplitter : public IRMutator {
+ public:
+  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
+    if (op->type_key == "thread_extent") {
+      LOG(INFO) << "??";
+      IterVar iv(op->node.node_);
+      return SplitDeviceFunc(s);
+    }
+    return IRMutator::Mutate_(op, s);
+  }
+
+  Array<LoweredFunc> Split(LoweredFunc f) {
+    for (auto kv : f->handle_data_type) {
+      handle_data_type_[kv.first.get()] = kv.second;
+    }
+    name_ = f->name;
+    std::shared_ptr<LoweredFuncNode> n =
+        std::make_shared<LoweredFuncNode>(*f.operator->());
+    n->body = this->Mutate(f->body);
+
+    Array<LoweredFunc> ret{LoweredFunc(n)};
+    for (LoweredFunc x : device_funcs_) {
+      ret.push_back(x);
+    }
+    return ret;
+  }
+
+ private:
+  Stmt SplitDeviceFunc(Stmt body) {
+    std::ostringstream os;
+    os << name_ << "_kernel" << device_funcs_.size();
+    std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
+    // isolate the device function.
+    IRUseDefAnalysis m;
+    m.visit_thread_extent_ = false;
+    n->body = m.Mutate(body);
+    n->name = os.str();
+    n->args = m.undefined_;
+    CHECK_NE(m.thread_extent_.size(), 0U);
+
+    // improve the handle data type
+    for (Var arg : n->args) {
+      auto it = handle_data_type_.find(arg.get());
+      if (it != handle_data_type_.end()) {
+        n->handle_data_type.Set(arg, it->second);
+      }
+    }
+    LoweredFunc f_device(n);
+    Array<Expr> call_args;
+    for (Var arg : n->args) {
+      call_args.push_back(arg);
+    }
+
+    for (Expr ext : m.thread_extent_) {
+      call_args.push_back(ext);
+    }
+    device_funcs_.emplace_back(f_device);
+    return Evaluate::make(Call::make(
+        Int(32), f_device->name, call_args, Call::Extern, f_device));
+  }
+
+  // function name
+  std::string name_;
+  // the device functions
+  std::vector<LoweredFunc> device_funcs_;
+  std::unordered_map<const Variable*, Expr> handle_data_type_;
+};
+
+
+Array<Var> UndefinedVars(const LoweredFunc& f) {
+  IRUseDefAnalysis m;
+  for (Var arg : f->args) {
+    m.use_count_[arg.get()] = 0;
+  }
+  m.Mutate(f->body);
+  return m.undefined_;
+}
+
+Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
+  return HostDeviceSplitter().Split(func);
+}
+
+}  // namespace codegen
+}  // namespace tvm
diff --git a/src/pass/inline.cc b/src/pass/inline.cc
index 085fe738e..de452c364 100644
--- a/src/pass/inline.cc
+++ b/src/pass/inline.cc
@@ -17,36 +17,28 @@ class IRInline : public IRMutator {
   IRInline(FunctionRef f, Array<Var> args, Expr body)
       : f_(f), args_(args), body_(body) {}
 
-  Expr Mutate(Expr expr) final {
-    expr = IRMutator::Mutate(expr);
-    const Call* call = expr.as<Call>();
-    if (call != nullptr && call->func == f_) {
-      CHECK_EQ(call->value_index, 0);
-      return InlineCall(call);
-    } else {
+  Expr Mutate_(const Call* op, const Expr& e) final {
+    Expr expr = IRMutator::Mutate_(op, e);
+    op = expr.as<Call>();
+
+    if (op->func == f_) {
+      CHECK_EQ(op->value_index, 0);
+      Expr expr = body_;
+      CHECK_EQ(args_.size(), op->args.size())
+          << op->args.size() << " vs " << args_.size();
+      for (size_t i = 0; i < args_.size(); ++i) {
+        expr = Let::make(args_[i], op->args[i], expr);
+      }
       return expr;
+    } else {
+      return e;
     }
   }
 
-  Stmt Mutate(Stmt stmt) final {
-    return IRMutator::Mutate(stmt);
-  }
-
  private:
   FunctionRef f_;
   Array<Var> args_;
   Expr body_;
-
-  Expr InlineCall(const Call* op) {
-    Expr expr = body_;
-
-    CHECK_EQ(args_.size(), op->args.size())
-        << op->args.size() << " vs " << args_.size();
-    for (size_t i = 0; i < args_.size(); ++i) {
-      expr = Let::make(args_[i], op->args[i], expr);
-    }
-    return expr;
-  }
 };
 
 Stmt Inline(Stmt stmt,
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index ad0ace10f..85b0589ce 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -58,6 +58,183 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
   }
 }
 
+#define DISPATCH_TO_MUTATE_STMT(OP)                                 \
+  set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) {  \
+      return m->Mutate_(op, s);                                     \
+    })
+
+#define DISPATCH_TO_MUTATE_EXPR(OP)                                 \
+  set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) {  \
+      return m->Mutate_(op, e);                                     \
+    })
+
+TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
+.DISPATCH_TO_MUTATE_STMT(LetStmt)
+.DISPATCH_TO_MUTATE_STMT(AttrStmt)
+.DISPATCH_TO_MUTATE_STMT(Provide)
+.DISPATCH_TO_MUTATE_STMT(Realize)
+.DISPATCH_TO_MUTATE_STMT(Store)
+.DISPATCH_TO_MUTATE_STMT(For)
+.DISPATCH_TO_MUTATE_STMT(Free);
+
+Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
+  Expr value = this->Mutate(op->value);
+  Stmt body = this->Mutate(op->body);
+  if (value.same_as(op->value) &&
+      body.same_as(op->body)) {
+    return s;
+  } else {
+    return LetStmt::make(op->var, value, body);
+  }
+}
+
+Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
+  Expr value = this->Mutate(op->value);
+  Stmt body = this->Mutate(op->body);
+  if (value.same_as(op->value) &&
+      body.same_as(op->body)) {
+    return s;
+  } else {
+    return AttrStmt::make(op->node, op->type_key, value, body);
+  }
+}
+
+Stmt IRMutator::Mutate_(const For *op, const Stmt& s) {
+  Expr min = this->Mutate(op->min);
+  Expr extent = this->Mutate(op->extent);
+  Stmt body = this->Mutate(op->body);
+  if (min.same_as(op->min) &&
+      extent.same_as(op->extent) &&
+      body.same_as(op->body)) {
+    return s;
+  } else {
+    return For::make(
+        op->loop_var, min, extent, op->for_type, op->device_api, body);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
+  IRMutator* m = this;
+  std::vector<Expr> new_extents;
+  bool all_extents_unmodified = true;
+  for (size_t i = 0; i < op->extents.size(); i++) {
+    new_extents.push_back(m->Mutate(op->extents[i]));
+    all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
+  }
+  Stmt body = m->Mutate(op->body);
+  Expr condition = m->Mutate(op->condition);
+  Expr new_expr;
+  if (op->new_expr.defined()) {
+    new_expr = m->Mutate(op->new_expr);
+  }
+  if (all_extents_unmodified &&
+      body.same_as(op->body) &&
+      condition.same_as(op->condition) &&
+      new_expr.same_as(op->new_expr)) {
+    return s;
+  } else {
+    return Allocate::make(
+        op->buffer_var, op->type,
+        new_extents, condition, body,
+        new_expr, op->free_function);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
+  auto new_args = MutateArray(op->args, this);
+  auto new_value = this->Mutate(op->value);
+  if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
+    return s;
+  } else {
+    return Provide::make(op->func, op->value_index, new_value, new_args);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
+  IRMutator* m = this;
+  Halide::Internal::Region new_bounds;
+  bool bounds_changed = false;
+
+  // Mutate the bounds
+  for (size_t i = 0; i < op->bounds.size(); i++) {
+    Expr old_min = op->bounds[i]->min;
+    Expr old_extent = op->bounds[i]->extent;
+    Expr new_min = m->Mutate(old_min);
+    Expr new_extent = m->Mutate(old_extent);
+    if (!new_min.same_as(old_min))  bounds_changed = true;
+    if (!new_extent.same_as(old_extent)) bounds_changed = true;
+    new_bounds.push_back(
+        Range::make_by_min_extent(new_min, new_extent));
+  }
+
+  Stmt body = m->Mutate(op->body);
+  Expr condition = m->Mutate(op->condition);
+  if (!bounds_changed &&
+      body.same_as(op->body) &&
+      condition.same_as(op->condition)) {
+    return s;
+  } else {
+    return Realize::make(op->func, op->value_index,
+                         op->type, new_bounds,
+                         condition, body);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
+  Expr value = this->Mutate(op->value);
+  Expr index = this->Mutate(op->index);
+  if (value.same_as(op->value) && index.same_as(op->index)) {
+    return s;
+  } else {
+    return Store::make(op->buffer_var, value, index);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
+  return s;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
+.DISPATCH_TO_MUTATE_EXPR(Call)
+.DISPATCH_TO_MUTATE_EXPR(Let)
+.DISPATCH_TO_MUTATE_EXPR(Load)
+.DISPATCH_TO_MUTATE_EXPR(Variable);
+
+Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
+  auto new_args = MutateArray(op->args, this);
+  if (op->args.same_as(new_args)) {
+    return e;
+  } else {
+    return Call::make(op->type, op->name, new_args, op->call_type,
+                      op->func, op->value_index);
+  }
+}
+
+Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
+  Expr index = this->Mutate(op->index);
+  if (index.same_as(op->index)) {
+    return e;
+  } else {
+    return Load::make(op->type, op->buffer_var, index);
+  }
+}
+
+
+Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
+  return e;
+}
+
+Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
+  Expr value = this->Mutate(op->value);
+  Expr body = this->Mutate(op->body);
+  if (value.same_as(op->value) &&
+      body.same_as(op->body)) {
+    return e;
+  } else {
+    return Let::make(op->var, value, body);
+  }
+}
+
 TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
 .set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
     Array<IterVar> new_rdom  = MutateRDom(op->rdom, m);
@@ -70,24 +247,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
     }
   });
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
-.set_dispatch<AttrStmt>([](const AttrStmt* op, const Stmt& s, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    Stmt body = m->Mutate(op->body);
-    if (value.same_as(op->value) &&
-        body.same_as(op->body)) {
-      return s;
-    } else {
-      return AttrStmt::make(op->node, op->type_key, value, body);
-    }
-  });
-
 TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
 .set_dispatch<IntImm>(ReturnSelfExpr)
 .set_dispatch<UIntImm>(ReturnSelfExpr)
 .set_dispatch<FloatImm>(ReturnSelfExpr)
-.set_dispatch<StringImm>(ReturnSelfExpr)
-.set_dispatch<Variable>(ReturnSelfExpr);
+.set_dispatch<StringImm>(ReturnSelfExpr);
 
 TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
 .set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
@@ -150,14 +314,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
       return Select::make(cond, t, f);
     }
   })
-.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
-    Expr index = m->Mutate(op->index);
-    if (index.same_as(op->index)) {
-      return e;
-    } else {
-      return Load::make(op->type, op->buffer_var, index);
-    }
-  })
 .set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
     Expr base = m->Mutate(op->base);
     Expr stride = m->Mutate(op->stride);
@@ -175,38 +331,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
     } else {
       return Broadcast::make(value, op->lanes);
     }
-  })
-.set_dispatch<Call>([](const Call *op, const Expr& e, IRMutator* m) {
-    auto new_args = MutateArray(op->args, m);
-    if (op->args.same_as(new_args)) {
-      return e;
-    } else {
-      return Call::make(op->type, op->name, new_args, op->call_type,
-                        op->func, op->value_index);
-    }
-  })
-.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    Expr body = m->Mutate(op->body);
-    if (value.same_as(op->value) &&
-        body.same_as(op->body)) {
-      return e;
-    } else {
-      return Let::make(op->var, value, body);
-    }
   });
 
 TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
-.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    Stmt body = m->Mutate(op->body);
-    if (value.same_as(op->value) &&
-        body.same_as(op->body)) {
-      return s;
-    } else {
-      return LetStmt::make(op->var, value, body);
-    }
-  })
 .set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
     Expr condition = m->Mutate(op->condition);
     Expr message = m->Mutate(op->message);
@@ -225,93 +352,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
       return ProducerConsumer::make(op->func, op->is_producer, body);
     }
   })
-.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
-    Expr min = m->Mutate(op->min);
-    Expr extent = m->Mutate(op->extent);
-    Stmt body = m->Mutate(op->body);
-    if (min.same_as(op->min) &&
-        extent.same_as(op->extent) &&
-        body.same_as(op->body)) {
-      return s;
-    } else {
-      return For::make(
-          op->loop_var, min, extent, op->for_type, op->device_api, body);
-    }
-  })
-.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    Expr index = m->Mutate(op->index);
-    if (value.same_as(op->value) && index.same_as(op->index)) {
-      return s;
-    } else {
-      return Store::make(op->buffer_var, value, index);
-    }
-  })
-.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
-    auto new_args = MutateArray(op->args, m);
-    auto new_value = m->Mutate(op->value);
-    if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
-      return s;
-    } else {
-      return Provide::make(op->func, op->value_index, new_value, new_args);
-    }
-  })
-.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
-    std::vector<Expr> new_extents;
-    bool all_extents_unmodified = true;
-    for (size_t i = 0; i < op->extents.size(); i++) {
-        new_extents.push_back(m->Mutate(op->extents[i]));
-        all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
-    }
-    Stmt body = m->Mutate(op->body);
-    Expr condition = m->Mutate(op->condition);
-    Expr new_expr;
-    if (op->new_expr.defined()) {
-      new_expr = m->Mutate(op->new_expr);
-    }
-    if (all_extents_unmodified &&
-        body.same_as(op->body) &&
-        condition.same_as(op->condition) &&
-        new_expr.same_as(op->new_expr)) {
-      return s;
-    } else {
-      return Allocate::make(
-          op->buffer_var, op->type,
-          new_extents, condition, body,
-          new_expr, op->free_function);
-    }
-  })
-.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) {
-  return s;
-  })
-.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
-    Halide::Internal::Region new_bounds;
-    bool bounds_changed = false;
-
-    // Mutate the bounds
-    for (size_t i = 0; i < op->bounds.size(); i++) {
-        Expr old_min = op->bounds[i]->min;
-        Expr old_extent = op->bounds[i]->extent;
-        Expr new_min = m->Mutate(old_min);
-        Expr new_extent = m->Mutate(old_extent);
-        if (!new_min.same_as(old_min))  bounds_changed = true;
-        if (!new_extent.same_as(old_extent)) bounds_changed = true;
-        new_bounds.push_back(
-            Range::make_by_min_extent(new_min, new_extent));
-    }
-
-    Stmt body = m->Mutate(op->body);
-    Expr condition = m->Mutate(op->condition);
-    if (!bounds_changed &&
-        body.same_as(op->body) &&
-        condition.same_as(op->condition)) {
-      return s;
-    } else {
-      return Realize::make(op->func, op->value_index,
-                           op->type, new_bounds,
-                           condition, body);
-    }
-  })
 .set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
     Stmt first = m->Mutate(op->first);
     Stmt rest = m->Mutate(op->rest);
diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h
new file mode 100644
index 000000000..794dcd820
--- /dev/null
+++ b/src/pass/ir_util.h
@@ -0,0 +1,70 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file ir_util.h
+ * \brief Helper functions to construct and compose IR nodes.
+ */
+#ifndef TVM_PASS_IR_UTIL_H_
+#define TVM_PASS_IR_UTIL_H_
+
+#include <tvm/ir.h>
+#include <vector>
+
+namespace tvm {
+namespace ir {
+
+/*!
+ * \brief combine the nest stmt, whose body is not defined.
+ * \param nest A list of For and LetStmt, whose body is not defined.
+ * \param body body
+ * \return The combined Stmt
+ */
+inline Stmt MergeNest(std::vector<Stmt> nest, Stmt body) {
+  // use reverse iteration
+  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
+    Stmt s = *ri;
+    if (s.as<For>()) {
+      auto n = std::make_shared<For>(*s.as<For>());
+      CHECK(is_no_op(n->body));
+      n->body = body;
+      body = Stmt(n);
+    } else if (s.as<LetStmt>()) {
+      auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
+      CHECK(is_no_op(n->body));
+      n->body = body;
+      body = Stmt(n);
+    } else if (s.as<AttrStmt>()) {
+      auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
+      CHECK(is_no_op(n->body));
+      n->body = body;
+      body = Stmt(n);
+    } else if (s.as<IfThenElse>()) {
+      auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
+      CHECK(is_no_op(n->then_case));
+      CHECK(!n->else_case.defined());
+      n->then_case = body;
+      body = Stmt(n);
+    } else if (s.as<AssertStmt>()) {
+      body = Block::make(s, body);
+    } else {
+      LOG(FATAL) << "not supported nest type";
+    }
+  }
+  return body;
+}
+
+/*!
+ * \brief combine the nest stmt, whose body is not defined.
+ * \param nest A list of For and LetStmt, whose body is not defined.
+ * \param body body
+ * \return The combined Stmt
+ */
+inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
+  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
+    body = MergeNest(*ri, body);
+  }
+  return body;
+}
+
+}  // namespace ir
+}  // namespace tvm
+#endif  // TVM_PASS_IR_UTIL_H_
diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc
index 3bbcbbd00..77ce3928f 100644
--- a/src/pass/ir_visitor.cc
+++ b/src/pass/ir_visitor.cc
@@ -8,7 +8,6 @@
 
 namespace tvm {
 namespace ir {
-namespace {
 // visitor to implement apply
 class IRApplyVisit : public IRVisitor {
  public:
@@ -26,7 +25,6 @@ class IRApplyVisit : public IRVisitor {
   std::unordered_set<const Node*> visited_;
 };
 
-}  // namespace
 
 void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
   IRApplyVisit(fvisit).Visit(node);
@@ -36,12 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() {  // NOLINT(*)
   static FVisit inst; return inst;
 }
 
-
-// namespace to register the functors.
-namespace {
-
-using namespace Halide::Internal;
-
 void NoOp(const NodeRef& n, IRVisitor* v) {
 }
 
@@ -59,24 +51,82 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
   }
 }
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
-    VisitRDom(op->rdom, v);
-    v->Visit(op->source);
-  });
+#define DISPATCH_TO_VISIT(OP)                       \
+  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
+      v->Visit_(op);                                \
+    })
 
 TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
-    v->Visit(op->value);
-    v->Visit(op->body);
-  });
+.DISPATCH_TO_VISIT(Variable)
+.DISPATCH_TO_VISIT(LetStmt)
+.DISPATCH_TO_VISIT(For)
+.DISPATCH_TO_VISIT(Allocate)
+.DISPATCH_TO_VISIT(Load)
+.DISPATCH_TO_VISIT(Store)
+.DISPATCH_TO_VISIT(Let)
+.DISPATCH_TO_VISIT(Call)
+.DISPATCH_TO_VISIT(Free);
+
+void IRVisitor::Visit_(const Variable* op) {}
+
+void IRVisitor::Visit_(const LetStmt *op) {
+  this->Visit(op->value);
+  this->Visit(op->body);
+}
+
+void IRVisitor::Visit_(const AttrStmt* op) {
+  this->Visit(op->value);
+  this->Visit(op->body);
+}
+
+void IRVisitor::Visit_(const For *op) {
+  IRVisitor* v = this;
+  v->Visit(op->min);
+  v->Visit(op->extent);
+  v->Visit(op->body);
+}
+
+void IRVisitor::Visit_(const Allocate *op) {
+  IRVisitor* v = this;
+  for (size_t i = 0; i < op->extents.size(); i++) {
+    v->Visit(op->extents[i]);
+  }
+  v->Visit(op->body);
+  v->Visit(op->condition);
+  if (op->new_expr.defined()) {
+    v->Visit(op->new_expr);
+  }
+}
+
+void IRVisitor::Visit_(const Load *op) {
+  this->Visit(op->index);
+}
+
+void IRVisitor::Visit_(const Store *op) {
+  this->Visit(op->value);
+  this->Visit(op->index);
+}
+
+void IRVisitor::Visit_(const Let *op) {
+  this->Visit(op->value);
+  this->Visit(op->body);
+}
+
+void IRVisitor::Visit_(const Free* op) {}
+
+void IRVisitor::Visit_(const Call *op) {
+  VisitArray(op->args, this);
+}
 
 TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
+.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
+    VisitRDom(op->rdom, v);
+    v->Visit(op->source);
+  })
 .set_dispatch<IntImm>(NoOp)
 .set_dispatch<UIntImm>(NoOp)
 .set_dispatch<FloatImm>(NoOp)
-.set_dispatch<StringImm>(NoOp)
-.set_dispatch<Variable>(NoOp);
+.set_dispatch<StringImm>(NoOp);
 
 TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
 .set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
@@ -116,29 +166,15 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
     v->Visit(op->true_value);
     v->Visit(op->false_value);
   })
-.set_dispatch<Load>([](const Load *op, IRVisitor* v) {
-    v->Visit(op->index);
-  })
 .set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
     v->Visit(op->base);
     v->Visit(op->stride);
   })
 .set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
     v->Visit(op->value);
-  })
-.set_dispatch<Call>([](const Call *op, IRVisitor* v) {
-    VisitArray(op->args, v);
-  })
-.set_dispatch<Let>([](const Let *op, IRVisitor* v) {
-    v->Visit(op->value);
-    v->Visit(op->body);
   });
 
 TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
-    v->Visit(op->value);
-    v->Visit(op->body);
-  })
 .set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
     v->Visit(op->condition);
     v->Visit(op->message);
@@ -146,30 +182,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
 .set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
     v->Visit(op->body);
   })
-.set_dispatch<For>([](const For *op, IRVisitor* v) {
-    v->Visit(op->min);
-    v->Visit(op->extent);
-    v->Visit(op->body);
-  })
-.set_dispatch<Store>([](const Store *op, IRVisitor* v) {
-    v->Visit(op->value);
-    v->Visit(op->index);
-  })
 .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
     VisitArray(op->args, v);
     v->Visit(op->value);
   })
-.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
-    for (size_t i = 0; i < op->extents.size(); i++) {
-      v->Visit(op->extents[i]);
-    }
-    v->Visit(op->body);
-    v->Visit(op->condition);
-    if (op->new_expr.defined()) {
-      v->Visit(op->new_expr);
-    }
-  })
-.set_dispatch<Free>(NoOp)
 .set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
     // Mutate the bounds
     for (size_t i = 0; i < op->bounds.size(); i++) {
@@ -193,6 +209,5 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
     v->Visit(op->value);
   });
 
-}  // namespace
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc
index a62cf678b..c2332a819 100644
--- a/src/pass/schedule_ops.cc
+++ b/src/pass/schedule_ops.cc
@@ -9,6 +9,7 @@
 #include <tvm/schedule_pass.h>
 
 #include "./scope.h"
+#include "./ir_util.h"
 #include "../schedule/graph.h"
 
 namespace tvm {
@@ -32,18 +33,27 @@ void PassUpOffset(const Stage& s,
       Expr outer = state.at(s->outer);
       Expr inner = state.at(s->inner);
       Expr factor = dom_map.at(s->inner)->extent;
-      Expr offset = inner + outer * factor;
-      Expr outer_min = dom_map.at(s->parent)->min;
-      if (!is_zero(outer_min)) {
-        offset = outer_min + offset;
+      Expr parent_min = dom_map.at(s->parent)->min;
+      state[s->parent] = inner + outer * factor;
+      // add min if they exist
+      if (!is_zero(parent_min)) {
+        state[s->parent] = parent_min + state[s->parent];
       }
-      state[s->parent] = offset;
     } else if (rel.as<FuseNode>()) {
       const FuseNode* s = rel.as<FuseNode>();
       Expr value = state.at(s->fused);
       Expr factor = dom_map.at(s->inner)->extent;
+      Expr outer_min = dom_map.at(s->outer)->min;
+      Expr inner_min = dom_map.at(s->inner)->min;
       state[s->outer] = value / factor;
       state[s->inner] = value % factor;
+      // add min if they exist
+      if (!is_zero(outer_min)) {
+        state[s->outer] = outer_min + state[s->outer];
+      }
+      if (!is_zero(inner_min)) {
+        state[s->inner] = outer_min + state[s->inner];
+      }
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -81,45 +91,6 @@ void SplitByAdd(Expr expr,
   }
 }
 
-/*!
- * \brief combine the nest stmt, whose body is not defined.
- * \param nest A list of For and LetStmt, whose body is not defined.
- * \param body body
- */
-Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
-  // use reverse iteration
-  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
-    for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
-      Stmt s = *rj;
-      if (s.as<For>()) {
-        auto n = std::make_shared<For>(*s.as<For>());
-        CHECK(is_no_op(n->body));
-        n->body = body;
-        body = Stmt(n);
-      } else if (s.as<LetStmt>()) {
-        auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
-        CHECK(is_no_op(n->body));
-        n->body = body;
-        body = Stmt(n);
-      } else if (s.as<AttrStmt>()) {
-        auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
-        CHECK(is_no_op(n->body));
-        n->body = body;
-        body = Stmt(n);
-      } else if (s.as<IfThenElse>()) {
-        auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
-        CHECK(is_no_op(n->then_case));
-        CHECK(!n->else_case.defined());
-        n->then_case = body;
-        body = Stmt(n);
-      } else {
-        LOG(FATAL) << "not supported nest type";
-      }
-    }
-  }
-  return body;
-}
-
 /*!
  * \brief Make the loop nest of the correspondings schedule.
  * \param sch The schedule.
@@ -142,16 +113,32 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
 
   for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
     auto iv = leaf_iter_vars[i];
+    Range dom = dom_map.at(iv);
     // initialize the offset and loop_level
     offset[iv] = iv->var;
     loop_level[iv->var.as<Variable>()] = i + 1;
     // Mark the iter var in the IR, to remember the point
     if (iv->thread_tag.length() == 0) {
-      Range dom = dom_map.at(iv);
+      if (is_zero(dom->min)) {
+        nest[i + 1].emplace_back(
+            For::make(iv->var, 0, dom->extent,
+                      ForType::Serial, DeviceAPI::None, no_op));
+      } else {
+        Var idx(iv->var->name_hint + ".idx", iv->var.type());
+        nest[i + 1].emplace_back(
+            For::make(idx, 0, dom->extent,
+                      ForType::Serial, DeviceAPI::None, no_op));
+        nest[i + 1].emplace_back(
+            LetStmt::make(iv->var, dom->min + idx, no_op));
+      }
+    } else {
+      // Always restrict threaded IterVar to starts from 0.
+      CHECK(is_zero(dom->min));
+      // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          For::make(iv->var, dom->min, dom->extent,
-                    ForType::Serial, DeviceAPI::None, no_op));
+          AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
     }
+    // annotate the extent of the IterVar
     nest[i + 1].emplace_back(
         AttrStmt::make(iv, "scope", iv->var, no_op));
   }
diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc
new file mode 100644
index 000000000..389394597
--- /dev/null
+++ b/src/pass/simple_passes.cc
@@ -0,0 +1,36 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file simple_passes.cc
+ * \brief Implementation of simple passes
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/ir_pass.h>
+
+namespace tvm {
+namespace ir {
+
+class IRSideEffect : public IRVisitor {
+ public:
+  void Visit(const NodeRef& e) final {
+    if (has_side_effect_) return;
+  }
+
+  void Visit_(const Call* op) final {
+    if (!op->is_pure()) {
+      has_side_effect_ = true; return;
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  bool has_side_effect_{false};
+};
+
+bool HasSideEffect(const Expr& e) {
+  IRSideEffect v;
+  v.Visit(e);
+  return v.has_side_effect_;
+}
+}  // namespace ir
+}  // namespace tvm
diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py
index 0f0a8df30..dc20dda36 100644
--- a/tests/python/test_codegen_cuda.py
+++ b/tests/python/test_codegen_cuda.py
@@ -24,31 +24,15 @@ def mock_test_add():
     Bb = tvm.Buffer(B.shape, B.dtype, name='B')
     Cb = tvm.Buffer(C.shape, C.dtype, name='C')
     stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
+    stmt = tvm.ir_pass.Simplify(stmt)
     print(stmt)
     output_ssa = False
-    code = tvm.codegen.CompileToC(stmt, "myadd",
-                                  [Ab.ptr, Bb.ptr, Cb.ptr, n],
-                                  output_ssa)
-
-    print(code)
-    def codegen():
-        # generate host/device code
-        host_code, device_code = tvm.codegen.GenCUDA(
-            s,
-            inputs={A: Ab, B:Bb},
-            outputs={C: Cb},
-            args=[A, B, C])
-        # generate a function based on the code
-        f = tvm.cuda.build_function(host_code, device_code)
-        # create arrays
-        a = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0))
-        b = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0))
-        c = tvm.nd.array(np.zeros(10), ctx=tvm.gpu(0))
-        # calll the generated code
-        f(a, b, c)
-        # sync the result
-        np.testing.assert_equal(c.asnumpy(), np.ones(10) * 2)
+    f = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 1)
 
+    f_list = tvm.codegen.SplitHostDevice(f)
+    for x in f_list:
+        code = tvm.codegen.CompileToC(x, output_ssa)
+        print(code)
 
 if __name__ == "__main__":
     mock_test_add()
diff --git a/tests/python/test_codegen_makeapi.py b/tests/python/test_codegen_makeapi.py
new file mode 100644
index 000000000..ebe6f4e63
--- /dev/null
+++ b/tests/python/test_codegen_makeapi.py
@@ -0,0 +1,27 @@
+import tvm
+import numpy
+
+def test_makeapi():
+    """Not yet working, mock design"""
+    n = tvm.Var('n')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
+    s = tvm.Schedule(C.op)
+
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.ir_pass.ScheduleOps(s, bounds)
+
+    Ab = tvm.Buffer(A.shape, A.dtype, name='A')
+    Bb = tvm.Buffer(B.shape, B.dtype, name='B')
+    Cb = tvm.Buffer(C.shape, C.dtype, name='C')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
+    num_packed_args = 2
+    f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
+    assert(f.handle_data_type[Ab.ptr].dtype == Ab.dtype)
+    assert(len(f.args) == 5)
+    output_ssa = False
+
+
+if __name__ == "__main__":
+    test_makeapi()
diff --git a/tests/python/test_pass_storage_flatten.py b/tests/python/test_pass_storage_flatten.py
index b7dff05d0..98200bc7d 100644
--- a/tests/python/test_pass_storage_flatten.py
+++ b/tests/python/test_pass_storage_flatten.py
@@ -18,6 +18,7 @@ def test_flatten2():
     Ab = tvm.Buffer(A.shape, A.dtype, name='A')
     A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
     stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
+    stmt = tvm.ir_pass.Simplify(stmt)
     print(stmt)
 
 if __name__ == "__main__":
-- 
GitLab