From 0c72ca970e5e3862bd2914bc9802335618ce15f0 Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Mon, 5 Dec 2016 11:21:32 +0100
Subject: [PATCH] Finish schedule operation

---
 HalideIR                      |  2 +-
 include/tvm/schedule.h        |  8 ++-
 python/tvm/function.py        |  8 +--
 python/tvm/schedule.py        | 99 +++++++++++++++++++++++++++++++++--
 python/tvm/tensor.py          |  2 +
 src/c_api/c_api_lang.cc       | 49 +++++++++++++++++
 src/c_api/c_api_registry.h    |  2 +-
 src/lang/expr.cc              |  4 +-
 src/lang/schedule.cc          | 94 ++++++++++++++++++++++++++++++---
 tests/python/test_schedule.py | 36 ++++++++-----
 10 files changed, 272 insertions(+), 32 deletions(-)

diff --git a/HalideIR b/HalideIR
index 29fd3defa..ea1a81be8 160000
--- a/HalideIR
+++ b/HalideIR
@@ -1 +1 @@
-Subproject commit 29fd3defa3dbf810e52dbc2ecd3933604989dcc8
+Subproject commit ea1a81be8baa43665f6ebd4d75d51c081283ebc8
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index ad23f5135..259ef2d17 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -50,16 +50,19 @@ class Schedule : public NodeRef {
    * \brief specify the schedule to be computed at the parent schedule's scope.
    * \param parent The parent schedule.
    * \param scope The iteration point to carry the schedule.
+   * \return reference to self.
    */
   Schedule& compute_at(Schedule parent, IterVar scope);   // NOLINT(*)
   /*!
    * \brief Compute the function inline, attach it at parent.
    * \param parent The parent schedule to be attached to.
+   * \return reference to self.
    */
   Schedule& compute_inline(Schedule parent);   // NOLINT(*)
   /*!
    * \brief Compute the function at root, attach it to its parent.
    * \param parent The parent schedule to be attached to.
+   * \return reference to self.
    */
   Schedule& compute_root(Schedule parent);  // NOLINT(*)
   /*!
@@ -68,7 +71,7 @@ class Schedule : public NodeRef {
    * \param p_outer The result outer domain
    * \param p_inner The result inner domain.
    * \param factor The split factor of the loop.
-   * \param outer The generated
+   * \return reference to self.
    */
   Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor);  // NOLINT(*)
   /*!
@@ -80,6 +83,7 @@ class Schedule : public NodeRef {
    * \param p_inner The result inner domain.
    * \param factor Optional, the factor of the split,
    *  factor must be provided such that factor * outer.extent >= parent.extent.
+   * \return reference to self.
    */
   Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr());   // NOLINT(*)
   /*!
@@ -87,11 +91,13 @@ class Schedule : public NodeRef {
    * \param inner The inner domain to be fused
    * \param outer The outer domain to be fused.
    * \param p_target The result target domain.
+   * \return reference to self.
    */
   Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target);  // NOLINT(*)
   /*!
    * \brief Reorder the iteration
    * \param order The order of iteration variable.
+   * \return reference to self.
    */
   Schedule& reorder(const Array<IterVar>& order);   // NOLINT(*)
 };
diff --git a/python/tvm/function.py b/python/tvm/function.py
index b1f91bd44..7088d051a 100644
--- a/python/tvm/function.py
+++ b/python/tvm/function.py
@@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"):
     tensor: tensor.Tensor
         The created tensor
     """
+    if isinstance(shape, _expr.Expr):
+        shape = (shape, )
+
     ndim = len(shape)
     arg_names = fcompute.__code__.co_varnames
     if ndim != len(arg_names):
@@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"):
 
     dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
     body = fcompute(*[v.var for v in dim_var])
+    body = convert(body)
     op_node = _function_internal._ComputeOp(
         name, dim_var, body)
     return _function_internal._Tensor(
@@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"):
     return _function_internal._Schedule(tensor, scope)
 
 
-def Split(dim, factor, over_rdom=False):
-    return _function_internal._DimSplit(dim, factor, over_rdom)
-
-
 _init_function_module("tvm")
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 5e368b8b1..dee3f3309 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node
 from . import _function_internal
 
 @register_node
-class DimSplit(NodeBase):
+class Split(NodeBase):
     pass
 
 @register_node
-class AttachSpec(NodeBase):
+class Fuse(NodeBase):
     pass
 
 @register_node
 class Schedule(NodeBase):
-    pass
+    def split(self, parent, factor=None, outer=None):
+        """Split the schedule either by factor providing outer scope, or both
+
+        Parameters
+        ----------
+        parent : IterVar
+             The parent iter var.
+
+        factor : Expr, optional
+             The splitting factor
+
+        outer : IterVar, optional
+             The outer split variable
+
+        Returns
+        -------
+        outer : IterVar
+            The outer variable of iteration.
+
+        inner : IterVar
+            The inner variable of iteration.
+        """
+        if outer is not None:
+            if outer.thread_tag == '':
+                raise ValueError("split by outer must have special thread_tag")
+            if outer.dom is None:
+                raise ValueError("split by outer must have specified domain")
+            inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor)
+        else:
+            if factor is None:
+                raise ValueError("either outer or factor need to be provided")
+            outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor)
+        return outer, inner
+
+    def fuse(self, inner, outer):
+        """Fuse inner and outer to a single iteration variable.
+
+        Parameters
+        ----------
+        outer : IterVar
+            The outer variable of iteration.
+
+        inner : IterVar
+            The inner variable of iteration.
+
+        Returns
+        -------
+        inner : IterVar
+            The fused variable of iteration.
+        """
+        return _function_internal._ScheduleFuse(self, inner, outer)
+
+    def compute_at(self, parent, scope):
+        """Attach the schedule at parent's scope
+
+        Parameters
+        ----------
+        parent : Schedule
+            The parent schedule
+
+        scope : IterVar
+            The loop scope t be attached to.
+        """
+        _function_internal._ScheduleComputeAt(self, parent, scope)
+
+    def compute_inline(self, parent):
+        """Attach the schedule at parent, and mark it as inline
+
+        Parameters
+        ----------
+        parent : Schedule
+            The parent schedule
+        """
+        _function_internal._ScheduleComputeInline(self, parent)
+
+    def compute_root(self, parent):
+        """Attach the schedule at parent, and mark it as root
+
+        Parameters
+        ----------
+        parent : Schedule
+            The parent schedule
+        """
+        _function_internal._ScheduleComputeInline(self, parent)
+
+    def reorder(self, *args):
+        """reorder the arguments in the specified order.
+
+        Parameters
+        ----------
+        args : list of IterVar
+            The order to be ordered
+        """
+        _function_internal._ScheduleReorder(self, args)
diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py
index ab1bdbe90..957080006 100644
--- a/python/tvm/tensor.py
+++ b/python/tvm/tensor.py
@@ -7,6 +7,8 @@ from . import expr as _expr
 class TensorSlice(SliceBase, _expr.ExprOp):
     """Auxiliary data structure for enable slicing syntax from tensor."""
     def __init__(self, tensor, indices):
+        if not isinstance(indices, tuple):
+            indices = (indices,)
         self.tensor = tensor
         self.indices = indices
 
diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc
index af6ce62e6..6ee137a7b 100644
--- a/src/c_api/c_api_lang.cc
+++ b/src/c_api/c_api_lang.cc
@@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule)
     *ret = Schedule(args.at(0), args.at(1));
   });
 
+TVM_REGISTER_API(_ScheduleSplitByFactor)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    IterVar outer, inner;
+    args.at(0).operator Schedule()
+        .split(args.at(1), &outer, &inner, args.at(2));
+    *ret = Array<IterVar>({outer, inner});
+  });
+
+TVM_REGISTER_API(_ScheduleSplitByOuter)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    IterVar inner;
+    args.at(0).operator Schedule()
+        .split(args.at(1), args.at(2), &inner, args.at(3));
+    *ret = inner;
+  });
+
+TVM_REGISTER_API(_ScheduleFuse)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    IterVar fused;
+    args.at(0).operator Schedule()
+        .split(args.at(1), args.at(2), &fused);
+    *ret = fused;
+  });
+
+TVM_REGISTER_API(_ScheduleComputeAt)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    args.at(0).operator Schedule()
+        .compute_at(args.at(1), args.at(2));
+  });
+
+TVM_REGISTER_API(_ScheduleComputeInline)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    args.at(0).operator Schedule()
+        .compute_inline(args.at(1));
+  });
+
+TVM_REGISTER_API(_ScheduleComputeRoot)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    args.at(0).operator Schedule()
+        .compute_root(args.at(1));
+  });
+
+TVM_REGISTER_API(_ScheduleReorder)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    args.at(0).operator Schedule()
+        .reorder(args.at(1));
+  });
+
+
 }  // namespace tvm
diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h
index 3e57d43c2..9edce8ea6 100644
--- a/src/c_api/c_api_registry.h
+++ b/src/c_api/c_api_registry.h
@@ -115,7 +115,7 @@ class APIVariantValue {
     CHECK_EQ(type_id, kNodeHandle);
     // use dynamic RTTI for safety
     CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
-        << "wrong type specified";
+        << "wrong type specified, expected " << typeid(typename T::ContainerType).name();
     return T(sptr);
   }
   inline operator Expr() const {
diff --git a/src/lang/expr.cc b/src/lang/expr.cc
index df2350148..3082bd579 100644
--- a/src/lang/expr.cc
+++ b/src/lang/expr.cc
@@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
     if (op->var->name_hint.length() != 0) {
       p->stream  << op->var->name_hint << ", ";
     }
-    p->stream << op->dom;
+    if (op->dom.defined()) {
+      p->stream << op->dom;
+    }
     if (op->thread_tag.length() != 0) {
       p->stream << ", " << op->thread_tag;
     }
diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc
index beb5ff1f1..47f5ee744 100644
--- a/src/lang/schedule.cc
+++ b/src/lang/schedule.cc
@@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) {
   return array_node->data.size();
 }
 
-size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* const IterVar& v) {
-  size_t pos = Find(leaf_iter_vars, parent);
+size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
+  size_t pos = FindIterVar(leaf_vars, v);
+  if (pos < leaf_vars->data.size()) return pos;
+
+  if (FindIterVar(all_vars, v) < all_vars->data.size()) {
+    LOG(FATAL) << "Operate on iter var " << v
+               << "that has already been splitted";
+  } else {
+    LOG(FATAL) << "Operate on iter var " << v
+               << "that is not part of the schedule";
+  }
+  return 0;
 }
 
+void Split(ScheduleNode* self, IterVar parent,
+           IterVar outer, IterVar inner, Expr factor) {
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
+
+  self->relations.push_back(SplitNode::make(parent, outer, inner, factor));
+  // add vars to all vars
+  all_vars->data.push_back(outer.node_);
+  all_vars->data.push_back(inner.node_);
+  // replace the position.
+  leaf_vars->data.erase(leaf_vars->data.begin() + pos);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner.node_);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer.node_);
 }
 
+}  // namespace
+
 Schedule::Schedule(Operation op, std::string scope) {
   auto n = std::make_shared<ScheduleNode>();
   n->op = op;
@@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) {   // NOLINT(*)
   CHECK_EQ((*this)->attach_type, kNone);
   (*this)->attach_type = kScope;
   (*this)->attach_parent = scope;
+  bool found = false;
+  for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
+    if (scope == parent->leaf_iter_vars[i]) {
+      found = true; break;
+    }
+  }
+  CHECK(found)
+      << "Cannot compute at a iteration variable that is not part of parent leaf vars";
   parent->children.push_back(*this);
   return *this;
 }
@@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) {   // NOLINT(*)
 
 Schedule& Schedule::split(
     IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) {  // NOLINT(*)
-  ScheduleNode* self = operator->();
-  ArrayNode* leaf_iter_vars = self->leaf_iter_vars.CopyOnWrite();
+  // place holder for the splitted results.
+  IterVar outer(Range(), parent->var->name_hint + ".outer");
+  IterVar inner(Range(), parent->var->name_hint + ".inner");
+  *p_outer = outer; *p_inner = inner;
 
-  CHECK(pos != leaf_iter_vars->data.size())
-      << "Cannot find IterVar " << parent << " in the active leaf vars"
-      << " this means "
+  Split(operator->(), parent, outer, inner, factor);
+  return *this;
+}
+
+Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*)
+  // place holder for the splitted results.
+  IterVar inner(Range(), parent->var->name_hint + ".inner");
+  *p_inner = inner;
+  Split(operator->(), parent, outer, inner, factor);
 
   return *this;
 }
 
+Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) {  // NOLINT(*)
+  IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
+  ScheduleNode* self = operator->();
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+
+  self->relations.push_back(FuseNode::make(inner, outer, fused));
+  all_vars->data.push_back(fused.node_);
+
+  size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner);
+  size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer);
+  CHECK_EQ(pos_inner, pos_outer + 1)
+      << "Can only fuse iterations that are consecutive between each other";
+  leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
+                        leaf_vars->data.begin() + pos_inner);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
+                         fused.node_);
+  return *this;
+}
+
+Schedule& Schedule::reorder(const Array<IterVar>& order) {  // NOLINT(*)
+  ScheduleNode* self = operator->();
+  ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
+  std::vector<size_t> pos;
 
+  for (size_t i = 0; i < order.size(); ++i) {
+    pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
+  }
+  std::vector<std::shared_ptr<Node> > temp;
+  for (size_t i = 0; i < pos.size(); ++i) {
+    temp.emplace_back(leaf_vars->data[pos[i]]);
+  }
+  std::sort(pos.begin(), pos.end());
+  for (size_t i = 0; i < pos.size(); ++i) {
+    leaf_vars->data[pos[i]] = temp[i];
+  }
+  return *this;
+}
 
 IterVarRelation SplitNode::make(
     IterVar parent, IterVar outer,
diff --git a/tests/python/test_schedule.py b/tests/python/test_schedule.py
index 5d3f76131..773be8b55 100644
--- a/tests/python/test_schedule.py
+++ b/tests/python/test_schedule.py
@@ -6,28 +6,36 @@ def test_schedule_create():
     l = tvm.Var('l')
     A = tvm.placeholder((m, l), name='A')
     B = tvm.placeholder((n, l), name='B')
+    AA = tvm.compute((m, l), lambda i, j: A[i, j])
     T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
 
-    Tsch = tvm.Schedule(T.op, scope="shared")
-    Asch = tvm.Schedule(A.op)
+    sch_T = tvm.Schedule(T.op, scope="shared")
+    sch_A = tvm.Schedule(AA.op, scope="global")
 
-    T.op.
+    xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
+    xi1, xi2 = sch_T.split(xi, factor=2)
 
+    sch_A.compute_at(sch_T, xi1)
+    xo, xi = sch_A.split(AA.op.dim_var[0], factor=10)
 
-    xo, xi = sch.split(sch.dim_var[0], factor)
-    Asch.compute_at(Tsch, xi)
+    sch_T.reorder(xi2, xi1)
+    assert T.op.dim_var[1] in sch_T.leaf_iter_vars
 
-    xf = sch.fuse(xo, xi)
-
-
-    tk1 = tvm.Split(T.op.dim_var[0], 10)
-    assert isinstance(sch, tvm.schedule.Schedule)
-    assert isinstance(tk1, tvm.schedule.DimSplit)
+def test_reorder():
+    m = tvm.Var('m')
+    A = tvm.placeholder((m,), name='A')
+    T = tvm.compute(m, lambda i: A[i+1])
 
-    print(tk1.var)
-    print(sch.scope)
-    print(sch.attachs)
+    sch_T = tvm.Schedule(T.op, scope="shared")
+    xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
+    xi1, xi2 = sch_T.split(xi, factor=2)
+    order = (xi2, xi1, xo)
+    assert tuple(sch_T.leaf_iter_vars) != order
+    sch_T.reorder(*order)
+    assert tuple(sch_T.leaf_iter_vars) == order
 
 
 if __name__ == "__main__":
     test_schedule_create()
+    test_reorder()
+
-- 
GitLab