From b90620ea25c3cd1e3742ff5e684acca7da00c441 Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Fri, 5 Oct 2018 21:30:43 -0700
Subject: [PATCH] [LANG] Generalize compute to tensor region (#1476)

---
 3rdparty/dmlc-core                            |   2 +-
 include/tvm/expr.h                            |   2 +
 include/tvm/operation.h                       |  72 +++-
 include/tvm/tensor_intrin.h                   |  53 +++
 python/tvm/api.py                             |  39 +-
 python/tvm/tensor.py                          |  12 +
 python/tvm/tensor_intrin.py                   |  28 +-
 src/api/api_lang.cc                           |  20 +
 src/lang/tensor.cc                            |  44 ++-
 src/op/compute_op.cc                          |  35 ++
 src/op/compute_op.h                           |  17 +-
 src/op/tensor_compute_op.cc                   | 361 ++++++++++++++++++
 src/op/tensorize.cc                           |  45 ---
 src/pass/arg_binder.cc                        |   4 +-
 src/schedule/schedule_dataflow_rewrite.cc     | 252 +++++++++---
 tests/python/unittest/test_lang_tensor.py     |  74 ++++
 .../unittest/test_schedule_schedule_ops.py    | 130 +++++++
 17 files changed, 1059 insertions(+), 131 deletions(-)
 create mode 100644 src/op/tensor_compute_op.cc

diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index 4f0564ec7..946a54012 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit 4f0564ec769477c66d480dd966088f172050c874
+Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 050ab4c33..e41f5f28d 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -108,6 +108,8 @@ class Range : public HalideIR::IR::Range {
   TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
 };
 
+using Region = Array<Range>;
+
 /*!
  * \brief Type of iteration variable.
  *  Each IterVar have a specific type.
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index c11242c0a..1a1d28ab7 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode {
   }
   /*!
    * \return The list of iteration variable at root
-   * \note root_iter_vars dedides the shape of the outputs.
+   * \note root_iter_vars decides the shape of the outputs.
    */
   virtual Array<IterVar> root_iter_vars() const = 0;
   /*!
@@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode {
   TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
 };
 
+/*!
+ * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
+ */
+class TensorComputeOpNode : public OperationNode {
+ public:
+  /*! \brief IterVar on each axis */
+  Array<IterVar> axis;
+  /*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */
+  Array<IterVar> reduce_axis;
+  /*! \brief number of axes that can be scheduled */
+  int schedulable_ndim;
+  /*! \brief TensorIntrin used to compute */
+  TensorIntrin intrin;
+  /*! \brief input tensors of intrin */
+  Array<Tensor> inputs;
+  /*! \brief region of input tensors */
+  Array<Region> input_regions;
+  /*! \brief constructor */
+  TensorComputeOpNode() {}
+  // override functions
+  int num_outputs() const final;
+  Array<IterVar> root_iter_vars() const final;
+  Type output_dtype(size_t i) const final;
+  Array<Expr> output_shape(size_t i) const final;
+  Array<Tensor> InputTensors() const final;
+  Operation ReplaceInputs(
+      const Operation& self,
+      const std::unordered_map<Tensor, Tensor>& rmap) const final;
+  void PropBoundToInputs(
+      const Operation& self,
+      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
+  void GatherBound(
+      const Operation& self,
+      const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+      std::unordered_map<IterVar, Range>* out_dom_map) const final;
+  Stmt BuildRealize(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& realize_map,
+      const Stmt& body) const final;
+  Stmt BuildProvide(
+      const Stage& stage,
+      const std::unordered_map<IterVar, Range>& dom_map,
+      bool debug_keep_trivial_loop) const final;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("name", &name);
+    v->Visit("tag", &tag);
+    v->Visit("axis", &axis);
+    v->Visit("reduce_axis", &reduce_axis);
+    v->Visit("schedulable_ndim", &schedulable_ndim);
+    v->Visit("intrin", &intrin);
+    v->Visit("inputs", &inputs);
+    v->Visit("input_regions", &input_regions);
+  }
+  static Operation make(std::string name,
+                        std::string tag,
+                        Array<IterVar> axis,
+                        Array<IterVar> reduce_axis,
+                        int schedulable_ndim,
+                        TensorIntrin intrin,
+                        Array<Tensor> tensors,
+                        Array<Region> regions);
+
+  static constexpr const char* _type_key = "TensorComputeOp";
+  TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode);
+};
+
 /*!
  * \brief Symbolic scan.
  */
@@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode {
  public:
   /*! \brief The input tensors */
   Array<Tensor> inputs;
-  /*! \brief Symbolic placeholder representationinputs */
+  /*! \brief Symbolic placeholder representation of inputs */
   Array<Buffer> input_placeholders;
   /*! \brief Symbolic placeholder representation of outputs */
   Array<Buffer> output_placeholders;
diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h
index 944498d1e..fbee4bccc 100644
--- a/include/tvm/tensor_intrin.h
+++ b/include/tvm/tensor_intrin.h
@@ -89,5 +89,58 @@ class TensorIntrinNode : public Node {
 inline const TensorIntrinNode* TensorIntrin::operator->() const {
   return static_cast<const TensorIntrinNode*>(node_.get());
 }
+
+
+// Internal node container of tensor intrinsic calling.
+class TensorIntrinCallNode;
+
+/*! \brief Tensor intrinsic calling node. */
+class TensorIntrinCall : public NodeRef {
+ public:
+  TensorIntrinCall() {}
+  explicit TensorIntrinCall(NodePtr<Node> n) : NodeRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const TensorIntrinCallNode* operator->() const;
+
+  /*! \brief specify container node */
+  using ContainerType = TensorIntrinCallNode;
+};
+
+class TensorIntrinCallNode : public Node {
+ public:
+  /*! \brief the tensor intrinsic */
+  TensorIntrin intrin;
+  /*! \brief input tensors of the intrinsic */
+  Array<Tensor> tensors;
+  /*! \brief regions of input tensors */
+  Array<Region> regions;
+  /*!
+   * \brief IterVar on each reduction axis, if the
+   * intrin will use the reduce axis
+   */
+  Array<IterVar> reduce_axis;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("intrin", &intrin);
+    v->Visit("tensors", &tensors);
+    v->Visit("regions", &regions);
+    v->Visit("reduce_axis", &reduce_axis);
+  }
+  static TensorIntrinCall make(TensorIntrin intrin,
+                               Array<Tensor> tensors,
+                               Array<Region> regions,
+                               Array<IterVar> reduce_axis);
+
+  static constexpr const char* _type_key = "TensorIntrinCall";
+  TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
+};
+
+inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
+  return static_cast<const TensorIntrinCallNode*>(node_.get());
+}
+
 }  // namespace tvm
 #endif  // TVM_TENSOR_INTRIN_H_
diff --git a/python/tvm/api.py b/python/tvm/api.py
index 8cf507de6..e275c1122 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
             raise ValueError("nested tag is not allowed for now")
         tag = _tag.TagScope.get_current().tag
     shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+    # for python3
+    shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
     ndim = len(shape)
     code = fcompute.__code__
 
-    if fcompute.__code__.co_argcount == 0:
+    out_ndim = ndim
+    if code.co_argcount == 0:
         arg_names = ["i%d" % i for i in range(ndim)]
     else:
         arg_names = code.co_varnames[:code.co_argcount]
+        out_ndim = code.co_argcount
 
-    if ndim != len(arg_names):
+    if out_ndim != len(arg_names):
         raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
 
-    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
+    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
     body = fcompute(*[v.var for v in dim_var])
-    if not isinstance(body, (list, tuple)):
-        body = [body]
-    body = convert(body)
-    op_node = _api_internal._ComputeOp(
-        name, tag, attrs, dim_var, body)
+
+    if isinstance(body, _tensor.TensorIntrinCall):
+        for i, s in enumerate(shape[out_ndim:]):
+            var_name = "ax" + str(i)
+            dim_var.append(_IterVar((0, s), var_name, 4))
+        op_node = _api_internal._TensorComputeOp(name,
+                                                 tag,
+                                                 dim_var,
+                                                 body.reduce_axis,
+                                                 out_ndim,
+                                                 body.intrin,
+                                                 body.tensors,
+                                                 body.regions)
+    else:
+        if not isinstance(body, (list, tuple)):
+            body = [body]
+        body = convert(body)
+        op_node = _api_internal._ComputeOp(
+            name, tag, attrs, dim_var, body)
+
     num = op_node.num_outputs
     outputs = tuple(op_node.output(i) for i in range(num))
     return outputs[0] if num == 1 else outputs
@@ -529,14 +548,14 @@ def decl_buffer(shape,
     dtype = float32 if dtype is None else dtype
     strides = () if strides is None else strides
     if offset_factor != 0 and elem_offset is None:
-        elem_offset = var('%s_elem_offset' % name, shape[0].dtype)
+        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
+        elem_offset = var('%s_elem_offset' % name, shape_dtype)
     if data is None:
         data = var(name, "handle")
     return _api_internal._Buffer(
         data, dtype, shape, strides, elem_offset, name, scope,
         data_alignment, offset_factor)
 
-
 def _IterVar(dom, name, iter_type, thread_tag=''):
     """Internal function to create IterVar
 
diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py
index f0d60f514..f32b70eb9 100644
--- a/python/tvm/tensor.py
+++ b/python/tvm/tensor.py
@@ -30,6 +30,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
         """Data content of the tensor."""
         return self.tensor.dtype
 
+@register_node
+class TensorIntrinCall(NodeBase):
+    """Intermediate structure for calling a tensor intrinsic."""
+    pass
+
 
 itervar_cls = None
 
@@ -106,6 +111,7 @@ class Tensor(NodeBase, _expr.ExprOp):
         return "%s.v%d" % (op.name, self.value_index)
 
 
+
 class Operation(NodeBase):
     """Represent an operation that generate a tensor"""
 
@@ -155,6 +161,12 @@ class ComputeOp(Operation):
         return self.__getattr__("reduce_axis")
 
 
+@register_node
+class TensorComputeOp(Operation):
+    """Tensor operation."""
+    pass
+
+
 @register_node
 class ScanOp(Operation):
     """Scan operation."""
diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py
index 193124b2f..f1f26655f 100644
--- a/python/tvm/tensor_intrin.py
+++ b/python/tvm/tensor_intrin.py
@@ -6,9 +6,25 @@ from . import expr as _expr
 from . import stmt as _stmt
 from . import make as _make
 from . import tensor as _tensor
+from . import schedule as _schedule
 from .build_module import current_build_config
 from ._ffi.node import NodeBase, register_node
 
+
+def _get_region(tslice):
+    region = []
+    for idx in tslice.indices:
+        if isinstance(idx, slice):
+            assert idx.step is None
+            region.append(_api.Range(idx.start, idx.stop))
+        else:
+            if isinstance(idx, _schedule.IterVar):
+                begin = idx.var
+            else:
+                begin = idx
+            region.append(_make.range_by_min_extent(begin, 1))
+    return region
+
 @register_node
 class TensorIntrin(NodeBase):
     """Tensor intrinsic functions for certain computation.
@@ -17,8 +33,16 @@ class TensorIntrin(NodeBase):
     --------
     decl_tensor_intrin: Construct a TensorIntrin
     """
-    pass
-
+    def __call__(self, *args, **kwargs):
+        tensors = [x.tensor for x in args]
+        regions = [_get_region(x) for x in args]
+        reduce_axis = []
+        if "reduce_axis" in kwargs:
+            reduce_axis = kwargs["reduce_axis"]
+            if not isinstance(reduce_axis, (list, tuple)):
+                reduce_axis = [reduce_axis]
+            reduce_axis = _api.convert(reduce_axis)
+        return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
 
 def decl_tensor_intrin(op,
                        fcompute,
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 8ca49f19b..75365da5b 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin")
                                   args[6]);
   });
 
+TVM_REGISTER_API("_TensorIntrinCall")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = TensorIntrinCallNode::make(args[0],
+                                      args[1],
+                                      args[2],
+                                      args[3]);
+  });
+
 TVM_REGISTER_API("_TensorEqual")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     *ret = args[0].operator Tensor() == args[1].operator Tensor();
@@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp")
                             args[7]);
   });
 
+TVM_REGISTER_API("_TensorComputeOp")
+.set_body([](TVMArgs args,  TVMRetValue* ret) {
+    *ret = TensorComputeOpNode::make(args[0],
+                                     args[1],
+                                     args[2],
+                                     args[3],
+                                     args[4],
+                                     args[5],
+                                     args[6],
+                                     args[7]);
+  });
+
 TVM_REGISTER_API("_ExternOp")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     *ret = ExternOpNode::make(args[0],
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
index 4f9c3e9d1..9b1a58abc 100644
--- a/src/lang/tensor.cc
+++ b/src/lang/tensor.cc
@@ -10,6 +10,8 @@
 
 namespace tvm {
 
+// Tensor
+
 Expr Tensor::operator()(Array<Var> indices) const {
   Array<Expr> arr(indices.begin(), indices.end());
   return operator()(arr);
@@ -26,6 +28,15 @@ Expr Tensor::operator()(Array<Expr> indices) const {
   return n;
 }
 
+Tensor Operation::output(size_t i) const {
+  auto node = make_node<TensorNode>();
+  node->op = *this;
+  node->value_index = i;
+  node->dtype = (*this)->output_dtype(i);
+  node->shape = (*this)->output_shape(i);
+  return Tensor(node);
+}
+
 Tensor TensorNode::make(Array<Expr> shape,
                         Type dtype,
                         Operation op,
@@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(TensorNode);
 
-Tensor Operation::output(size_t i) const {
-  auto node = make_node<TensorNode>();
-  node->op = *this;
-  node->value_index = i;
-  node->dtype = (*this)->output_dtype(i);
-  node->shape = (*this)->output_shape(i);
-  return Tensor(node);
-}
+
+// TensorIntrin
 
 TensorIntrin TensorIntrinNode::make(std::string name,
                                     Operation op,
@@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
+
+// TensorIntrinCall
+
+TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
+                                            Array<Tensor> tensors,
+                                            Array<Region> regions,
+                                            Array<IterVar> reduce_axis) {
+  auto n = make_node<TensorIntrinCallNode>();
+  n->intrin = std::move(intrin);
+  n->tensors = std::move(tensors);
+  n->regions = std::move(regions);
+  n->reduce_axis = std::move(reduce_axis);
+  return TensorIntrinCall(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
+    p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
+
 }  // namespace tvm
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index daafac21b..5c972595f 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -13,6 +13,7 @@
 #include "compute_op.h"
 #include "op_util.h"
 #include "../schedule/message_passing.h"
+#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 
@@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) {
   v.Run();
 }
 
+Stmt TransformUpdate(const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map,
+                     const ComputeLoopNest& n,
+                     Stmt body,
+                     Stmt update) {
+  Array<Expr> conds;
+  std::unordered_set<const Variable*> banned;
+  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
+    IterVar iv = stage->leaf_iter_vars[i];
+    auto iit = stage->iter_var_attrs.find(iv);
+    if (iit != stage->iter_var_attrs.end()) {
+      const IterVarAttr& attr = (*iit).second;
+      if (attr->iter_type == kTensorized) {
+        break;
+      }
+    }
+    if (iv->iter_type == kCommReduce) {
+      auto vit = dom_map.find(iv);
+      CHECK(vit != dom_map.end());
+      const Range& vrange = vit->second;
+      conds.push_back(likely(iv->var > vrange->min));
+      banned.insert(iv->var.get());
+    }
+  }
+  for (const Expr& pred : n.main_predicates) {
+    if (ir::ExprUseVar(pred, banned)) {
+      LOG(FATAL) << "Tensorize update transform failed, the condition "
+                 << pred << " has a conflict with the reset condition";
+    }
+  }
+
+  return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
+                          update, body);
+}
 }  // namespace tvm
diff --git a/src/op/compute_op.h b/src/op/compute_op.h
index 996764c6c..87b0814c1 100644
--- a/src/op/compute_op.h
+++ b/src/op/compute_op.h
@@ -14,7 +14,7 @@
 
 namespace tvm {
 // loop nest structure for general compute
-// This the the loop nest structured used in compute.
+// This the loop nest structured used in compute.
 // Does not include the loop body.
 struct ComputeLoopNest {
   // The common number of loops between init and main
@@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self,
                    const Stage& stage,
                    const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop);
+
+/*!
+ * \brief Transform the update part when there is no init func in tensorizing
+ * \param stage The stage for tensorizing.
+ * \param dom_map The range of each iter var.
+ * \param n The loop nest structured used in compute.
+ * \param body The body func in tensorize intrin
+ * \param update The update func in tensorize intrin
+ * \return Transformed result.
+ */
+Stmt TransformUpdate(const Stage& stage,
+                     const std::unordered_map<IterVar, Range>& dom_map,
+                     const ComputeLoopNest& n,
+                     Stmt body,
+                     Stmt update);
 }  // namespace tvm
 
 #endif  // TVM_OP_COMPUTE_OP_H_
diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc
new file mode 100644
index 000000000..f9b8188d4
--- /dev/null
+++ b/src/op/tensor_compute_op.cc
@@ -0,0 +1,361 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \brief Tensor Compute Op.
+ * \file tensor_compute_op.cc
+ */
+#include <tvm/operation.h>
+#include <tvm/arithmetic.h>
+#include <tvm/ir.h>
+#include <tvm/ir_visitor.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "./op_util.h"
+#include "./compute_op.h"
+#include "../arithmetic/compute_expr.h"
+
+namespace tvm {
+using namespace ir;
+// TensorComputeOpNode
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op,
+                                      IRPrinter *p) {
+    p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
+  });
+
+TVM_REGISTER_NODE_TYPE(TensorComputeOpNode);
+
+int TensorComputeOpNode::num_outputs() const {
+  return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
+}
+
+Array<IterVar> TensorComputeOpNode::root_iter_vars() const {
+  Array<IterVar> ret = axis;
+  for (IterVar iv : reduce_axis) {
+    ret.push_back(iv);
+  }
+  return ret;
+}
+
+Type TensorComputeOpNode::output_dtype(size_t i) const {
+  return this->intrin->buffers[this->inputs.size() + i]->dtype;
+}
+
+Array<Expr> TensorComputeOpNode::output_shape(size_t i) const {
+  Array<Expr> shape;
+  for (const auto& ivar : this->axis) {
+    shape.push_back(ivar->dom->extent);
+  }
+  return shape;
+}
+
+
+Operation TensorComputeOpNode::make(std::string name,
+                                    std::string tag,
+                                    Array<IterVar> axis,
+                                    Array<IterVar> reduce_axis,
+                                    int schedulable_ndim,
+                                    TensorIntrin intrin,
+                                    Array<Tensor> tensors,
+                                    Array<Region> regions) {
+  auto n = make_node<TensorComputeOpNode>();
+  n->name = std::move(name);
+  n->tag = std::move(tag);
+  n->axis = std::move(axis);
+  n->reduce_axis = std::move(reduce_axis);
+  n->schedulable_ndim = std::move(schedulable_ndim);
+  n->intrin = std::move(intrin);
+  n->inputs = std::move(tensors);
+  n->input_regions = std::move(regions);
+  return Operation(n);
+}
+
+Array<Tensor> TensorComputeOpNode::InputTensors() const {
+  return inputs;
+}
+
+Operation TensorComputeOpNode::ReplaceInputs(
+    const Operation& self,
+    const std::unordered_map<Tensor, Tensor>& rmap) const {
+  CHECK_EQ(self.operator->(), this);
+  auto n = make_node<TensorComputeOpNode>(*this);
+  auto intrin = make_node<TensorIntrinNode>(*(this->intrin.operator->()));
+  intrin->body = op::ReplaceTensor(this->intrin->body, rmap);
+  if (intrin->reduce_init.defined()) {
+    intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap);
+  }
+  if (intrin->reduce_update.defined()) {
+    intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap);
+  }
+  for (size_t i = 0; i < n->inputs.size(); ++i) {
+    Tensor t = n->inputs[i];
+    if (rmap.count(t)) {
+      n->inputs.Set(i, rmap.at(t));
+    }
+  }
+
+  if (intrin->body.same_as(n->intrin->body) &&
+      intrin->reduce_init.same_as(n->intrin->reduce_init) &&
+      intrin->reduce_update.same_as(n->intrin->reduce_update) &&
+      inputs.same_as(n->inputs)) {
+    return self;
+  } else {
+    n->intrin = TensorIntrin(intrin);
+    return Operation(n);
+  }
+}
+
+void TensorComputeOpNode::PropBoundToInputs(
+    const Operation& self,
+    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
+  for (size_t i = 0; i < this->inputs.size(); ++i) {
+    Tensor t = this->inputs[i];
+    Region region = input_regions[i];
+
+    auto it = out_dom_map->find(t);
+    if (it == out_dom_map->end()) continue;
+    TensorDom& dom = it->second;
+    for (size_t j = 0; j < t.ndim(); ++j) {
+      dom.data[j].emplace_back(EvalSet(region[j], dom_map));
+    }
+  }
+}
+
+void TensorComputeOpNode::GatherBound(
+    const Operation& self,
+    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
+    std::unordered_map<IterVar, Range>* out_dom_map) const {
+  const TensorDom& tdom = tensor_dom.at(self.output(0));
+  for (size_t i = 0; i < this->axis.size(); ++i) {
+    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
+    CHECK(!out_dom_map->count(this->axis[i]));
+    (*out_dom_map)[this->axis[i]] = r;
+  }
+  for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
+    CHECK(!out_dom_map->count(this->reduce_axis[i]));
+    (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
+  }
+}
+
+Stmt TensorComputeOpNode::BuildRealize(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& realize_map,
+    const Stmt& body) const {
+  CHECK_EQ(stage->op.get(), this);
+  HalideIR::Internal::Region bounds;
+  for (IterVar iv : this->axis) {
+    bounds.push_back(realize_map.at(iv));
+  }
+  Stmt realize = body;
+  for (int i = this->num_outputs(); i > 0; --i) {
+    Tensor t = stage->op.output(i-1);
+    realize = ir::Realize::make(t->op, t->value_index,
+      t->dtype, bounds, const_true(), realize);
+    // alignment requirement, only useful for compute
+    for (int i = 0; i < schedulable_ndim; ++i) {
+      auto it = stage->iter_var_attrs.find(this->axis[i]);
+      if (it != stage->iter_var_attrs.end()) {
+        IterVarAttr attr = (*it).second;
+        if (attr->dim_align_factor != 0) {
+          Array<Expr> tuple = {static_cast<int>(i),
+                               attr->dim_align_factor,
+                               attr->dim_align_offset};
+          realize = ir::AttrStmt::make(
+              t, ir::attr::buffer_dim_align,
+              Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
+              realize);
+        }
+      }
+    }
+  }
+  return realize;
+}
+
+ComputeLoopNest MakeLoopNest(
+    const TensorComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) {
+  CHECK_EQ(stage->op.operator->(), self);
+  ComputeLoopNest ret;
+  // make main loop nest
+  ret.main_nest = op::MakeLoopNest(
+      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
+      debug_keep_trivial_loop);
+  ret.main_predicates = schedule::MakeBoundCheck(
+      stage, dom_map, ret.main_vmap, false,
+      std::unordered_set<IterVar>());
+  for (auto& e : ret.main_predicates) {
+    e = likely(e);
+  }
+  if (stage->store_predicate.defined()) {
+    ret.main_predicates.push_back(stage->store_predicate);
+  }
+  if (self->reduce_axis.size() != 0) {
+    // try to find the location to insert the initialization.
+    // Fuse the initialization and provide loop when possible.
+    std::unordered_map<IterVar, int> update_state;
+    for (IterVar iv : self->reduce_axis) {
+      update_state[iv] = 2;
+    }
+    for (int i = 0; i < self->schedulable_ndim; ++i) {
+      update_state[self->axis[i]] = 1;
+    }
+    // find which iter var is related to reduction and which is related to axis.
+    schedule::PassDownBitMaskOr(stage, &update_state);
+    auto leaf_iter_vars = stage->leaf_iter_vars;
+    // first first loop that is related to reduction.
+    size_t begin_loop = leaf_iter_vars.size();
+    for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
+      auto iv = leaf_iter_vars[i];
+      int flag = update_state.at(iv);
+      if ((flag & 2) != 0) {
+        begin_loop = i; break;
+      }
+      ret.init_vmap[iv] = ret.main_vmap.at(iv);
+    }
+    ret.num_common_loop = begin_loop;
+    // skip loops that does not relates to axis.
+    std::unordered_set<IterVar> skip_iter;
+    for (auto kv : update_state) {
+      int flag = kv.second;
+      if ((flag & 1) == 0) skip_iter.insert(kv.first);
+    }
+    ret.init_nest = op::MakeLoopNest(
+        stage, dom_map, begin_loop, true,
+        skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
+    ret.init_predicates = schedule::MakeBoundCheck(
+        stage, dom_map, ret.init_vmap, true, skip_iter);
+    for (auto& e : ret.init_predicates) {
+      e = likely(e);
+    }
+  } else {
+    CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
+    ret.num_common_loop = stage->leaf_iter_vars.size();
+  }
+  // copy elison here.
+  return ret;
+}
+
+
+Stmt TensorComputeOpNode::BuildProvide(
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map,
+    bool debug_keep_trivial_loop) const {
+  CHECK_EQ(stage->op.operator->(), this);
+
+  // Start bind data.
+  Stmt nop = Evaluate::make(0);
+  std::vector<Stmt> input_bind_nest, output_bind_nest;
+  Array<Tensor> inputs = this->InputTensors();
+
+  // input binding
+  size_t num_inputs = inputs.size();
+  for (size_t i = 0; i < num_inputs; ++i) {
+    Tensor tensor = inputs[i];
+    Region region = this->input_regions[i];
+    Buffer buffer = this->intrin->buffers[i];
+    Array<NodeRef> bind_spec{buffer, tensor};
+
+    Array<Expr> tuple;
+    for (size_t i = 0; i < region.size(); ++i) {
+      tuple.push_back(region[i]->min);
+      tuple.push_back(region[i]->extent);
+    }
+    input_bind_nest.emplace_back(AttrStmt::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+  }
+
+  // output binding
+  for (int i = 0; i < this->num_outputs(); ++i) {
+    Tensor tensor = stage->op.output(i);
+    Buffer buffer = this->intrin->buffers[num_inputs + i];
+    Array<NodeRef> bind_spec{buffer, tensor};
+
+    Array<Expr> tuple;
+    for (size_t i = 0; i < this->axis.size(); ++i) {
+      auto ivar = this->axis[i];
+      if (i < static_cast<size_t>(this->schedulable_ndim)) {
+        tuple.push_back(ivar->var);
+        tuple.push_back(1);
+      } else {
+        Range dom = ivar->dom;
+        tuple.push_back(dom->min);
+        tuple.push_back(dom->extent);
+      }
+    }
+
+    output_bind_nest.emplace_back(AttrStmt::make(
+        bind_spec, ir::attr::buffer_bind_scope,
+        Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+  }
+
+  // Check variable remap
+  std::unordered_map<const Variable*, Expr> vmap;
+  ir::ArgBinder binder(&vmap);
+
+  size_t tloc = stage->leaf_iter_vars.size();
+  ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop);
+
+  if (this->reduce_axis.size() == 0) {
+    std::vector<std::vector<Stmt> > nest(
+        n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
+    nest.emplace_back(op::MakeIfNest(n.main_predicates));
+    CHECK_EQ(n.init_predicates.size(), 0U);
+    CHECK(this->intrin->body.defined())
+        << "Normal store op for intrin " << this << " is not defined";
+    Stmt body = MergeNest(output_bind_nest, this->intrin->body);
+    body = MergeNest(input_bind_nest, body);
+    body = ir::Substitute(body, vmap);
+    body = MergeNest(binder.asserts(), body);
+    body = op::Substitute(body, n.main_vmap);
+    Stmt ret =  MergeNest(nest, body);
+    return ret;
+  } else {
+    // Need to split reduction
+    CHECK(this->intrin->reduce_update.defined())
+        << "Reduction update op is not defined";
+    // Need init and update steps
+    CHECK_NE(this->reduce_axis.size(), 0U);
+    std::vector<std::vector<Stmt> > common(
+        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
+    std::vector<std::vector<Stmt> > update_nest(
+        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1);
+    update_nest.emplace_back(op::MakeIfNest(n.main_predicates));
+
+    if (this->intrin->reduce_init.defined()) {
+      // init nest
+      std::vector<std::vector<Stmt> > init_nest(
+          n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
+      init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
+      Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init);
+      init = op::Substitute(init, n.init_vmap);
+      init = MergeNest(init_nest, init);
+      // The update
+      Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = op::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, Block::make(init, update));
+    } else {
+      // When init op is not available, use body op for reset in the first iter.
+      CHECK(this->intrin->body.defined())
+          << "Normal body op is not defined";
+      Stmt update = TransformUpdate(stage, dom_map, n,
+                                    this->intrin->body,
+                                    this->intrin->reduce_update);
+      update = MergeNest(output_bind_nest, update);
+      update = MergeNest(input_bind_nest, update);
+      update = ir::Substitute(update, vmap);
+      update = MergeNest(binder.asserts(), update);
+      update = op::Substitute(update, n.main_vmap);
+      update = MergeNest(update_nest, update);
+      return MergeNest(common, update);
+    }
+  }
+}
+
+}  // namespace tvm
diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc
index 6daaedd16..a61aac422 100644
--- a/src/op/tensorize.cc
+++ b/src/op/tensorize.cc
@@ -10,7 +10,6 @@
 #include "op_util.h"
 #include "compute_op.h"
 #include "../schedule/message_passing.h"
-#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 
@@ -323,50 +322,6 @@ void VerifyTensorizeBody(
   }
 }
 
-/*!
- * \brief Transform the update part when there is no init func in tensorizing
- * \param stage The stage for tensorizing.
- * \param dom_map The range of each iter var.
- * \param n The loop nest structured used in compute. 
- * \param body The body func in tensorize intrin
- * \param update The update func in tensorize intrin
- * \return Transformed result.
- */
-Stmt TransformUpdate(const Stage& stage,
-                     const std::unordered_map<IterVar, Range>& dom_map,
-                     const ComputeLoopNest& n,
-                     Stmt body,
-                     Stmt update) {
-  Array<Expr> conds;
-  std::unordered_set<const Variable*> banned;
-  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
-    IterVar iv = stage->leaf_iter_vars[i];
-    auto iit = stage->iter_var_attrs.find(iv);
-    if (iit != stage->iter_var_attrs.end()) {
-      const IterVarAttr& attr = (*iit).second;
-      if (attr->iter_type == kTensorized) {
-        break;
-      }
-    }
-    if (iv->iter_type == kCommReduce) {
-      auto vit = dom_map.find(iv);
-      CHECK(vit != dom_map.end());
-      const Range& vrange = vit->second;
-      conds.push_back(likely(iv->var > vrange->min));
-      banned.insert(iv->var.get());
-    }
-  }
-  for (const Expr& pred : n.main_predicates) {
-    if (ir::ExprUseVar(pred, banned)) {
-      LOG(FATAL) << "Tensorize update transform failed, the condition "
-                 << pred << " has a conflict with the reset condition";
-    }
-  }
-
-  return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
-                          update, body);
-}
-
 Stmt MakeTensorize(const ComputeOpNode* self,
                    const Stage& stage,
                    const std::unordered_map<IterVar, Range>& dom_map,
diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc
index 0fac313c0..623886c31 100644
--- a/src/pass/arg_binder.cc
+++ b/src/pass/arg_binder.cc
@@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
   // bind pointer and offset.
   if (is_zero(arg->elem_offset)) {
     CHECK(is_zero(value->elem_offset))
-        << "Trying to bind a Buffer with offset into one without offset";
+        << "Trying to bind a Buffer with offset into one without offset "
+        << " required elem_offset=" << arg->elem_offset
+        << ", provided elem_offset=" << value->elem_offset;
   }
 
   this->Bind(arg->data, value->data, arg_name + ".data");
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index 8591c77bd..ccf7fd617 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor,
   return cache;
 }
 
-// Cache write and relayout the data according to loop pattern
-Array<Tensor> CacheWriteWithReLayout(Schedule sch,
-                              const Array<Tensor>& tensor_array,
-                              const std::string& scope) {
-  size_t tensor_size = tensor_array.size();
-  sch->InvalidateCache();
-  Tensor tensor = tensor_array[0];
-  Stage orig_stage = sch[tensor->op];
-  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
-  std::unordered_set<IterVar> red_axis;
-  for (IterVar iv : compute->reduce_axis) {
+template<typename OpType>
+void PrepareAxisMapping(Stage orig_stage,
+                        OpType* op,
+                        std::unordered_set<IterVar>* p_red_axis,
+                        Array<IterVar>* p_new_axis,
+                        std::unordered_map<IterVar, Range>* p_dom_map,
+                        std::unordered_map<const Variable*, Expr>* p_vsub,
+                        std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
+                        std::vector<Expr>* p_predicates) {
+  auto& red_axis = *p_red_axis;
+  auto& new_axis = *p_new_axis;
+  auto& dom_map = *p_dom_map;
+  auto& vsub = *p_vsub;
+  auto& vsub2newvar = *p_vsub2newvar;
+  auto& predicates = *p_predicates;
+
+  for (IterVar iv : op->reduce_axis) {
     red_axis.insert(iv);
   }
-  std::unordered_map<IterVar, Range> dom_map;
-  Array<IterVar> new_axis;
-
-  for (IterVar iv : compute->axis) {
+  for (IterVar iv : op->axis) {
     dom_map[iv] = iv->dom;
   }
   schedule::PassDownDomain(orig_stage, &dom_map, true);
-  std::unordered_map<const Variable*, Expr> vsub;
-  std::unordered_map<const Variable*, Expr> vsub2newvar;
-  std::vector<Expr> predicates;
   {
     // The source->cache
     std::unordered_map<IterVar, Expr> value_map;
@@ -178,17 +178,85 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
     }
     // skip reduction iteration.
     std::unordered_set<IterVar> skip_bound_check;
-    for (IterVar iv : compute->reduce_axis) {
+    for (IterVar iv : op->reduce_axis) {
       skip_bound_check.insert(iv);
     }
     schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
     predicates = schedule::MakeBoundCheck(
         orig_stage, dom_map, value_map, true, skip_bound_check);
     // The root axis
-    for (IterVar iv : compute->axis) {
-      vsub[iv->var.get()] = value_map.at(iv);
+    for (IterVar iv : op->axis) {
+      if (value_map.count(iv)) {
+        vsub[iv->var.get()] = value_map.at(iv);
+      }  // to handle tensor axis
     }
   }
+}
+
+Array<Tensor> ReplaceOriginalOp(Schedule sch,
+                                Stage orig_stage,
+                                const std::string& scope,
+                                Operation cache_op,
+                                Operation orig_new_op,
+                                size_t tensor_size) {
+  Array<Tensor> cache_tensor_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_tensor_list.push_back(cache_tensor);
+  }
+  // The replace of the dataflow
+  std::unordered_map<Tensor, Tensor> vmap;
+  std::unordered_map<Tensor, Tensor> rvmap;
+  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+  rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  for (size_t i = 0; i < tensor_size; i++) {
+    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  }
+  ReplaceDataFlow(sch->stages, &vmap, &rvmap);
+  // mutate orig stage
+  orig_stage->op = orig_new_op;
+  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
+  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
+  orig_stage->relations = Array<IterVarRelation>();
+  // create schedule for new cached stage.
+  ArrayNode* stages = sch->stages.CopyOnWrite();
+  size_t pos = FindNodeRef(stages, orig_stage);
+  Stage cache_stage = Stage(cache_op);
+  cache_stage.set_scope(scope);
+  CHECK_LT(pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + pos,
+                      cache_stage.node_);
+  sch->stage_map.Set(cache_op, cache_stage);
+  // Update group
+  cache_stage->group = orig_stage->group;
+  if (cache_stage->group.defined()) {
+    ++cache_stage->group->num_child_stages;
+  }
+  return cache_tensor_list;
+}
+
+
+// Cache write and relayout the data according to loop pattern
+Array<Tensor> CacheWriteWithReLayout(Schedule sch,
+                                     const Array<Tensor>& tensor_array,
+                                     const std::string& scope) {
+  size_t tensor_size = tensor_array.size();
+  sch->InvalidateCache();
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = sch[tensor->op];
+  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
+
+  std::unordered_set<IterVar> red_axis;
+  Array<IterVar> new_axis;
+  std::unordered_map<IterVar, Range> dom_map;
+
+  std::unordered_map<const Variable*, Expr> vsub;
+  std::unordered_map<const Variable*, Expr> vsub2newvar;
+  std::vector<Expr> predicates;
+
+  PrepareAxisMapping(orig_stage, compute,
+    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
 
   Expr body;
   Array<Expr> body_list;
@@ -198,7 +266,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
     body = InjectPredicate(predicates, body);
     body = VarReplacer(vsub2newvar).Mutate(body);
     // Reduce nodes in ONE computeOp must be the same except value_index
-    // This is right only if the oringinal body ensures Reduce nodes are the same
+    // This is right only if the original body ensures Reduce nodes are the same
     if (body->is_type<ir::Reduce>()) {
       const ir::Reduce* reduce_body = body.as<ir::Reduce>();
       if (first_reduce != nullptr) {
@@ -234,48 +302,107 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
   Operation cache_op = ComputeOpNode::make(
       compute->name + "." + scope, compute->tag, compute->attrs,
       new_axis, body_list);
-  Array<Tensor> cache_tensor_list;
+
   Array<Expr> cache_expr_list;
   for (size_t i = 0; i < tensor_size; i++) {
     Tensor cache_tensor = cache_op.output(i);
-    cache_tensor_list.push_back(cache_tensor);
     cache_expr_list.push_back(cache_tensor(args));
   }
   Operation orig_new_op = ComputeOpNode::make(
       compute->name, compute->tag, compute->attrs,
       compute->axis, cache_expr_list);
-  // The replace of the dataflow
-  std::unordered_map<Tensor, Tensor> vmap;
-  std::unordered_map<Tensor, Tensor> rvmap;
-  vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
-  rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
-  for (size_t i = 0; i < tensor_size; i++) {
-    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
-    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  return ReplaceOriginalOp(sch, orig_stage, scope,
+    cache_op, orig_new_op, tensor_size);
+}
+
+
+// for tensor compute op
+Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
+                                           const Array<Tensor>& tensor_array,
+                                           const std::string& scope) {
+  size_t tensor_size = tensor_array.size();
+  sch->InvalidateCache();
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = sch[tensor->op];
+  const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
+  CHECK_EQ(tensor_op->num_outputs(), 1)
+      << "cache write only support single output tensor_compute_op";
+
+  std::unordered_set<IterVar> red_axis;
+  Array<IterVar> new_axis;
+  std::unordered_map<IterVar, Range> dom_map;
+
+  std::unordered_map<const Variable*, Expr> vsub;
+  std::unordered_map<const Variable*, Expr> vsub2newvar;
+  std::vector<Expr> predicates;
+
+  PrepareAxisMapping(orig_stage, tensor_op,
+    &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
+
+
+  for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
+    IterVar iv = tensor_op->axis[i];
+    IterVar new_iv = IterVarNode::make(
+      iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+    new_axis.push_back(new_iv);
+  }
+  Array<Region> new_regions;
+  for (Region old_region : tensor_op->input_regions) {
+    Region region;
+    for (Range r : old_region) {
+      Expr min = VarReplacer(vsub2newvar).Mutate(r->min);
+      Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent);
+      region.push_back(Range::make_by_min_extent(min, extent));
+    }
+    new_regions.push_back(region);
   }
-  ReplaceDataFlow(sch->stages, &vmap, &rvmap);
-  // mutate orig stage
-  orig_stage->op = orig_new_op;
-  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
-  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
-  orig_stage->relations = Array<IterVarRelation>();
-  // create schedule for new cached stage.
-  ArrayNode* stages = sch->stages.CopyOnWrite();
-  size_t pos = FindNodeRef(stages, orig_stage);
-  Stage cache_stage = Stage(cache_op);
-  cache_stage.set_scope(scope);
-  CHECK_LT(pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + pos,
-                      cache_stage.node_);
-  sch->stage_map.Set(cache_op, cache_stage);
-  // Update group
-  cache_stage->group = orig_stage->group;
-  if (cache_stage->group.defined()) {
-    ++cache_stage->group->num_child_stages;
+
+  Operation cache_op = TensorComputeOpNode::make(
+      tensor_op->name + "." + scope, tensor_op->tag, new_axis,
+      tensor_op->reduce_axis, tensor_op->schedulable_ndim,
+      tensor_op->intrin, tensor_op->inputs, new_regions);
+
+  // axis will be used in generating compute op
+  Array<IterVar> compute_axis = tensor_op->axis;
+  for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+    IterVar iv = tensor_op->axis[i];
+    IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar);
+    compute_axis.Set(i, aiv);
   }
-  return cache_tensor_list;
+
+  // The reader args
+  Array<Expr> args;
+  {
+    // cache->compute
+    std::unordered_map<IterVar, Expr> value_map;
+    for (IterVar iv : compute_axis) {
+      value_map[iv] = iv->var;
+    }
+    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      args.push_back(value_map.at(iv));
+    }
+    // tensorized region axis
+    for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
+      IterVar iv = compute_axis[i];
+      args.push_back(value_map.at(iv));
+    }
+  }
+
+  Array<Expr> cache_expr_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_expr_list.push_back(cache_tensor(args));
+  }
+  Operation orig_new_op = ComputeOpNode::make(
+      tensor_op->name, tensor_op->tag, {},
+      compute_axis, cache_expr_list);
+  return ReplaceOriginalOp(sch, orig_stage, scope,
+    cache_op, orig_new_op, tensor_size);
 }
 
+
 Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
                              const std::string& scope) {
   (*this)->InvalidateCache();
@@ -291,23 +418,26 @@ Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
     CHECK(orig_stage.same_as(tmp_stage))
         << "Input tensor list must be generated by ONE computeOp";
   }
-
   return CacheWriteWithReLayout(*this, tensor_array, scope);
 }
 
+
 Tensor Schedule::cache_write(const Tensor& tensor,
                              const std::string& scope) {
+  // support original compute and tensor compute both
   (*this)->InvalidateCache();
-  Stage orig_stage = operator[](tensor->op);
-  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
-  CHECK(compute)
-      << "cache write only take ComputeOp as writers";
-  CHECK_EQ(compute->num_outputs(), 1)
-      << "cache write only support single output ComputeOp";
-
-  return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
+  const char* type_key = tensor->op->type_key();
+  if (!strcmp(type_key, "ComputeOp")) {
+    return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
+  } else if (!strcmp(type_key, "TensorComputeOp")) {
+    return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
+  } else {
+    LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
+    return Tensor();
+  }
 }
 
+
 void RebaseNonZeroMinLoop(const Schedule& sch) {
   std::unordered_map<IterVar, IterVar> rebase_map;
   for (Stage s : sch->stages) {
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index f562a48e4..50492ca41 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -85,6 +85,78 @@ def test_tensor_reduce():
     assert(isinstance(C_loaded, tvm.tensor.Tensor))
     assert(str(C_loaded) == str(C))
 
+def test_tensor_compute1():
+    m = 1024
+    factor = 16
+    dtype = 'float32'
+
+    def intrin_vadd(n):
+        x = tvm.placeholder((n,))
+        y = tvm.placeholder((n,))
+        z = tvm.compute(x.shape, lambda i: x[i] + y[i])
+
+        def intrin_func(ins, outs):
+            ib = tvm.ir_builder.create()
+            ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
+            return ib.get()
+
+        with tvm.build_config(offset_factor=n):
+            return tvm.decl_tensor_intrin(z.op, intrin_func)
+
+    vadd = intrin_vadd(factor)
+
+    A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype)
+    B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype)
+    C = tvm.compute((m//factor, factor),
+          lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
+
+    s = tvm.create_schedule(C.op)
+    stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+    assert isinstance(stmt.body.body, tvm.stmt.Evaluate)
+
+def test_tensor_compute2():
+    M = 2048
+    N = 1024
+    L = 1024
+    factor = 16
+    factor1 = 32
+    factor2 = 32
+    dtype = 'float32'
+
+    def intrin_gemm(m, n, l):
+        k = tvm.reduce_axis((0, l))
+        x = tvm.placeholder((m, l))
+        y = tvm.placeholder((n, l))
+        # in theory, no relation
+        z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k))
+
+        def intrin_func(ins, outs):
+            x_ptr = ins[0].access_ptr("r")
+            y_ptr = ins[1].access_ptr("r")
+            z_ptr = outs[0].access_ptr("w")
+            body = tvm.call_packed(
+                "gemv", x_ptr, y_ptr, z_ptr, m, n, l)
+            reset = tvm.call_packed(
+                "fill_zero", z_ptr, m, n)
+            update = tvm.call_packed(
+                "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
+            return body, reset, update
+
+        with tvm.build_config(offset_factor=n):
+            return tvm.decl_tensor_intrin(z.op, intrin_func)
+
+    vgemm = intrin_gemm(factor1, factor2, factor)
+
+    A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype)
+    B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype)
+    k = tvm.reduce_axis((0, L//factor), name='k')
+    C = tvm.compute((M//factor1, N//factor2, factor1, factor2),
+          lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k))
+
+    s = tvm.create_schedule(C.op)
+    stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+    assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate)
+    assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate)
 
 def test_tensor_scan():
     m = tvm.var("m")
@@ -221,6 +293,8 @@ if __name__ == "__main__":
     test_conv1d()
     test_tensor_slice()
     test_tensor()
+    test_tensor_compute1()
+    test_tensor_compute2()
     test_tensor_reduce()
     test_tensor_scan()
     test_scan_multi_out()
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index 8e6f4090d..8774514cf 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -276,6 +276,133 @@ def test_schedule_bound_condition():
    stmt = tvm.ir_pass.Simplify(stmt)
    assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))
 
+
+def intrin_gemv(m, n):
+    w = tvm.placeholder((m, n), name='w')
+    x = tvm.placeholder((n,), name='x')
+    k = tvm.reduce_axis((0, n), name='k')
+    z = tvm.compute((m,), lambda i:
+                    tvm.sum(w[i, k] * x[k], axis=k), name='z')
+    Wb = tvm.decl_buffer(w.shape, w.dtype,
+                         name="W",
+                         offset_factor=16,
+                         strides=[tvm.var('ldw'), 1])
+    def intrin_func(ins, outs):
+        ww, xx = ins
+        zz = outs[0]
+        ww_ptr = ww.access_ptr("r")
+        xx_ptr = xx.access_ptr("r")
+        zz_ptr = zz.access_ptr("w")
+        body = tvm.call_packed(
+            "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        reset = tvm.call_packed(
+            "fill_zero", zz_ptr, n)
+        update = tvm.call_packed(
+            "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
+        return body, reset, update
+
+    with tvm.build_config(data_alignment=16,
+                          offset_factor=16):
+        return tvm.decl_tensor_intrin(z.op, intrin_func,
+                                      binds={w: Wb})
+
+
+def test_schedule_tensor_compute1():
+    # basic: split, reorder, tile
+    M, N, L = 2048, 1024, 512
+    factor, rfactor = 16, 16
+    A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A')
+    B = tvm.placeholder((M, L//rfactor, rfactor), name='B')
+    k = tvm.reduce_axis((0, L//rfactor), name='k')
+
+    gemv = intrin_gemv(factor, rfactor)
+    C = tvm.compute((N, M//factor, factor),
+        lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k),
+        name='C')
+
+    s = tvm.create_schedule(C.op)
+    ai, aj, ax = s[C].op.axis
+    aio, aii = s[C].split(ai, 16)
+    s[C].reorder(aio, aj, aii)
+    aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4)
+
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
+def intrin_vadd(n, cache_read=False, cache_write=False):
+    scope_ubuf = 'local'
+    dtype = 'float32'
+    x = tvm.placeholder((n,), dtype=dtype, name='vx')
+    y = tvm.placeholder((n,), dtype=dtype, name='vy')
+    z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
+    s = tvm.create_schedule(z.op)
+
+    def create_buffer(t):
+        return tvm.decl_buffer(t.shape, t.dtype,
+                               name='W'+t.name,
+                               scope=scope_ubuf,
+                               offset_factor=16)
+
+    binds = {}
+    if cache_read:
+        binds[x] = create_buffer(x)
+        binds[y] = create_buffer(y)
+    if cache_write:
+        binds[z] = create_buffer(z)
+
+    def intrin_func(ins, outs):
+        ib = tvm.ir_builder.create()
+        ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
+        return ib.get()
+
+    with tvm.build_config(offset_factor=16):
+        return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
+
+
+def test_schedule_tensor_compute2():
+    # cache_read, cache_write
+    M = 1024
+    factor = 16
+    dtype = 'float32'
+    scope_ubuf = 'local'
+
+    A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
+    B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
+
+    vadd = intrin_vadd(factor, True, True)
+    C = tvm.compute((M//factor, factor),
+        lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C')
+
+    s = tvm.create_schedule(C.op)
+    AL = s.cache_read(A, scope_ubuf, C)
+    BL = s.cache_read(B, scope_ubuf, C)
+    CL = s.cache_write(C, scope_ubuf)
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
+def test_schedule_tensor_compute3():
+    # compute_at
+    M = 1024
+    factor = 16
+    dtype = 'float32'
+    A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype)
+    B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype)
+    Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi")
+
+    vadd = intrin_vadd(factor)
+    C = tvm.compute((M//factor, factor),
+        lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C')
+    s = tvm.create_schedule(C.op)
+    s[Bi].compute_at(s[C], C.op.axis[0])
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
 if __name__ == "__main__":
     test_schedule_middle_cache()
     test_inline_multi_reduce()
@@ -294,3 +421,6 @@ if __name__ == "__main__":
     test_schedule2()
     test_schedule_cache()
     test_schedule_bound_condition()
+    test_schedule_tensor_compute1()
+    test_schedule_tensor_compute2()
+    test_schedule_tensor_compute3()
-- 
GitLab