diff --git a/HalideIR b/HalideIR
index 5d1bd103c2abe19392b4d8def7e3ff1c854e8683..1ec478bbd0c20b8659f0c897363b5a76e13ef495 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683
+Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index adf0d245d20a84a9bbcfc8508aaaa7e865f8413e..e6057b290088fedfb7b0656992e838276521a184 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -17,6 +17,7 @@ namespace tvm {
 
 using Halide::Type;
 using Halide::Float;
+using Halide::Bool;
 using Halide::Int;
 using Halide::UInt;
 using Halide::Handle;
@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
 using Halide::Internal::IRPrinter;
 using Halide::Internal::Variable;
 
+using Halide::Internal::make_const;
+
 /*! \brief a named variable in TVM */
 class Var : public Halide::VarExpr {
  public:
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 276bba9448f80265e32130e183efa17445508e73..c93d748a185623d5fc56bc7b4a191797ad9c5a2f 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -18,6 +18,16 @@
 namespace tvm {
 namespace ir {
 
+
+/*!
+ * \brief Schedule s' dependent operations.
+ *
+ * \param s The schedule to be realized
+ * \param dom_map The domain of each iter vars.
+ * \return the result Stmt
+ */
+Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
+
 /*!
  * \brief verifies whether the IR stmt or Expr is in SSA form.
  *  That is: each VarExpr is defined and assigned once(in Let/For)
@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
             Expr body,
             Stmt stmt);
 
-/*!
- * \brief Schedule s' dependent operations.
- *
- * \param s The schedule to be realized
- * \return the result Stmt
- */
-Stmt ScheduelOps(Schedule s);
-
 }  // namespace ir
 }  // namespace tvm
 
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index 21ab9e5f90e0b9e0d609a7c16171a5a4e71c609c..aff7d9b2d6376d587a0e8b87ba5b4d3d2305d371 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -12,6 +12,36 @@
 
 namespace tvm {
 
+/*!
+ * \brief A placeholder op represents an input placeholder.
+ */
+class PlaceholderOpNode : public OperationNode {
+ public:
+  /*! \brief The shape of the input */
+  Array<Expr> shape;
+  /*! \brief The data type of the input. */
+  Type dtype;
+
+  int num_outputs() const final {
+    return 1;
+  }
+  Array<IterVar> root_iter_vars() const final;
+  Type output_dtype(size_t i) const final;
+  Array<Expr> output_shape(size_t i) const final;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+  }
+  static Operation make(std::string name,
+                        Array<Expr> shape,
+                        Type dtype);
+
+  static constexpr const char* _type_key = "PlaceholderOp";
+  TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
+};
+
 /*!
  * \brief A Compute op that compute a tensor on certain domain.
  */
@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
   /*! \brief constructor */
   ComputeOpNode() {}
 
-  size_t num_outputs() const final {
+  int num_outputs() const final {
     return 1;
   }
   Array<IterVar> root_iter_vars() const final;
-  std::string output_name(size_t i) const final;
   Type output_dtype(size_t i) const final;
   Array<Expr> output_shape(size_t i) const final;
 
@@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode {
 /*! \brief The compute function to specify the input source of a Tensor */
 using FCompute = std::function<Expr (const Array<Var>& i)>;
 
+/*!
+ * \brief create a place holder tensor.
+ * \param shape The shape of the tensor.
+ * \param dtype the data type of the tensor.
+ * \param name The name of the Tensor.
+ */
+Tensor Placeholder(Array<Expr> shape,
+                   Type dtype = Float(32),
+                   std::string name = "placeholder");
+
 /*!
  * \brief Construct a new tensor by computing over shape,
  *  using the computation rule: result_tensor[axis] = fcompute(axis)
diff --git a/src/schedule/bound.h b/include/tvm/schedule_pass.h
similarity index 50%
rename from src/schedule/bound.h
rename to include/tvm/schedule_pass.h
index 53aa111d9ce4fe254ae12754f504250afb95a001..45b2745c9eab00d4567128c3a9d0d3485e990019 100644
--- a/src/schedule/bound.h
+++ b/include/tvm/schedule_pass.h
@@ -1,14 +1,17 @@
 /*!
  *  Copyright (c) 2016 by Contributors
- * \file bound.h
- * \brief The bound inference logics on the schedule.
+ * \file schedule_pass.h
+ * \brief  Collection of Schedule pass functions.
+ *
+ *  These passes works on the schedule hyper-graph
+ *  and infers information such as bounds, check conditions
+ *  read/write dependencies between the IterVar
  */
-#ifndef TVM_SCHEDULE_BOUND_H_
-#define TVM_SCHEDULE_BOUND_H_
+#ifndef TVM_SCHEDULE_PASS_H_
+#define TVM_SCHEDULE_PASS_H_
 
-#include <tvm/expr.h>
-#include <tvm/schedule.h>
-#include <unordered_map>
+#include "./base.h"
+#include "./schedule.h"
 
 namespace tvm {
 namespace schedule {
@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
 
 }  // namespace schedule
 }  // namespace tvm
-
-#endif  // TVM_SCHEDULE_BOUND_H_
+#endif  // TVM_SCHEDULE_PASS_H_
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
index 0614723597477fda6cbcfec41921e93914567a56..92786b33106df7fd4920100be9701d5fe91c8303 100644
--- a/include/tvm/tensor.h
+++ b/include/tvm/tensor.h
@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
  * \brief Tensor structure representing a possible input,
  *  or intermediate computation result.
  */
-class Tensor : public FunctionRef {
+class Tensor : public NodeRef {
  public:
   /*! \brief default constructor, used internally */
   Tensor() {}
-  explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
-  /*!
-   * \brief constructor of input tensor
-   * \param shape Shape of the tensor.
-   * \param name optional name of the Tensor.
-   * \param dtype The data type of the input tensor.
-   */
-  explicit Tensor(Array<Expr> shape,
-                  std::string name = "tensor",
-                  Type dtype = Float(32));
+  explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
 };
 
 /*! \brief Operation that produces tensors */
-class Operation : public NodeRef {
+class Operation : public FunctionRef {
  public:
   /*! \brief default constructor  */
   Operation() {}
-  explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
+  explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
@@ -137,12 +128,10 @@ class Operation : public NodeRef {
 };
 
 /*! \brief Node to represent a tensor */
-class TensorNode : public FunctionBaseNode {
+class TensorNode : public Node {
  public:
   /*! \brief The shape of the tensor */
   Array<Expr> shape;
-  /*! \brief optional name of the tensor */
-  std::string name;
   /*! \brief data type in the content of the tensor */
   Type dtype;
   /*! \brief the source operation, can be None */
@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("shape", &shape);
-    v->Visit("name", &name);
     v->Visit("dtype", &dtype);
     v->Visit("op", &op);
     v->Visit("value_index", &value_index);
   }
-  const std::string& func_name() const final {
-    return name;
-  }
-  int outputs() const final {
-    return 1;
-  }
   static Tensor make(Array<Expr> shape,
-                     std::string name,
                      Type dtype,
                      Operation op,
                      int value_index);
@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
 /*!
  * \brief base class of operation node.
  */
-class OperationNode : public Node {
+class OperationNode : public FunctionBaseNode {
  public:
   /*! \brief optional name of the operation */
   std::string name;
+  /*! \return name of the operation */
+  const std::string& func_name() const final {
+    return name;
+  }
+  /*! \return number of outputs of this op */
+  virtual int num_outputs() const = 0;
   /*! \return the list of iteration variable at root */
   virtual Array<IterVar> root_iter_vars() const = 0;
-  /*! \return number of outputs of this op */
-  virtual size_t num_outputs() const = 0;
-  /*! \return name of i-th output */
-  virtual std::string output_name(size_t i) const = 0;
   /*! \return type of i-th output */
   virtual Type output_dtype(size_t i) const = 0;
   /*! \return shape of i-th output */
diff --git a/python/tvm/function.py b/python/tvm/function.py
index 7088d051ab54d8bb06a4dca291187333960c3b20..3dde071a6b8293227ce8cabee1de3778a9ead75e 100644
--- a/python/tvm/function.py
+++ b/python/tvm/function.py
@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
     return _function_internal._Var(name, dtype)
 
 
-def placeholder(shape, dtype = None, name="TensorObj"):
+def placeholder(shape, dtype = None, name="placeholder"):
     """Construct an empty tensor object.
 
     Parameters
@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
         The created tensor
     """
     dtype = float32 if dtype is None else dtype
-    return _function_internal._Tensor(
-        shape, name, dtype, None, 0)
+    return _function_internal._Placeholder(
+        shape, dtype, name)
 
 
-def compute(shape, fcompute, name="TensorCompute"):
+def compute(shape, fcompute, name="compute"):
     """Construct a new tensor by computing over the shape domain.
 
     The compute rule is result[axis] = fcompute(axis)
diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py
index b8ed62cc96e8d6d82ac50610153ca9079ef12b87..99e14180aa7beb978f38c2eb1c623020d4fab4a0 100644
--- a/python/tvm/tensor.py
+++ b/python/tvm/tensor.py
@@ -34,7 +34,9 @@ class Tensor(NodeBase):
             else:
                 raise ValueError("The indices must be expression")
 
-        return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
+        return _make.Call(self.dtype, self.op.name,
+                          args, _expr.Call.Halide,
+                          self.op, self.value_index)
 
     def __getitem__(self, indices):
         return TensorSlice(self, indices)
@@ -71,3 +73,7 @@ class Operation(NodeBase):
 @register_node
 class ComputeOp(Operation):
     pass
+
+@register_node
+class PlaceholderOp(Operation):
+    pass
diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc
index 151f9da7a1b3e3a760b9b31200d53f8ba167ef35..af66e9db66900619b014f08c09a9bdff7fe8a0f7 100644
--- a/src/c_api/c_api_ir.cc
+++ b/src/c_api/c_api_ir.cc
@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
                      args.at(5));
   });
 
+TVM_REGISTER_API(_make_Realize)
+.set_body([](const ArgStack& args,  RetValue *ret) {
+    *ret = Realize::make(args.at(0),
+                         args.at(1),
+                         args.at(2),
+                         args.at(3),
+                         args.at(4),
+                         args.at(5));
+  });
+
+
 TVM_REGISTER_API(_make_Call)
 .set_body([](const ArgStack& args,  RetValue *ret) {
     *ret = Call::make(args.at(0),
@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
 REGISTER_MAKE2(AssertStmt);
 REGISTER_MAKE3(ProducerConsumer);
 REGISTER_MAKE3(Store);
-REGISTER_MAKE3(Provide);
+REGISTER_MAKE4(Provide);
 REGISTER_MAKE1(Free);
-// TODO(tqchen) Realize;
 REGISTER_MAKE2(Block);
 REGISTER_MAKE3(IfThenElse);
 REGISTER_MAKE1(Evaluate);
diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc
index 2e5ad3dcec482ba6eb608d93e2449f5327be4da0..46075f1140d4edb102d71dc970c7d73620e6bc59 100644
--- a/src/c_api/c_api_lang.cc
+++ b/src/c_api/c_api_lang.cc
@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
 TVM_REGISTER_API(_Tensor)
 .set_body([](const ArgStack& args,  RetValue *ret) {
     *ret = TensorNode::make(args.at(0),
-                            args.at(1),
                             args.at(2),
                             args.at(3),
                             args.at(4));
@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
         std::hash<Tensor>()(args.at(0).operator Tensor()));
   });
 
+TVM_REGISTER_API(_Placeholder)
+.set_body([](const ArgStack& args,  RetValue *ret) {
+    *ret = Placeholder(args.at(0),
+                       args.at(1),
+                       args.at(2));
+  });
+
 TVM_REGISTER_API(_ComputeOp)
 .set_body([](const ArgStack& args,  RetValue *ret) {
     *ret = ComputeOpNode::make(args.at(0),
diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
index 867a5f6f1293d325f34077b57f16b5a23d2f92bf..2d4cb6e3fb55bbaf12803e899635f8aaa358747e 100644
--- a/src/c_api/c_api_pass.cc
+++ b/src/c_api/c_api_pass.cc
@@ -7,7 +7,6 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include "./c_api_registry.h"
-#include "../schedule/bound.h"
 
 namespace tvm {
 namespace ir {
@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
 REGISTER_PASS1(ConvertSSA);
 REGISTER_PASS1(VerifySSA);
 REGISTER_PASS4(Inline);
+REGISTER_PASS2(ScheduleOps);
 
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/c_api/c_api_schedule.cc b/src/c_api/c_api_schedule.cc
index 7ebe625b7819894a9bf951bf406efc7899127067..6ee41b2d398db2ef4ccf63dc1bc20d3baf99c38b 100644
--- a/src/c_api/c_api_schedule.cc
+++ b/src/c_api/c_api_schedule.cc
@@ -6,8 +6,8 @@
 #include <tvm/expr.h>
 #include <tvm/tensor.h>
 #include <tvm/schedule.h>
+#include <tvm/schedule_pass.h>
 #include "./c_api_registry.h"
-#include "../schedule/bound.h"
 #include "../schedule/graph.h"
 
 namespace tvm {
diff --git a/src/lang/expr.cc b/src/lang/expr.cc
index 1361c0f57a830089e7cdc2f79e3fd2f4d3ed6692..c6ec66fcaa2421be11856d312afc4ea51b5e0ee5 100644
--- a/src/lang/expr.cc
+++ b/src/lang/expr.cc
@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(IterVarNode);
 
-
 }  // namespace tvm
diff --git a/src/lang/ir.cc b/src/lang/ir.cc
index 3b06c8b1cbb1fe4f01eccbab534b01325fd60016..29eac1d07290890b23e96567f649a868a6b1d761 100644
--- a/src/lang/ir.cc
+++ b/src/lang/ir.cc
@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
-    p->stream << "attr " << op->type_key << " = ";
+    p->do_indent();
+    p->stream << "// attr " << op->type_key << " = ";
     p->print(op->value);
     p->stream << '\n';
     p->print(op->body);
diff --git a/src/lang/operation.cc b/src/lang/operation.cc
index 011ad2b92241c97dcbce27910744f67ca8bcb431..1883a5eacff398d533aee830445a9b1e9b4cd6b4 100644
--- a/src/lang/operation.cc
+++ b/src/lang/operation.cc
@@ -9,11 +9,73 @@
 
 namespace tvm {
 
+Tensor Operation::output(size_t i) const {
+  auto node = std::make_shared<TensorNode>();
+  node->op = *this;
+  node->value_index = 0;
+  node->dtype = (*this)->output_dtype(i);
+  node->shape = (*this)->output_shape(i);
+  return Tensor(node);
+}
+
+// PlaceholderOpNode
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
-    p->stream << "op(" << op << ")";
+.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
+    p->stream << "placeholder(" << op->name << ", " << op << ")";
 });
 
+TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
+
+Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
+  return {};
+}
+
+Type PlaceholderOpNode::output_dtype(size_t i) const {
+  CHECK_EQ(i, 0U);
+  return dtype;
+}
+
+Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
+  CHECK_EQ(i, 0U);
+  return shape;
+}
+
+Operation PlaceholderOpNode::make(std::string name,
+                                  Array<Expr> shape,
+                                  Type dtype) {
+  auto n = std::make_shared<PlaceholderOpNode>();
+  n->name = name;
+  n->shape = shape;
+  n->dtype = dtype;
+  return Operation(n);
+}
+
+
+
+Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
+  return PlaceholderOpNode::make(name, shape, dtype).output(0);
+}
+
+// ComputeOpNode
+Array<IterVar> ComputeOpNode::root_iter_vars() const {
+  return axis;
+}
+
+Type ComputeOpNode::output_dtype(size_t i) const {
+  CHECK_EQ(i, 0U);
+  return body.type();
+}
+
+Array<Expr> ComputeOpNode::output_shape(size_t i) const {
+  CHECK_EQ(i, 0U);
+  std::vector<Expr> shape;
+  for (size_t i = 0; i < axis.size(); ++i) {
+    const Range& r = axis[i]->dom;
+    shape.push_back(r->extent);
+  }
+  return Array<Expr>(shape);
+}
+
 Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
   auto op_node = std::make_shared<ComputeOpNode>();
   // compute dimension.
@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
   return Operation(n);
 }
 
-Tensor Operation::output(size_t i) const {
-  auto node = std::make_shared<TensorNode>();
-  node->op = *this;
-  node->value_index = 0;
-  node->name =  (*this)->output_name(i);
-  node->dtype = (*this)->output_dtype(i);
-  node->shape = (*this)->output_shape(i);
-  return Tensor(node);
-}
-
-Array<IterVar> ComputeOpNode::root_iter_vars() const {
-  return axis;
-}
-
-std::string ComputeOpNode::output_name(size_t i) const {
-  CHECK_EQ(i, 0U);
-  return name;
-}
-
-Type ComputeOpNode::output_dtype(size_t i) const {
-  CHECK_EQ(i, 0U);
-  return body.type();
-}
-
-Array<Expr> ComputeOpNode::output_shape(size_t i) const {
-  CHECK_EQ(i, 0U);
-  std::vector<Expr> shape;
-  for (size_t i = 0; i < axis.size(); ++i) {
-    const Range& r = axis[i]->dom;
-    shape.push_back(r->extent);
-  }
-  return Array<Expr>(shape);
-}
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
+    p->stream << "compute(" << op->name << ", " << op << ")";
+});
 
 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
index 16a7af8e995e5d04e329f0f8eef22ae0eefbb09b..4a4571147bbaec8f6e0ceefe3bf8c7a1bdb23805 100644
--- a/src/lang/tensor.cc
+++ b/src/lang/tensor.cc
@@ -8,33 +8,24 @@
 
 namespace tvm {
 
-Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
-  auto node = std::make_shared<TensorNode>();
-  node->name = std::move(name);
-  node->dtype = dtype;
-  node->shape = std::move(shape);
-  node_ = std::move(node);
-}
-
 Expr Tensor::operator()(Array<Expr> indices) const {
   using Halide::Internal::Call;
   CHECK_EQ(ndim(), indices.size())
       << "Tensor dimension mismatch in read"
       << "ndim = " << ndim() << ", indices.size=" << indices.size();
   auto n = Call::make(
-      (*this)->dtype, (*this)->name, indices, Call::Halide, *this);
+      (*this)->dtype, (*this)->op->name, indices, Call::Halide,
+      (*this)->op, (*this)->value_index);
   return n;
 }
 
 
 Tensor TensorNode::make(Array<Expr> shape,
-                        std::string name,
                         Type dtype,
                         Operation op,
                         int value_index) {
   auto n = std::make_shared<TensorNode>();
   n->shape = shape;
-  n->name = name;
   n->dtype = dtype;
   n->op = op;
   n->value_index = value_index;
@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
     p->stream << "Tensor(shape=" << t->shape
-              << ", name=" << t->name << ')';
+              << ", op.name=" << t->op->name << ')';
   });
 
 TVM_REGISTER_NODE_TYPE(TensorNode);
diff --git a/src/pass/inline.cc b/src/pass/inline.cc
index 1fe16372d5ef5d7256c64302184d8cd4335664a0..b912e30897dba2c3da940cfaa383b60cd81fcb04 100644
--- a/src/pass/inline.cc
+++ b/src/pass/inline.cc
@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
     expr = IRMutator::Mutate(expr);
     const Call* call = expr.as<Call>();
     if (call != nullptr && call->func == f_) {
+      CHECK_EQ(call->value_index, 0);
       return InlineCall(call);
     } else {
       return expr;
@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
             Array<Var> args,
             Expr body,
             Stmt stmt) {
+  CHECK_EQ(f->num_outputs(), 1)
+      << "can only inline output single value operation";
   return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
 }
 }  // namespace ir
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index 3bdafa75bb367a32c9b01a25230c8ff879130a1c..2c534a6c1b28195a23aacd2d5ffc49767895c1e8 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
   })
 .set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
     auto new_args = MutateArray(op->args, m);
-    auto new_values = MutateArray(op->values, m);
-    if (op->args.same_as(new_args) && op->values.same_as(new_values)) {
+    auto new_value = m->Mutate(op->value);
+    if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
       return s;
     } else {
-      return Provide::make(op->func, new_values, new_args);
+      return Provide::make(op->func, op->value_index, new_value, new_args);
     }
   })
 .set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
         condition.same_as(op->condition)) {
       return s;
     } else {
-      return Realize::make(op->func, op->types, new_bounds,
+      return Realize::make(op->func, op->value_index,
+                           op->type, new_bounds,
                            condition, body);
     }
   })
@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
 .set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
     Expr condition = m->Mutate(op->condition);
     Stmt then_case = m->Mutate(op->then_case);
-    Stmt else_case = m->Mutate(op->else_case);
+    Stmt else_case;
+    if (else_case.defined()) {
+      else_case = m->Mutate(op->else_case);
+    }
     if (condition.same_as(op->condition) &&
         then_case.same_as(op->then_case) &&
         else_case.same_as(op->else_case)) {
diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc
index f625274d0e3e7f5c7fc825ddeabd278561682d06..3bbcbbd002ada3796fd3086bcfa48c2ea0ae5a1b 100644
--- a/src/pass/ir_visitor.cc
+++ b/src/pass/ir_visitor.cc
@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
   })
 .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
     VisitArray(op->args, v);
-    VisitArray(op->values, v);
+    v->Visit(op->value);
   })
 .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
     for (size_t i = 0; i < op->extents.size(); i++) {
diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc
index e4dc5bc7e03edb8ee606ae5b5bd7d46d6a5ce00b..ae3b96f27be5c07d55233df04598afe1ae5f919e 100644
--- a/src/pass/schedule_ops.cc
+++ b/src/pass/schedule_ops.cc
@@ -6,7 +6,10 @@
 #include <tvm/ir_mutator.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_visitor.h>
+#include <tvm/schedule_pass.h>
+
 #include "./scope.h"
+#include "../schedule/graph.h"
 
 namespace tvm {
 namespace ir {
@@ -20,7 +23,7 @@ namespace {
  *     IterVar->The assignment.
  */
 void PassUpOffset(const Schedule& s,
-                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const Map<IterVar, Range>& dom_map,
                   std::unordered_map<IterVar, Expr>* p_state) {
   auto& state = *p_state;
   for (size_t i = s->relations.size(); i != 0; --i) {
@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s,
     if (rel.as<SplitNode>()) {
       const SplitNode* s = rel.as<SplitNode>();
       Expr outer = state.at(s->outer);
-      Expr inner = state.at(s->outer);
-      Expr factor = dom_map.at(s->outer)->extent;
+      Expr inner = state.at(s->inner);
+      Expr factor = dom_map.at(s->inner)->extent;
       Expr offset = inner + outer * factor;
       Expr outer_min = dom_map.at(s->parent)->min;
       if (!is_zero(outer_min)) {
@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s,
     } else if (rel.as<FuseNode>()) {
       const FuseNode* s = rel.as<FuseNode>();
       Expr value = state.at(s->fused);
-      Expr factor = dom_map.at(s->outer)->extent;
+      Expr factor = dom_map.at(s->inner)->extent;
       state[s->outer] = value / factor;
       state[s->inner] = value % factor;
     } else {
@@ -84,24 +87,35 @@ void SplitByAdd(Expr expr,
  * \param nest A list of For and LetStmt, whose body is not defined.
  * \param body body
  */
-Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
-  while (!nest.empty()) {
-    Stmt s = std::move(nest.back());
-    nest.pop_back();
-    if (s.as<For>()) {
-      auto n = std::make_shared<For>(*s.as<For>());
-      n->body = body;
-      body = Stmt(n);
-    } else if (s.as<LetStmt>()) {
-      auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
-      n->body = body;
-      body = Stmt(n);
-    } else if (s.as<AttrStmt>()) {
-      auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
-      n->body = body;
-      body = Stmt(n);
-    } else {
-      LOG(FATAL) << "not supported nest type";
+Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
+  // use reverse iteration
+  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
+    for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
+      Stmt s = *rj;
+      if (s.as<For>()) {
+        auto n = std::make_shared<For>(*s.as<For>());
+        CHECK(is_no_op(n->body));
+        n->body = body;
+        body = Stmt(n);
+      } else if (s.as<LetStmt>()) {
+        auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
+        CHECK(is_no_op(n->body));
+        n->body = body;
+        body = Stmt(n);
+      } else if (s.as<AttrStmt>()) {
+        auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
+        CHECK(is_no_op(n->body));
+        n->body = body;
+        body = Stmt(n);
+      } else if (s.as<IfThenElse>()) {
+        auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
+        CHECK(is_no_op(n->then_case));
+        CHECK(!n->else_case.defined());
+        n->then_case = body;
+        body = Stmt(n);
+      } else {
+        LOG(FATAL) << "not supported nest type";
+      }
     }
   }
   return body;
@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
  * \brief Make the loop nest of the correspondings schedule.
  * \param sch The schedule.
  * \param dom_map The domain map.
+ *
+ * \return a nested representation of loop statements.
+ *  The flattened Stmt are ordered from outmost to inner most order.
  */
-std::vector<Stmt> MakeLoopNest(
+std::vector<std::vector<Stmt> > MakeLoopNest(
     const Schedule& sch,
-    const std::unordered_map<IterVar, Range>& dom_map) {
+    const Map<IterVar, Range>& dom_map) {
   // optional, use let to define some CSE in dom_map.
   auto leaf_iter_vars = sch->leaf_iter_vars;
   std::unordered_map<IterVar, Expr> offset;
   std::unordered_map<const Variable*, size_t> loop_level;
-
+  Stmt no_op = Evaluate::make(0);
   // create the loop nest
-  std::vector<Stmt> nest;
-  nest.resize(leaf_iter_vars.size() + 1, Stmt());
+  std::vector<std::vector<Stmt> > nest;
+  nest.resize(leaf_iter_vars.size() + 1);
 
   for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
     auto iv = leaf_iter_vars[i];
     // initialize the offset and loop_level
     offset[iv] = iv->var;
     loop_level[iv->var.as<Variable>()] = i + 1;
-
-    nest[i] = AttrStmt::make(iv->var, "scope", iv, Stmt());
+    // Mark the iter var in the IR, to remember the point
     if (iv->thread_tag.length() == 0) {
       Range dom = dom_map.at(iv);
-      nest[i] = For::make(iv->var, dom->min, dom->extent,
-                          ForType::Serial, DeviceAPI::None, nest[i]);
+      nest[i + 1].emplace_back(
+          For::make(iv->var, dom->min, dom->extent,
+                    ForType::Serial, DeviceAPI::None, no_op));
     }
+    nest[i + 1].emplace_back(
+        AttrStmt::make(iv, "scope", iv->var, no_op));
   }
   // message passing to get offset of root iter vars.
   PassUpOffset(sch, dom_map, &offset);
+
   for (IterVar iv : sch->op->root_iter_vars()) {
     Expr value = offset.at(iv);
-    if (value.same_as(iv->var)) continue;
-    using Entry = std::pair<size_t, Expr>;
-    std::vector<Entry> splits;
-    SplitByAdd(value, loop_level, &splits);
-
-    Expr offset = 0;
-    for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
-      auto iv = leaf_iter_vars[i];
-      for (const auto& kv : splits) {
-        if (kv.first == i) {
-          offset = offset + splits[i].second;
+    if (!value.same_as(iv->var)) {
+      using Entry = std::pair<size_t, Expr>;
+      std::vector<Entry> splits;
+      SplitByAdd(value, loop_level, &splits);
+
+      Expr offset = 0;
+      size_t nsplit_left = splits.size() - 1;
+      for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
+        size_t hit = 0;
+        for (const auto& kv : splits) {
+          if (kv.first == i) {
+            if (is_zero(offset)) {
+              offset = kv.second;
+            } else {
+              offset = offset + kv.second;
+              ++hit;
+            }
+          }
         }
+        nsplit_left -= hit;
+        if (hit != 0) {
+          std::ostringstream os;
+          os << iv->var->name_hint << ".at.l" << i;
+          Var base_offset(os.str());
+          if (nsplit_left == 0) {
+            base_offset = iv->var;
+          }
+          nest[i].emplace_back(
+              LetStmt::make(base_offset, offset, no_op));
+          offset = base_offset;
+        }
+      }
+      Range dom = dom_map.at(iv);
+      if (!offset.same_as(iv->var)) {
+        // define the iv->var
+        nest.back().emplace_back(
+            LetStmt::make(iv->var, offset, no_op));
       }
-      std::ostringstream os;
-      os << iv->var->name_hint << ".at.l" << i;
-      Var base_offset(os.str());
-      nest[i] = LetStmt::make(base_offset, offset, nest[i]);
-      offset = base_offset;
+      Expr condition = (iv->var - dom->min) < dom->extent;
+      // Boundary condition checking
+      // Need better boundary condition here.
+      nest.back().emplace_back(IfThenElse::make(condition, no_op));
     }
-    nest.back() = LetStmt::make(iv->var, offset, nest.back());
   }
   return nest;
 }
 
+
 /*!
- * \brief Make the loop nest of the correspondings schedule.
- * \param op The operation.
+ * \brief Make pipeline specifically for compute op node.
+ * \param op The compute node
+ * \param tensors The tensors generated by provide.
  */
-Stmt MakeBody(const Operation& op) {
-  Stmt body;
-  if (op.as<ComputeOpNode>()) {
-    const ComputeOpNode* compute = op.as<ComputeOpNode>();
-    // Note: Tensor's address cannot uniquely
-    Tensor t = op.output(0);
-    Array<Expr> args;
-    for (IterVar iv : compute->axis) {
-      args.push_back(iv->var);
-    }
-    body = Provide::make(t, {compute->body}, args);
+Stmt MakeProvide(const ComputeOpNode* op,
+                 const std::vector<Tensor>& tensors) {
+  Tensor t = tensors[0];
+  Array<Expr> args;
+  for (IterVar iv : op->axis) {
+    args.push_back(iv->var);
+  }
+  return Provide::make(t->op, t->value_index, op->body, args);
+}
+
+/*!
+ * \brief Make pipeline specifically for compute op node.
+ * \param op The compute node
+ * \param dom_map The domain map
+ * \param tensors The tensors generated by provide.
+ * \param body The content of the pipeline.
+ */
+Stmt MakeRealize(const ComputeOpNode* op,
+                 const Map<IterVar, Range>& dom_map,
+                 const std::vector<Tensor>& tensors,
+                 Stmt body) {
+  Tensor t = tensors[0];
+  Halide::Internal::Region bounds;
+  for (IterVar iv : op->axis) {
+    bounds.push_back(dom_map.at(iv));
+  }
+  return Realize::make(t->op, t->value_index, t->dtype,
+                       bounds, make_const(Bool(1), true), body);
+}
+
+Stmt MakePipeline(const Schedule& sch,
+                  const Map<IterVar, Range>& dom_map,
+                  Stmt consumer) {
+  std::vector<Tensor> tensors;
+  for (int i = 0; i < sch->op->num_outputs(); ++i) {
+    tensors.emplace_back(sch->op.output(i));
+  }
+
+  Stmt provide;
+  if (sch->op.as<ComputeOpNode>()) {
+    provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
   } else {
     LOG(FATAL) << "not supported op";
   }
-  return body;
-}
+  std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
+  Stmt producer = MergeNest(nest, provide);
+  producer = ProducerConsumer::make(sch->op, true, producer);
 
-Stmt MakePipeline(const Schedule& sch, Stmt body) {
-  return body;
+  Stmt pipeline = producer;
+  if (consumer.defined()) {
+    consumer = ProducerConsumer::make(sch->op, false, consumer);
+    pipeline = Block::make(producer, consumer);
+  }
+
+  if (sch->op.as<ComputeOpNode>()) {
+    return MakeRealize(sch->op.as<ComputeOpNode>(),
+                       dom_map, tensors, pipeline);
+  } else {
+    LOG(FATAL) << "not supported op";
+    return Stmt();
+  }
 }
 
 // inject the operator's realization on the stmt.
 class InjectRealize : public IRMutator {
  public:
-  explicit InjectRealize(Schedule sch)
-      : sch_(sch) {}
+  InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map)
+      : schedule(schedule), dom_map(dom_map) {}
 
   Stmt Mutate(Stmt stmt) final {
+    CHECK(stmt.defined());
+    stmt =  IRMutator::Mutate(stmt);
     const AttrStmt* op = stmt.as<AttrStmt>();
-    if (op != nullptr) {
-      attr_scope_.Push({op->node, op->type_key}, op->value);
-      stmt = IRMutator::Mutate(stmt);
-      attr_scope_.Pop({op->node, op->type_key});
-    } else {
-      stmt = IRMutator::Mutate(stmt);
-    }
-
     if (op != nullptr &&
-        op->type_key == "scope" &&
-        op->node == sch_->attach_parent) {
-      return AttrStmt::make(
-          op->node, op->type_key, op->value,
-          MakePipeline(sch_, op->body));
-    } else {
-      return stmt;
+        op->type_key == "scope") {
+      if (op->node == schedule->attach_parent) {
+        CHECK(!found_attach);
+        found_attach = true;
+        stmt = AttrStmt::make(
+            op->node, op->type_key, op->value,
+            MakePipeline(schedule, dom_map,
+                         IRMutator::Mutate(op->body)));
+      }
     }
+    return stmt;
   }
-
- private:
   // the operations to be carried
-  Schedule sch_;
-  Scope<AttrKey, Expr> attr_scope_;
+  Schedule schedule;
+  // domain map
+  Map<IterVar, Range> dom_map;
+  // whether attach point is found
+  bool found_attach{false};
 };
 
 
+
+void GetOpToScheduleMap(
+    Schedule s,
+    std::unordered_map<Operation, Schedule>* ret) {
+  CHECK(!ret->count(s->op))
+      << "Duplicated schedule for op";
+  (*ret)[s->op] = s;
+  for (Schedule c : s->children) {
+    GetOpToScheduleMap(c, ret);
+  }
+}
+
+// order schedule by DFS calling order of ops
+std::vector<Schedule> OrderSchedule(Schedule s) {
+  auto g = schedule::CreateReadGraph(s->op);
+  auto post_order = schedule::PostDFSOrder(s->op, g);
+  std::unordered_map<Operation, Schedule> op2sch;
+  GetOpToScheduleMap(s, &op2sch);
+  std::vector<Schedule> sorder;
+
+  // reverse iteration.
+  for (size_t i = post_order.size(); i != 0; --i) {
+    sorder.push_back(op2sch.at(post_order[i - 1]));
+  }
+  return sorder;
+}
+
+Stmt InjectInline(const Operation op, Stmt body) {
+  CHECK(body.defined());
+  const ComputeOpNode* compute = op.as<ComputeOpNode>();
+  CHECK(compute != nullptr)
+      << "can only inline compute op";
+  Array<Var> args;
+  for (auto iv : compute->axis) {
+    args.push_back(iv->var);
+  }
+  return Inline(op, args, compute->body, body);
+}
+
 }  // namespace
+
+Stmt ScheduleOps(
+    Schedule s, Map<IterVar, Range> dom_map) {
+  std::vector<Schedule> svec = OrderSchedule(s);
+  Stmt body = Stmt();
+
+  for (Schedule s : svec) {
+    if (s->attach_type == kInline) {
+      body = InjectInline(s->op, body);
+    } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
+      body = MakePipeline(s, dom_map, body);
+    } else if (s->attach_type == kScope) {
+      CHECK(body.defined());
+      InjectRealize mutator(s, dom_map);
+      body = mutator.Mutate(body);
+      CHECK(mutator.found_attach)
+          << "did not find attachment point";
+    }
+  }
+  return body;
+}
+
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/pass/scope.h b/src/pass/scope.h
index 36a38d67c55eb7857f8079496dc6f9a69aaf55d8..8fee949d86dacb8af2d500b2367db799d491db63 100644
--- a/src/pass/scope.h
+++ b/src/pass/scope.h
@@ -36,7 +36,7 @@ class Scope {
    */
   inline void Pop(const K& key) {
     auto& v = data_[key];
-    CHECK_NE(v.size(), 0);
+    CHECK_NE(v.size(), 0U);
     v.pop_back();
   }
 
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 67fafc493b99a97968517fe47a4b5c991080e496..86ed36fb91f9c34439203609de9db9ca6e6928eb 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -5,8 +5,8 @@
  */
 #include <tvm/ir.h>
 #include <tvm/ir_visitor.h>
+#include <tvm/schedule_pass.h>
 #include "./int_set.h"
-#include "./bound.h"
 #include "./graph.h"
 
 namespace tvm {
@@ -113,7 +113,7 @@ void PassToOperation(
       (*result)[root_iter_vars[i]].push_back(dim_bounds[i]);
     }
   } else {
-    LOG(FATAL) << "unknown operation mode";
+    LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
   }
 }
 
@@ -140,8 +140,8 @@ BoundProp(const Array<Operation>& post_order,
       auto fvisit = [p_state, &result](const NodeRef& n) {
         auto *call = n.as<ir::Call>();
         if (call != nullptr && call->func.defined()) {
-          Tensor t(call->func.node_);
-          if (t->op.defined()) {
+          Tensor t = Operation(call->func.node_).output(call->value_index);
+          if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
             std::vector<IntSet> arg_bounds;
             for (size_t i = 0; i < t.ndim(); ++i) {
               arg_bounds.push_back(EvalSet(call->args[i], result));
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
index c6693210cad604796b8368007356dc263a96b77b..ade0e433d20f1bbe6811fd7ed0355a793e460866 100644
--- a/src/schedule/graph.cc
+++ b/src/schedule/graph.cc
@@ -27,18 +27,20 @@ ReadGraph CreateReadGraph(const Operation& root) {
       auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
         auto *call = n.as<ir::Call>();
         if (call != nullptr && call->func.defined()) {
-          Tensor t(call->func.node_);
-          deps.push_back(t);
-          if (t->op.defined() && visited.count(t->op.get()) == 0) {
-            visited.insert(t->op.get());
-            stack.push_back(t->op);
+          Operation call_op(call->func.node_);
+          deps.push_back(call_op.output(call->value_index));
+          if (call_op.defined() && visited.count(call_op.get()) == 0) {
+            visited.insert(call_op.get());
+            stack.push_back(call_op);
           }
         }
       };
       ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
       rmap.Set(op, deps);
     } else {
-      LOG(FATAL) << "unknown operation mode";
+      if (!op.as<PlaceholderOpNode>()) {
+        LOG(FATAL) << "unknown Operation" << op->type_key();
+      }
     }
   }
   return rmap;
@@ -51,7 +53,7 @@ void PostDFSOrder(const Operation& op,
                   Array<Operation>* post_order) {
   visited->insert(op);
   for (const auto& t : g.at(op)) {
-    if (t->op.defined() && !visited->count(t->op)) {
+    if (!t->op.as<PlaceholderOpNode>() && !visited->count(t->op)) {
       PostDFSOrder(t->op, g, visited, post_order);
     }
   }
diff --git a/src/schedule/int_set.cc b/src/schedule/int_set.cc
index 2b5cd182236226964ea95b0982a30dee9872414f..ac0b0c6ac910d5daec3d79bcba775331762e2518 100644
--- a/src/schedule/int_set.cc
+++ b/src/schedule/int_set.cc
@@ -220,7 +220,7 @@ void PassUp(const SplitNode* s,
     *parent = IntSet::make_range(dom_map.at(s->parent));
     return;
   }
-  Expr factor = dom_map.at(s->outer)->extent;
+  Expr factor = dom_map.at(s->inner)->extent;
   CHECK(outer.defined());
   CHECK(inner.defined());
   CHECK(factor.defined());
@@ -261,7 +261,7 @@ void PassUp(const FuseNode* s,
 
   if (IsNumber(fused)) {
     Expr value = AsNumber(fused);
-    Expr factor = dom_map.at(s->outer)->extent;
+    Expr factor = dom_map.at(s->inner)->extent;
     *outer = IntSet::make_point(value / factor);
     *inner = IntSet::make_point(value % factor);
   } else {
diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc
index 493399388fa10c7484a3e1e24480975d3357212f..a62a0b09af0ea4e1d41504d6d6bc68c556957ac8 100644
--- a/tests/cpp/tensor_test.cc
+++ b/tests/cpp/tensor_test.cc
@@ -5,8 +5,9 @@
 TEST(Tensor, Basic) {
   using namespace tvm;
   Var m("m"), n("n"), l("l");
-  Tensor A({m, l}, "A");
-  Tensor B({n, l}, "B");
+
+  Tensor A = Placeholder({m, l}, Float(32), "A");
+  Tensor B = Placeholder({n, l}, Float(32), "B");
 
   auto C = Compute({m, n}, [&](Var i, Var j) {
       return A[i][j];
@@ -19,8 +20,8 @@ TEST(Tensor, Basic) {
 TEST(Tensor, Reduce) {
   using namespace tvm;
   Var m("m"), n("n"), l("l");
-  Tensor A({m, l}, "A");
-  Tensor B({n, l}, "B");
+  Tensor A = Placeholder({m, l}, Float(32), "A");
+  Tensor B = Placeholder({n, l}, Float(32), "B");
   IterVar rv(Range{0, l}, "k");
 
   auto C = Compute({m, n}, [&](Var i, Var j) {
diff --git a/tests/python/test_basic.py b/tests/python/test_lang_basic.py
similarity index 100%
rename from tests/python/test_basic.py
rename to tests/python/test_lang_basic.py
diff --git a/tests/python/test_container.py b/tests/python/test_lang_container.py
similarity index 100%
rename from tests/python/test_container.py
rename to tests/python/test_lang_container.py
diff --git a/tests/python/test_schedule.py b/tests/python/test_lang_schedule.py
similarity index 100%
rename from tests/python/test_schedule.py
rename to tests/python/test_lang_schedule.py
diff --git a/tests/python/test_tensor.py b/tests/python/test_lang_tensor.py
similarity index 94%
rename from tests/python/test_tensor.py
rename to tests/python/test_lang_tensor.py
index f24aa987e82b6942aa0da97f2c7fabb1dbc61e59..ca695813d7a20e0dec1a2b62ead90288caa275cf 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_lang_tensor.py
@@ -10,7 +10,7 @@ def test_tensor():
     print(T)
     print(T.op.body)
     assert(tuple(T.shape) == (m, n, l))
-    assert(A.op is None)
+    assert(isinstance(A.op, tvm.tensor.PlaceholderOp))
     assert(A == A)
     assert(T.op.output(0) == T)
     assert(T.op.output(0).__hash__() == T.__hash__())
diff --git a/tests/python/test_ir_pass.py b/tests/python/test_pass_basic.py
similarity index 100%
rename from tests/python/test_ir_pass.py
rename to tests/python/test_pass_basic.py
diff --git a/tests/python/test_inline.py b/tests/python/test_pass_inline.py
similarity index 83%
rename from tests/python/test_inline.py
rename to tests/python/test_pass_inline.py
index c3f6b6aa7b15cb62c36d11c0808fc1379bde693e..858864c60b7551e4291e0a02c3026224190229c9 100644
--- a/tests/python/test_inline.py
+++ b/tests/python/test_pass_inline.py
@@ -6,7 +6,7 @@ def test_inline():
     T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
     stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
     stmt = tvm.ir_pass.Inline(
-        T, [x.var for x in T.op.axis], T.op.body, stmt)
+        T.op, [x.var for x in T.op.axis], T.op.body, stmt)
     print(stmt)
     assert(tvm.ir_pass.VerifySSA(stmt))
 
@@ -14,7 +14,7 @@ def test_inline():
         # pass in int array(wrong argument type)
         # must raise an error
         stmt = tvm.ir_pass.Inline(
-            T, [1,2,3], T.op.body, stmt)
+            T.op, [1,2,3], T.op.body, stmt)
         assert False
     except tvm.TVMError:
         pass
diff --git a/tests/python/test_pass_schedule_ops.py b/tests/python/test_pass_schedule_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a74bd7f53bfeafa4300f9f13bb94678769b5e70
--- /dev/null
+++ b/tests/python/test_pass_schedule_ops.py
@@ -0,0 +1,46 @@
+import tvm
+
+
+def test_schedule0():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    sA1 = tvm.Schedule(A1.op)
+    bounds = tvm.schedule.InferBound(sA1)
+    assert isinstance(bounds, tvm.collections.Map)
+    stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
+    print(stmt)
+
+def test_schedule1():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    sA1 = tvm.Schedule(A1.op)
+    xo, xi = sA1.split(A1.op.axis[0], 8)
+    bounds = tvm.schedule.InferBound(sA1)
+    assert isinstance(bounds, tvm.collections.Map)
+    stmt = tvm.ir_pass.ScheduleOps(sA1, bounds)
+    print(stmt)
+
+def test_schedule2():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
+    sA1 = tvm.Schedule(A1.op)
+    sA2 = tvm.Schedule(A2.op)
+    xo, xi = sA2.split(A2.op.axis[0], 8)
+    sA1.compute_at(sA2, xo)
+    bounds = tvm.schedule.InferBound(sA2)
+    assert isinstance(bounds, tvm.collections.Map)
+    stmt = tvm.ir_pass.ScheduleOps(sA2, bounds)
+    print(stmt)
+
+
+if __name__ == "__main__":
+    test_schedule0()
+    test_schedule1()
+    test_schedule2()
diff --git a/tests/python/test_bound_inference.py b/tests/python/test_schedule_bound_inference.py
similarity index 100%
rename from tests/python/test_bound_inference.py
rename to tests/python/test_schedule_bound_inference.py
index 9a1626f1038f77ef6cef32912fdc99afad3da57d..7970d99080b77e48c7a1679c83597c460b0e9ff4 100644
--- a/tests/python/test_bound_inference.py
+++ b/tests/python/test_schedule_bound_inference.py
@@ -65,7 +65,7 @@ def test_create_read_graph():
 
 
 if __name__ == "__main__":
+    test_create_read_graph()
     test_bound3()
     test_bound1()
     test_bound2()
-    test_create_read_graph()