From 820a85975f09ce5eb2aaad0df3496e69874f3b80 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 13 Feb 2017 16:39:07 -0800
Subject: [PATCH] [LANG] Introduce Scan, Bugfix Canonical (#43)

---
 include/tvm/ir.h                              |  35 ++-
 include/tvm/operation.h                       |  64 ++++
 python/tvm/api.py                             |  55 +++-
 python/tvm/tensor.py                          |   9 +-
 src/api/api_lang.cc                           |   9 +
 src/arithmetic/canonical.cc                   |  11 +-
 src/codegen/codegen_cuda.cc                   |   2 +-
 src/codegen/codegen_cuda.h                    |   2 +-
 src/lang/operation.cc                         |  87 ++++++
 src/pass/inject_virtual_thread.cc             |  20 +-
 src/pass/storage_flatten.cc                   |  34 --
 src/schedule/bound.cc                         | 125 +++++++-
 src/schedule/graph.cc                         |  22 +-
 src/schedule/schedule_lang.cc                 |   2 +
 src/schedule/schedule_ops.cc                  | 290 ++++++++++++++++--
 tests/python/integration/test_scan.py         |  54 ++++
 tests/python/unittest/test_lang_tensor.py     |  14 +
 tests/python/unittest/test_pass_simplify.py   |  10 +-
 .../unittest/test_schedule_schedule_ops.py    |  48 ++-
 19 files changed, 776 insertions(+), 117 deletions(-)
 create mode 100644 tests/python/integration/test_scan.py

diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index d6a258053..e6aa692af 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -49,12 +49,27 @@ struct Reduce : public ExprNode<Reduce> {
   static constexpr const char* Min = "Min";
 };
 
-/*! \brief namespace of possible attribute sin AttrStmt.type_key */
-namespace attr {
 /*!
- * \brief Mark scope of iteration variable, used by Schedule.
+ * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
  */
-constexpr const char* scope = "scope";
+struct TensorKey {
+  FunctionRef f;
+  int value_index;
+
+  inline bool operator==(const TensorKey& other) const {
+    return f == other.f && value_index == other.value_index;
+  }
+  inline std::string GetName() const {
+    if (f->num_outputs() == 1) return f->func_name();
+    std::ostringstream os;
+    os << f->func_name() << ".v" << value_index;
+    return os.str();
+  }
+};
+
+/*! \brief namespace of possible attribute sin AttrStmt.type_key */
+namespace attr {
+// The above attr does not pass to ir stage.
 /*!
  * \brief Mark launching extent of thread, used by device API.
  */
@@ -189,4 +204,16 @@ using Halide::Internal::Evaluate;
 }  // namespace ir
 }  // namespace tvm
 
+namespace std {
+template <>
+struct hash<::tvm::ir::TensorKey> {
+  std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
+    size_t lhs = k.f.hash();
+    size_t rhs = static_cast<size_t>(k.value_index);
+    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
+    return lhs;
+  }
+};
+}  // namespace std
+
 #endif  // TVM_IR_H_
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index a48d0e5b8..1d16c3428 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -77,6 +77,55 @@ class ComputeOpNode : public OperationNode {
   TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
 };
 
+/*!
+ * \brief Symbolic scan.
+ */
+class ScanOpNode : public OperationNode {
+ public:
+  /*! \brief IterVar to scan over */
+  IterVar scan_axis;
+  /*! \brief the initialization tensors */
+  Array<Tensor> init;
+  /*! \brief the update function represented by tensor */
+  Array<Tensor> update;
+  /*! \brief The placeholder to refer as states in update. */
+  Array<Tensor> state_placeholder;
+  /*!
+   * \brief Spatial axis to indicate spatial dimension of each output.
+   *  They corresponds to flattened spatial axis of the outputs.
+   *
+   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
+   *  These are auxiliary data structure for storing result of bound inference.
+   *  They do not corresponds to splittable iterations, thus the name comes
+   *  with underscore.
+   */
+  Array<IterVar> spatial_axis_;
+  /*! \brief constructor */
+  ScanOpNode() {}
+  // override behavior.
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  Type output_dtype(size_t i) const final;
+  Array<Expr> output_shape(size_t i) const final;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("scan_axis", &scan_axis);
+    v->Visit("init", &init);
+    v->Visit("update", &update);
+    v->Visit("state_placeholder", &state_placeholder);
+    v->Visit("spatial_axis_", &spatial_axis_);
+  }
+  static Operation make(std::string name,
+                        IterVar axis,
+                        Array<Tensor> init,
+                        Array<Tensor> update,
+                        Array<Tensor> state_placeholder);
+
+  static constexpr const char* _type_key = "ScanOp";
+  TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
+};
+
 
 /*! \brief The compute function to specify the input source of a Tensor */
 using FCompute = std::function<Expr (const Array<Var>& i)>;
@@ -100,6 +149,21 @@ Tensor Placeholder(Array<Expr> shape,
  */
 Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
 
+/*!
+ * \brief Construct new tensors by scan over scan_axis.
+ *
+ * \param scan_axis The iteration representing the scan.
+ * \param init The intialize tensor of first K steps.
+ * \param update The update tensor indicated the updated result after each timestamp.
+ * \param state_placeholder The placeholder for the states.
+ * \param name The optional name of the tensor.
+ */
+Array<Tensor> Scan(IterVar scan_axis,
+                   Array<Tensor> init,
+                   Array<Tensor> update,
+                   Array<Tensor> state_placeholder,
+                   std::string name = "scan");
+
 // same as compute, specialized for different fcompute function
 inline Tensor Compute(Array<Expr> shape,
                       std::function<Expr(Var)> f,
diff --git a/python/tvm/api.py b/python/tvm/api.py
index bb1a563b2..2c3f54483 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -14,6 +14,7 @@ from ._ctypes._function import convert_to_tvm_func as _convert_tvm_func
 from . import _api_internal
 from . import make as _make
 from . import expr as _expr
+from . import tensor as _tensor
 from . import collections as _collections
 
 int32 = "int32"
@@ -111,7 +112,6 @@ def compute(shape, fcompute, name="compute"):
     shape: Tuple of Expr
         The shape of the tensor
 
-
     fcompute: lambda function of *indices-> value
         Specifies the input source expression
 
@@ -137,8 +137,57 @@ def compute(shape, fcompute, name="compute"):
     body = convert(body)
     op_node = _api_internal._ComputeOp(
         name, dim_var, body)
-    return _api_internal._Tensor(
-        shape, body.dtype, op_node, 0)
+    return op_node.output(0)
+
+
+def scan(axis, init, update, state_placeholder, name="scan"):
+    """Construct new tensors by scanning over axis.
+
+    Parameters
+    ----------
+    axis: IterVar
+        The scanning axis.
+
+    init: Tensor or list of Tensor
+        The initial condition of first init.shape[0] timestamps
+
+    update: Tensor or list of Tensor
+        The update rule of the scan given by symbolic tensor.
+
+    state_placeholder: Tensor or list of Tensor
+        The placeholder variables used by update.
+
+    name: str, optional
+        The name hint of the tensor
+
+    Returns
+    -------
+    tensor: tensor.Tensor
+        The created tensor
+
+    Example
+    -------
+    # The following code is equivalent to numpy.cumsum
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    t = tvm.IterVar((1, m), name="t")
+    X = tvm.placeholder((m, n), name="X")
+    s_state = tvm.placeholder((m, n))
+    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
+    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
+    res = tvm.scan(t, s_init, s_update, s_state)
+    """
+    if isinstance(init, _tensor.Tensor):
+        init = [init]
+    if isinstance(update, _tensor.Tensor):
+        update = [update]
+    if isinstance(state_placeholder, _tensor.Tensor):
+        state_placeholder = [state_placeholder]
+    if len(init) != len(update) or len(init) != len(state_placeholder):
+        raise ValueError("init, update, state_placeholder must have same length")
+    op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
+    res = [op.output(i) for i in range(len(update))]
+    return (res[0] if len(res) == 1 else res)
 
 
 def Buffer(shape, dtype=None,
diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py
index 47a7ec88c..2dbab96de 100644
--- a/python/tvm/tensor.py
+++ b/python/tvm/tensor.py
@@ -74,12 +74,17 @@ class Operation(NodeBase):
         """
         return _api_internal._OpGetOutput(self, index)
 
+@register_node
+class PlaceholderOp(Operation):
+    """Placeholder operation."""
+    pass
+
 @register_node
 class ComputeOp(Operation):
     """Compute operation."""
     pass
 
 @register_node
-class PlaceholderOp(Operation):
-    """Placeholder operation."""
+class ScanOp(Operation):
+    """Scan operation."""
     pass
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 769345fc4..32fcc41a1 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -173,6 +173,15 @@ TVM_REGISTER_API(_ComputeOp)
                                args[2]);
   });
 
+TVM_REGISTER_API(_ScanOp)
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = ScanOpNode::make(args[0],
+                            args[1],
+                            args[2],
+                            args[3],
+                            args[4]);
+  });
+
 TVM_REGISTER_API(_OpGetOutput)
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     *ret = args[0].operator Operation().output(
diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 8ae8ed47e..ae95b04a5 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -365,7 +365,7 @@ class Canonical::Internal : public IRMutator {
                   const ComExpr& sumb,
                   int bscale) {
     std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
-    n->base = suma->base + sumb->base;
+    n->base = suma->base + sumb->base * bscale;
     // merge of suma and sumb;
     size_t i = 0, j = 0;
     while (i < suma->elem.size() && j < sumb->elem.size()) {
@@ -417,7 +417,7 @@ class Canonical::Internal : public IRMutator {
   // convert sum to expr
   Expr Sum2Expr(const ComExpr& com, Type t) {
     Expr vsum;
-    if (com->base != 0) {
+    if (com->base > 0) {
       vsum = make_const(t, com->base);
     }
     for (const ComExprEntry& e : com->elem) {
@@ -433,6 +433,13 @@ class Canonical::Internal : public IRMutator {
         }
       }
     }
+    if (com->base < 0) {
+      if (vsum.defined()) {
+        vsum = Sub::make(vsum, make_const(t, -com->base));
+      } else {
+        vsum = make_const(t, com->base);
+      }
+    }
     for (const ComExprEntry& e : com->elem) {
       if (e.scale < 0) {
         Expr v = e.value;
diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
index c4c5d99f3..c526ec8d0 100644
--- a/src/codegen/codegen_cuda.cc
+++ b/src/codegen/codegen_cuda.cc
@@ -168,7 +168,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
     const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
     code = f(code).operator std::string();
   }
-    LOG(INFO) << code;
+
   std::string ptx;
   if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
     const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h
index 428f9ffdd..641c28f95 100644
--- a/src/codegen/codegen_cuda.h
+++ b/src/codegen/codegen_cuda.h
@@ -42,7 +42,7 @@ class CodeGenCUDA : public CodeGenC {
  private:
   // magic number to add pragma unroll to it.
   // used to generate code that is compact but still unrolls.
-  int max_auto_unroll_{8};
+  int max_auto_unroll_{1025};
 };
 
 }  // namespace codegen
diff --git a/src/lang/operation.cc b/src/lang/operation.cc
index 95c292e48..9e16f1c1b 100644
--- a/src/lang/operation.cc
+++ b/src/lang/operation.cc
@@ -5,6 +5,7 @@
 #include <tvm/operation.h>
 #include <tvm/tensor.h>
 #include <tvm/ir.h>
+#include <tvm/ir_pass.h>
 #include <memory>
 
 namespace tvm {
@@ -120,4 +121,90 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 
+// Scan
+inline bool prove_equal(Expr lhs, Expr rhs) {
+  return is_zero(ir::Simplify(lhs - rhs));
+}
+
+int ScanOpNode::num_outputs() const {
+  return update.size();
+}
+Array<IterVar> ScanOpNode::root_iter_vars() const {
+  return Array<IterVar>{scan_axis};
+}
+
+Type ScanOpNode::output_dtype(size_t i) const {
+  return update[i]->dtype;
+}
+
+Array<Expr> ScanOpNode::output_shape(size_t i) const {
+  CHECK_LT(i, state_placeholder.size());
+  return state_placeholder[i]->shape;
+}
+
+Operation ScanOpNode::make(std::string name,
+                           IterVar axis,
+                           Array<Tensor> init,
+                           Array<Tensor> update,
+                           Array<Tensor> state_placeholder) {
+  auto n = std::make_shared<ScanOpNode>();
+  CHECK_EQ(init.size(), update.size());
+  CHECK_EQ(init.size(), state_placeholder.size());
+
+  for (size_t i = 0; i < init.size(); ++i) {
+    CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
+    CHECK_EQ(init[i]->dtype, update[i]->dtype);
+    CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
+        << "init.shape[0] need to match scan_axis.dom.min";
+    CHECK(prove_equal(
+        state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
+        << "shate_placeholder.shape[0] need to match"
+        << " scan_axis.dom.min + scan_axis.dom.extent";
+    CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
+        << "The dimension of init need to match state_placeholder";
+    CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
+        << "The update.ndim need to be state_placeholder.ndim - 1";
+    for (size_t k = 0;  k < update[i].ndim(); ++k) {
+      CHECK(prove_equal(
+          update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
+      // setup spatial axis
+      std::ostringstream spatial_name;
+      spatial_name << name << ".out" << i << ".i" << k + 1;
+      n->spatial_axis_.push_back(
+          IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
+                  spatial_name.str()));
+    }
+    for (size_t k = 1;  k < init[i].ndim(); ++k) {
+      CHECK(prove_equal(
+          init[i]->shape[k], state_placeholder[i]->shape[k]));
+    }
+  }
+
+  n->name = name;
+  n->scan_axis = axis;
+  n->init = init;
+  n->update = update;
+  n->state_placeholder = state_placeholder;
+  return Operation(n);
+}
+
+Array<Tensor> Scan(IterVar scan_axis,
+                   Array<Tensor> init,
+                   Array<Tensor> update,
+                   Array<Tensor> state_placeholder,
+                   std::string name) {
+  Operation op = ScanOpNode::make(
+      name, scan_axis, init, update, state_placeholder);
+  Array<Tensor> res;
+  for (int i = 0; i < op->num_outputs(); ++i) {
+    res.push_back(op.output(i));
+  }
+  return res;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
+    p->stream << "scan(" << op->name << ", " << op << ")";
+});
+
 }  // namespace tvm
diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc
index 0a9f5b38f..2ca1d7c41 100644
--- a/src/pass/inject_virtual_thread.cc
+++ b/src/pass/inject_virtual_thread.cc
@@ -191,20 +191,16 @@ class VTInjector : public IRMutator {
   }
   // Attribute
   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
-    if (op->type_key == attr::scope) {
-      return Mutate(op->body);
+    Expr value = Mutate(op->value);
+    if (visit_touched_var_) {
+      return InjectVTLoop(s, true);
     } else {
-      Expr value = Mutate(op->value);
-      if (visit_touched_var_) {
-        return InjectVTLoop(s, true);
+      Stmt body = Mutate(op->body);
+      if (value.same_as(op->value) &&
+          body.same_as(op->body)) {
+        return s;
       } else {
-        Stmt body = Mutate(op->body);
-        if (value.same_as(op->value) &&
-            body.same_as(op->body)) {
-          return s;
-        } else {
-          return AttrStmt::make(op->node, op->type_key, value, body);
-        }
+        return AttrStmt::make(op->node, op->type_key, value, body);
       }
     }
   }
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index 944a8c0a4..e7a881640 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -11,40 +11,6 @@
 namespace tvm {
 namespace ir {
 
-// key of function buffer
-struct TensorKey {
-  FunctionRef f;
-  int value_index;
-
-  inline bool operator==(const TensorKey& other) const {
-    return f == other.f && value_index == other.value_index;
-  }
-  inline std::string GetName() const {
-    if (f->num_outputs() == 1) return f->func_name();
-    std::ostringstream os;
-    os << f->func_name() << ".v" << value_index;
-    return os.str();
-  }
-};
-
-}  // namespace ir
-}  // namespace tvm
-
-namespace std {
-template <>
-struct hash<::tvm::ir::TensorKey> {
-  std::size_t operator()(const ::tvm::ir::TensorKey& k) const {
-    size_t lhs = k.f.hash();
-    size_t rhs = static_cast<size_t>(k.value_index);
-    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
-    return lhs;
-  }
-};
-}  // namespace std
-
-namespace tvm {
-namespace ir {
-
 using Halide::Internal::Region;
 using runtime::StorageScope;
 using runtime::ThreadScope;
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 4514d0228..88729a3ce 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -23,6 +23,10 @@ inline Expr DivCeil(Expr a, Expr b) {
   return ir::Simplify((a + b - 1) / b);
 }
 
+inline bool prove_equal(Expr lhs, Expr rhs) {
+  return is_zero(ir::Simplify(lhs - rhs));
+}
+
 // Downward message passing algorithm on stage schedule s,
 // pass the range state down from the root to the leaves
 // after this pass, every IterVar in the stage hyper graph will have a range(domain)
@@ -41,9 +45,18 @@ void PassDown(const Stage& s,
         if (r->outer->dom.defined()) {
           state[r->outer] = r->outer->dom;
         } else {
-          CHECK(!state.count(r->outer));
-          state[r->outer] = Range::make_with_min_extent(
-              0, DivCeil(range_parent->extent, r->factor));
+          if (!state.count(r->outer)) {
+            state[r->outer] = Range::make_with_min_extent(
+                0, DivCeil(range_parent->extent, r->factor));
+          } else {
+            Expr outer_ext = DivCeil(range_parent->extent, r->factor);
+            Range outer_rng = state.at(r->outer);
+            bool match = is_zero(outer_rng->min);
+            if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
+            CHECK(match)
+                << "IterVar is used in two places as outer scope,"
+                << " cannot prove their extents are the same";
+          }
         }
       } else {
         CHECK(r->outer->dom.defined());
@@ -181,6 +194,21 @@ void PassUp(const Stage& s,
   }
 }
 
+// All the itervars that are needed to output bound of op.
+// For most op, it is root_iter_vars
+// For Scan, it also contains the additional spatial axis.
+Array<IterVar> OutputRelatedIterVars(const Operation& op) {
+  if (op.as<ScanOpNode>()) {
+    const ScanOpNode* scan = op.as<ScanOpNode>();
+    Array<IterVar> ret{scan->scan_axis};
+    for (IterVar iv : scan->spatial_axis_) {
+      ret.push_back(iv);
+    }
+    return ret;
+  } else {
+    return op->root_iter_vars();
+  }
+}
 
 /*! \brief temporary data structure to store Tensor domain */
 struct TensorDom {
@@ -214,6 +242,34 @@ void BoundProp(const Operation& op,
       }
     };
     ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+  } else if (op.as<ScanOpNode>()) {
+    const ScanOpNode* scan = op.as<ScanOpNode>();
+    size_t sp_idx = 0;
+    for (size_t i = 0; i < scan->init.size(); ++i) {
+      TensorDom* init_dom = nullptr;
+      TensorDom* update_dom = nullptr;
+      if (out->count(scan->init[i])) {
+        init_dom = &out->at(scan->init[i]);
+      }
+      if (out->count(scan->update[i])) {
+        update_dom = &out->at(scan->update[i]);
+      }
+      // first dimension, always needed.
+      if (init_dom) {
+        init_dom->data[0].push_back(IntSet::range(
+            Range::make_with_min_extent(0, scan->init[i]->shape[0])));
+      }
+      // The update dimensions
+      for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+        IterVar sp_ax = scan->spatial_axis_[sp_idx];
+        if (init_dom) {
+          init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
+        }
+        if (update_dom) {
+          update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
+        }
+      }
+    }
   } else if (op.as<PlaceholderOpNode>()) {
     // do nothing
   } else {
@@ -221,14 +277,49 @@ void BoundProp(const Operation& op,
   }
 }
 
-void InferOpBound(const Operation& op,
-                  const std::unordered_map<Tensor, TensorDom>& tmap,
-                  std::unordered_map<IterVar, Range>* rmap) {
+// Given the bound of output of op
+// Pass the bound to the related axis in op.
+void GatherOpBound(const ScanOpNode* scan,
+                   const Operation& op,
+                   const std::unordered_map<Tensor, TensorDom>& tmap,
+                   std::unordered_map<IterVar, Range>* rmap) {
+  CHECK(!rmap->count(scan->scan_axis));
+  std::vector<Tensor> output(op->num_outputs());
+  for (size_t i = 0; i < output.size(); ++i) {
+    output[i] = op.output(i);
+  }
+  // Update for time axis.
+  std::vector<IntSet> time_dom;
+  for (size_t i = 0; i < output.size(); ++i) {
+    const TensorDom& d = tmap.at(output[i]);
+    time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end());
+  }
+  LOG(INFO) << time_dom.size();
+  CHECK(!rmap->count(scan->scan_axis));
+  Range sdom = scan->scan_axis->dom;
+  Range r = arith::Union(time_dom).cover_range(sdom);
+  (*rmap)[scan->scan_axis] = Range::make_with_min_extent(
+      sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
+  // Update for spatial axis.
+  size_t sp_idx = 0;
+  for (size_t i = 0; i < output.size(); ++i) {
+    for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      IterVar sp_ax = scan->spatial_axis_[sp_idx];
+      CHECK(!rmap->count(sp_ax));
+      // In default, we always need all spatial axis
+      // Unless that axis only refers back to itself as a fixed point.
+      // TODO(tqchen): Add fix point detection.
+      (*rmap)[sp_ax] = sp_ax->dom;
+    }
+  }
+}
+
+void GatherOpBound(const Operation& op,
+                   const std::unordered_map<Tensor, TensorDom>& tmap,
+                   std::unordered_map<IterVar, Range>* rmap) {
   if (op.as<ComputeOpNode>()) {
-    auto root_iter_vars = op->root_iter_vars();
     const ComputeOpNode* compute = op.as<ComputeOpNode>();
     const TensorDom& tdom = tmap.at(op.output(0));
-
     for (size_t i = 0; i < compute->axis.size(); ++i) {
       Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
       CHECK(!rmap->count(compute->axis[i]));
@@ -238,6 +329,8 @@ void InferOpBound(const Operation& op,
       CHECK(!rmap->count(compute->reduce_axis[i]));
       (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
     }
+  } else if (op.as<ScanOpNode>()) {
+    GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
   } else if (op.as<PlaceholderOpNode>()) {
     // dp nothing
   } else {
@@ -269,8 +362,7 @@ void InferRootBound(const Stage& stage,
                     std::unordered_map<IterVar, Range>* rmap) {
   if (stage->attach_type == kInline) return;
   if (stage->attach_type == kRoot || stage->attach_type == kNone) {
-    auto root_iter_vars = stage->op->root_iter_vars();
-    for (auto iv :  root_iter_vars) {
+    for (auto iv :  OutputRelatedIterVars(stage->op)) {
       CHECK(iv->dom.defined());
       CHECK(!rmap->count(iv));
       (*rmap)[iv] = iv->dom;
@@ -338,8 +430,13 @@ void InferRootBound(const Stage& stage,
     PassUp(parent, *rmap, &up_state);
 
     std::unordered_map<const Variable*, IntSet> dom_map;
-    for (auto iv : parent->op->root_iter_vars()) {
-      Range r = up_state.at(iv).cover_range(iv->dom);
+    for (auto iv : OutputRelatedIterVars(parent->op)) {
+      Range r;
+      if (up_state.count(iv)) {
+        r = up_state.at(iv).cover_range(iv->dom);
+      } else {
+        r = iv->dom;
+      }
       if (relax_set.size() != 0) {
         dom_map[iv->var.get()] = EvalSet(r, relax_set);
       } else {
@@ -379,13 +476,13 @@ void InferRootBound(const Stage& stage,
     CHECK(found)
         << "Invalid Schedule, cannot find the producer " << stage->op
         << " along the loop nest specified by compute_at of consumer " << op;
-    for (auto iv : op->root_iter_vars()) {
+    for (auto iv : OutputRelatedIterVars(op)) {
       Range r = rmap->at(iv);
       dom_map[iv->var.get()] = EvalSet(r, relax_set);
     }
     BoundProp(op, dom_map, &tmap);
   }
-  InferOpBound(stage->op, tmap, rmap);
+  GatherOpBound(stage->op, tmap, rmap);
 }
 
 FeedGraph CreateFeedGraph(const Schedule& sch) {
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
index 33272fceb..f1047bf95 100644
--- a/src/schedule/graph.cc
+++ b/src/schedule/graph.cc
@@ -33,20 +33,28 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
         if (call != nullptr && call->func.defined()) {
           Operation call_op(call->func.node_);
           deps.push_back(call_op.output(call->value_index));
-          if (call_op.defined() && visited.count(call_op.get()) == 0) {
-            visited.insert(call_op.get());
-            stack.push_back(call_op);
-          }
         }
       };
       ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
-      rmap.Set(op, deps);
+    } else if (op.as<ScanOpNode>()) {
+      const ScanOpNode* scan = op.as<ScanOpNode>();
+      for (Tensor t : scan->init) {
+        deps.push_back(t);
+      }
+      for (Tensor t : scan->update) {
+        deps.push_back(t);
+      }
     } else if (op.as<PlaceholderOpNode>()) {
-      // empty set of deps
-      rmap.Set(op, deps);
     } else {
       LOG(FATAL) << "unknown Operation" << op->type_key();
     }
+    rmap.Set(op, deps);
+    for (Tensor t : deps) {
+      if (t->op.defined() && visited.count(t->op.get()) == 0) {
+        visited.insert(t->op.get());
+        stack.push_back(t->op);
+      }
+    }
   }
   return rmap;
 }
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index d7d514b0c..3975e4e90 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -146,6 +146,8 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) {  // NOLINT
 
 Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
   StageNode* self = operator->();
+  CHECK(!self->op.as<ScanOpNode>())
+      << "Cannot reorder axis of scan";
   ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
   ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
   std::vector<size_t> pos;
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index ed4ad7011..c69381967 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -7,7 +7,9 @@
 #include <tvm/ir_pass.h>
 #include <tvm/ir_visitor.h>
 #include <tvm/schedule_pass.h>
-
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
 #include "../pass/ir_util.h"
 #include "../arithmetic/compute_expr.h"
 #include "./graph.h"
@@ -18,6 +20,12 @@ namespace schedule {
 using namespace arith;
 using namespace ir;
 
+// Two private scope marks
+namespace attr {
+constexpr const char* loop_scope = "loop_scope";
+constexpr const char* scan_scope = "scan_scope";
+}  // namespace attr
+
 /*!
  * \brief message passing to find if IterVar is related to reduction.
  * \param s The stage to be used.
@@ -168,7 +176,6 @@ MakeLoopNest(const Stage& sch,
       value_map[iv] = iv->var;
       continue;
     }
-
     Range dom = dom_map.at(iv);
     // initialize the offset and loop_level
     Var var = iv->var;
@@ -223,7 +230,7 @@ MakeLoopNest(const Stage& sch,
     if (!reduce_init_loop) {
       // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          AttrStmt::make(iv, ir::attr::scope, iv->var, no_op));
+          AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
     }
   }
   // message passing to get offset of root iter vars.
@@ -307,8 +314,8 @@ Stmt MakeLoop(const Stage& s,
     init = Substitute(init, init_value_map);
     init  = MergeNest(init_nest, init);
     // common nest
-    std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop);
-    std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop, nest.end());
+    std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
+    std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
     provide = MergeNest(reduce, provide);
     return MergeNest(
         common, Block::make(init, provide));
@@ -340,6 +347,29 @@ Stmt MakeRealize(const ComputeOpNode* op,
                        bounds, make_const(Bool(1), true), body);
 }
 
+Stmt MakeRealize(const ScanOpNode* op,
+                 const Map<IterVar, Range>& dom_map,
+                 const std::vector<Tensor>& tensors,
+                 Stmt body) {
+  Range sdom = dom_map.at(op->scan_axis);
+  Range tdom = Range::make_with_min_extent(
+      0, ir::Simplify(sdom->extent + sdom->min));
+  size_t sp_idx = 0;
+  for (size_t i = 0; i < tensors.size(); ++i) {
+    const Tensor& t = tensors[i];
+    CHECK_EQ(static_cast<size_t>(t->value_index), i);
+    Halide::Internal::Region bounds;
+    bounds.push_back(tdom);
+    for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
+      IterVar sp_ax = op->spatial_axis_[sp_idx];
+      bounds.push_back(dom_map.at(sp_ax));
+    }
+    body = Realize::make(t->op, t->value_index, t->dtype,
+                         bounds, make_const(Bool(1), true), body);
+  }
+  return body;
+}
+
 
 void MakeReduction(const ComputeOpNode* op,
                    const std::vector<Tensor>& tensors,
@@ -382,12 +412,18 @@ Stmt MakePipeline(const Stage& s,
   Stmt init, provide;
 
   const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+  const ScanOpNode* scan = s->op.as<ScanOpNode>();
   if (compute) {
     if (compute->reduce_axis.size() == 0) {
       provide = MakeProvide(compute, tensors);
     } else {
       MakeReduction(compute, tensors, &init, &provide);
     }
+  } else if (scan) {
+    // Provide is done by the sub operations.
+    provide = AttrStmt::make(
+        s->op, attr::scan_scope, scan->scan_axis->var,
+        Evaluate::make(0));
   } else {
     LOG(FATAL) << "not supported op " << s->op->type_key();
   }
@@ -396,7 +432,12 @@ Stmt MakePipeline(const Stage& s,
   producer = ProducerConsumer::make(s->op, true, producer);
 
   Stmt pipeline = producer;
-  if (consumer.defined()) {
+  // check if consumer is nop.
+  bool is_no_op{false};
+  const Evaluate* ev = consumer.as<Evaluate>();
+  if (ev && ev->value.as<IntImm>()) is_no_op = true;
+
+  if (consumer.defined() && !is_no_op) {
     consumer = ProducerConsumer::make(s->op, false, consumer);
     pipeline = Block::make(producer, consumer);
   }
@@ -404,47 +445,103 @@ Stmt MakePipeline(const Stage& s,
   if (s->op.as<ComputeOpNode>()) {
     pipeline = MakeRealize(s->op.as<ComputeOpNode>(),
                            dom_map, tensors, pipeline);
+  } else if (s->op.as<ScanOpNode>()) {
+    pipeline = MakeRealize(s->op.as<ScanOpNode>(),
+                           dom_map, tensors, pipeline);
   } else {
     LOG(FATAL) << "not supported op";
-    return Stmt();
   }
   // use attribute to mark scope of the operation.
   pipeline = AttrStmt::make(
-      s->op, "realize_scope",
+      s->op, ir::attr::realize_scope,
       StringImm::make(s->scope),
       pipeline);
   return pipeline;
 }
 
 // inject the operator's realization on the stmt.
-class InjectRealize : public IRMutator {
+class InjectAttach : public IRMutator {
  public:
-  InjectRealize(Stage schedule, Map<IterVar, Range> dom_map)
-      : schedule(schedule), dom_map(dom_map) {}
+  InjectAttach(const Stage& stage,
+                const Map<IterVar, Range>& dom_map)
+      : stage_(stage), dom_map_(dom_map) {}
 
   Stmt Mutate(Stmt stmt) final {
     CHECK(stmt.defined());
     stmt =  IRMutator::Mutate(stmt);
     const AttrStmt* op = stmt.as<AttrStmt>();
     if (op != nullptr &&
-        op->type_key == "scope") {
-      if (op->node == schedule->attach_ivar) {
+        op->type_key == attr::loop_scope) {
+      if (op->node == stage_->attach_ivar) {
         CHECK(!found_attach);
         found_attach = true;
         stmt = AttrStmt::make(
             op->node, op->type_key, op->value,
-            MakePipeline(schedule, dom_map,
-                         IRMutator::Mutate(op->body)));
+            MakePipeline(stage_, dom_map_, op->body));
       }
     }
     return stmt;
   }
+  // whether attach point is found
+  bool found_attach{false};
+
+ private:
   // the operations to be carried
-  Stage schedule;
+  const Stage& stage_;
   // domain map
-  Map<IterVar, Range> dom_map;
+  const Map<IterVar, Range>& dom_map_;
+};
+
+// inject the operator's realization on the stmt.
+class InjectScanStep : public IRMutator {
+ public:
+  InjectScanStep(const Stage& stage,
+                 const Operation& scan_op,
+                 const Map<IterVar, Range>& dom_map,
+                 bool is_init)
+      : stage_(stage), scan_op_(scan_op),
+        dom_map_(dom_map), is_init_(is_init) {}
+
+  Stmt Mutate(Stmt stmt) final {
+    CHECK(stmt.defined());
+    stmt =  IRMutator::Mutate(stmt);
+    if (is_init_) {
+      const ProducerConsumer* op = stmt.as<ProducerConsumer>();
+      if (op != nullptr &&
+          op->is_producer &&
+          op->func.same_as(scan_op_)) {
+        stmt = ProducerConsumer::make(
+            op->func, true,
+            MakePipeline(stage_, dom_map_, op->body));
+        found_attach = true;
+      }
+    } else {
+      // update
+      const AttrStmt* op = stmt.as<AttrStmt>();
+      if (op != nullptr &&
+          op->type_key == attr::scan_scope) {
+        if (op->node.same_as(scan_op_)) {
+          found_attach = true;
+          stmt = AttrStmt::make(
+              op->node, op->type_key, op->value,
+              MakePipeline(stage_, dom_map_, op->body));
+        }
+      }
+    }
+    return stmt;
+  }
+
   // whether attach point is found
   bool found_attach{false};
+
+ private:
+  // the operations to be carried
+  const Stage& stage_;
+  const Operation& scan_op_;
+  // domain map
+  const Map<IterVar, Range>& dom_map_;
+  // whether it is init.
+  bool is_init_;
 };
 
 Stmt InjectInline(const Operation op, Stmt body) {
@@ -459,27 +556,180 @@ Stmt InjectInline(const Operation op, Stmt body) {
   return Inline(body, op, args, compute->body);
 }
 
+// Postprocessing of schedule op
+// Replace the init and update's expression by scan's buffer.
+class SchedulePostProc : public IRMutator {
+ public:
+  Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final {
+    if (to_remove_.count(op->func.get())) {
+      return this->Mutate(op->body);
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
+    if (!HasSideEffect(op->value)) {
+      var_value_[op->var.get()] = Mutate(op->value);
+      return this->Mutate(op->body);
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    if (op->type_key == attr::loop_scope) {
+      return this->Mutate(op->body);
+    } else if (op->type_key == attr::scan_scope) {
+      const ScanOpNode* scan = op->node.as<ScanOpNode>();
+      CHECK(scan);
+      var_value_[scan->scan_axis->var.get()] = op->value;
+      return this->Mutate(op->body);
+    } else if (op->type_key == ir::attr::realize_scope) {
+      if (to_remove_.count(op->node.get())) {
+        return this->Mutate(op->body);
+      }
+    }
+    return IRMutator::Mutate_(op, s);
+  }
+
+  Stmt Mutate_(const Realize* op, const Stmt& s) final {
+    TensorKey key{op->func, op->value_index};
+    if (replace_.count(key)) {
+      return this->Mutate(op->body);
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+
+  Stmt Mutate_(const Provide* op, const Stmt& s) final {
+    TensorKey key{op->func, op->value_index};
+    auto it = replace_.find(key);
+    if (it != replace_.end()) {
+      const Tensor& dst = it->second.first;
+      Stmt ret = Provide::make(
+          dst->op, dst->value_index, op->value,
+          RewriteArgs(it->second.second, op->args));
+      return IRMutator::Mutate_(ret.as<Provide>(), ret);
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+
+  Expr Mutate_(const Call* op, const Expr& e) final {
+    if (op != nullptr && op->call_type == Call::Halide) {
+      TensorKey key{op->func, op->value_index};
+      auto it = replace_.find(key);
+      if (it != replace_.end()) {
+        const Tensor& dst = it->second.first;
+        Expr ret = Call::make(
+            op->type, dst->op->name,
+            RewriteArgs(it->second.second, op->args),
+            op->call_type, dst->op, dst->value_index);
+        return IRMutator::Mutate_(ret.as<Call>(), ret);
+      }
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+
+  Expr Mutate_(const Variable* op, const Expr& e) final {
+    auto it = var_value_.find(op);
+    if (it != var_value_.end()) {
+      return it->second;
+    } else {
+      return e;
+    }
+  }
+
+  void Init(const Schedule& sch) {
+    for (Stage s : sch->stages) {
+      const ScanOpNode* scan = s->op.as<ScanOpNode>();
+      if (!scan) continue;
+      for (size_t i = 0; i < scan->update.size(); ++i) {
+        Tensor t = s->op.output(i);
+        AddReplace(scan->init[i], t, Expr());
+        AddReplace(scan->update[i], t, scan->scan_axis->var);
+        AddReplace(scan->state_placeholder[i], t, Expr());
+      }
+    }
+  }
+
+ private:
+  void AddReplace(Tensor src, Tensor dst, Expr head_idx) {
+    replace_[TensorKey{src->op, src->value_index}]
+        = std::make_pair(dst, head_idx);
+    to_remove_.insert(src->op.get());
+  }
+  Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
+    if (!head.defined()) return args;
+    Array<Expr> new_args{head};
+    for (Expr e : args) {
+      new_args.push_back(e);
+    }
+    return new_args;
+  }
+  // The scan value
+  std::unordered_map<const Variable*, Expr> var_value_;
+  // buffer replacement
+  std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_;
+  // replaced functions
+  std::unordered_set<const Node*> to_remove_;
+};
+
 Stmt ScheduleOps(
     Schedule sch, Map<IterVar, Range> dom_map) {
   Stmt body = Stmt();
+  // scan init and scan updates
+  std::unordered_map<Operation, std::pair<Operation, bool> > scan_attach;
+  for (Stage s : sch->stages) {
+    const ScanOpNode* scan = s->op.as<ScanOpNode>();
+    if (!scan) continue;
+    for (Tensor t : scan->init) {
+      if (scan_attach.count(t->op)) {
+        CHECK(scan_attach.at(t->op).first.same_as(s->op))
+            << "Scan init tensor can only belong to one scan";
+      } else {
+        scan_attach[t->op] = std::make_pair(s->op, true);
+      }
+    }
+    for (Tensor t : scan->update) {
+      if (scan_attach.count(t->op)) {
+        CHECK(scan_attach.at(t->op).first.same_as(s->op))
+            << "Scan update tensor can only belong to one scan";
+      } else {
+        scan_attach[t->op] = std::make_pair(s->op, false);
+      }
+    }
+  }
+
   // reverse the post DFS order.
   for (size_t i = sch->stages.size(); i != 0; --i) {
     Stage s = sch->stages[i - 1];
     // no need to specify place holder op.
     if (s->op.as<PlaceholderOpNode>()) continue;
-    if (s->attach_type == kInline) {
+    if (scan_attach.count(s->op)) {
+      CHECK(s->attach_type == kNone || s->attach_type == kInline)
+          << "Cannot specify compute_at for scan's init/update";
+      CHECK(body.defined());
+      const auto& p = scan_attach.at(s->op);
+      InjectScanStep mu(s, p.first, dom_map, p.second);
+      body = mu.Mutate(body);
+      CHECK(mu.found_attach)
+          << "did not find attachment point for scan.init/update";
+    } else if (s->attach_type == kInline) {
       body = InjectInline(s->op, body);
     } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
       body = MakePipeline(s, dom_map, body);
     } else if (s->attach_type == kScope) {
       CHECK(body.defined());
-      InjectRealize mutator(s, dom_map);
+      InjectAttach mutator(s, dom_map);
       body = mutator.Mutate(body);
       CHECK(mutator.found_attach)
           << "did not find attachment point";
     }
   }
-  return body;
+  SchedulePostProc post_proc;
+  post_proc.Init(sch);
+  return post_proc.Mutate(body);
 }
 
 }  // namespace schedule
diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py
new file mode 100644
index 000000000..38cd832f2
--- /dev/null
+++ b/tests/python/integration/test_scan.py
@@ -0,0 +1,54 @@
+import tvm
+import numpy as np
+
+def test_scan():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    t = tvm.IterVar((1, m), name="t")
+    X = tvm.placeholder((m, n), name="X")
+    s_state = tvm.placeholder((m, n))
+    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
+    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
+    res = tvm.scan(t, s_init, s_update, s_state)
+
+    # schedule
+    s = tvm.Schedule(res.op)
+    num_thread = 256
+    block_x = tvm.IterVar(thread_tag="blockIdx.x")
+    thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
+    _, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
+    _, x = s[s_init].split(x, outer=thread_x)
+    _, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
+    _, x = s[s_update].split(x, outer=thread_x)
+
+    # one line to build the function.
+    def check_device(target):
+        codes = []
+        fscan = tvm.build(s, [X, res],
+                          target, record_codes=codes,
+                          name="myscan")
+        if target == "cuda":
+            ctx = tvm.gpu(0)
+        else:
+            ctx = tvm.cl(0)
+        if not ctx.enabled:
+            return
+
+        for c in codes[1:]:
+            print(c)
+        # launch the kernel.
+        n = 1024
+        m = 10
+        a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
+        fscan(a, b)
+        np.testing.assert_allclose(
+            b.asnumpy(), np.cumsum(a_np, axis=0))
+
+    tvm.init_opencl()
+    check_device("cuda")
+
+
+if __name__ == "__main__":
+    test_scan()
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index 9d9115f5c..3459e80e9 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -34,6 +34,20 @@ def test_tensor_reduce():
     assert(str(C_loaded) == str(C))
 
 
+def test_tensor_scan():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    t = tvm.IterVar((1, m), "t")
+    x = tvm.placeholder((m, n))
+    s = tvm.placeholder((m, n))
+    res = tvm.scan(t,
+                   tvm.compute((1, n), lambda _, i: x[0, i]),
+                   tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
+                   s)
+    assert tuple(res.shape) == (m, n)
+
+
 if __name__ == "__main__":
     test_tensor()
     test_tensor_reduce()
+    test_tensor_scan()
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
index 9002b9686..197c4b1f4 100644
--- a/tests/python/unittest/test_pass_simplify.py
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -18,9 +18,15 @@ def test_simplify():
                                         tvm.make.Load(dtype, Ab.data, i + 4) + 1,
                                         (j + 1) * 4 - 4 * j + i),
                          None)))
-    print(stmt)
     stmt = tvm.ir_pass.CanonicalSimplify(stmt)
-    print(stmt)
+
+
+def test_basic():
+    m = tvm.Var('m')
+    ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
+    assert str(ret.value) == "(m - 1)"
+
 
 if __name__ == "__main__":
+    test_basic()
     test_simplify()
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index 9689a1c34..278d1cc53 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -6,13 +6,11 @@ def test_schedule0():
     l = tvm.Var('l')
     A = tvm.placeholder((m, l), name='A')
     A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
-
     s = tvm.Schedule(A1.op)
 
     bounds = tvm.schedule.InferBound(s)
     assert isinstance(bounds, tvm.collections.Map)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
-    print(stmt)
 
 def test_schedule1():
     m = tvm.Var('m')
@@ -25,7 +23,7 @@ def test_schedule1():
     bounds = tvm.schedule.InferBound(s)
     assert isinstance(bounds, tvm.collections.Map)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
-    print(stmt)
+
 
 def test_schedule2():
     m = tvm.Var('m')
@@ -40,25 +38,45 @@ def test_schedule2():
     bounds = tvm.schedule.InferBound(s)
     assert isinstance(bounds, tvm.collections.Map)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
+def test_schedule_scan():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    l = tvm.Var("l")
+    t = tvm.IterVar((1, m), name="t")
+    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
+    s_state = tvm.placeholder((m, n))
+    s_init = tvm.compute((1, n), lambda _, i: x[0, i])
+    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
+    res = tvm.scan(t, s_init, s_update, s_state)
+
+    assert tuple(res.shape) == (m, n)
+    s = tvm.Schedule(res.op)
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    assert(bounds[res.op.scan_axis].min.value == 1)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
     print(stmt)
 
+
 def test_auto_inline():
-  m = tvm.Var('m')
-  n = tvm.Var('n')
-  A = tvm.placeholder((m, n), name='A')
-  B = tvm.placeholder((m, n), name='B')
-  C = tvm.placeholder((m, n), name='C')
-  T1 = tvm.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='T1')
-  T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
+    m = tvm.Var('m')
+    n = tvm.Var('n')
+    A = tvm.placeholder((m, n), name='A')
+    B = tvm.placeholder((m, n), name='B')
+    C = tvm.placeholder((m, n), name='C')
+    T1 = tvm.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='T1')
+    T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
 
-  s = tvm.Schedule(T2.op)
-  tvm.schedule.AutoInlineElemWise(s)
-  bounds = tvm.schedule.InferBound(s)
-  stmt = tvm.schedule.ScheduleOps(s, bounds)
-  print(stmt)
+    s = tvm.Schedule(T2.op)
+    tvm.schedule.AutoInlineElemWise(s)
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
 
 
 if __name__ == "__main__":
+    test_schedule_scan()
     test_schedule0()
     test_schedule1()
     test_schedule2()
-- 
GitLab