From 84aeaf4803901ef738c8d21005e21c899218c8ac Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Sun, 4 Jun 2017 17:34:14 -0700
Subject: [PATCH] Change Schedule Array constructor to static make method
 (#170)

* Change Schedule Array constructor to static make method

* Add CreateSchedule

* Add doc

* Change CreateSchedule to create_schedule at cpp side
---
 include/tvm/schedule.h        | 21 +++++++---
 python/tvm/schedule.py        |  2 +-
 src/api/api_lang.cc           |  4 +-
 src/schedule/schedule_lang.cc | 79 ++++++++++++++++++-----------------
 4 files changed, 59 insertions(+), 47 deletions(-)

diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 085730a9b..4479c4fbe 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -197,11 +197,6 @@ class Schedule : public NodeRef {
  public:
   Schedule() {}
   explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
-  /*!
-   * \brief construct schedule for array of ops(and their dependencies).
-   * \param ops The ops to be scheduled.
-   */
-  explicit Schedule(Array<Operation> ops);
   /*!
    * \brief Get a copy of current schedule.
    * \return The copied schedule.
@@ -439,10 +434,26 @@ class ScheduleNode : public Node {
   /*! \brief Invalidate temp cache. */
   void InvalidateCache();
 
+  /*!
+   * \brief Create a schedule for array of ops(and their dependencies).
+   * \param ops The ops to be scheduled.
+   * \return sch The created Schedule.
+   */
+  static Schedule make(Array<Operation> ops);
+
   static constexpr const char* _type_key = "Schedule";
   TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
 };
 
+/*!
+ * \brief Create a schedule for array of ops(and their dependencies).
+ * \param ops The ops to be scheduled.
+ * \return sch The created Schedule.
+ */
+inline Schedule create_schedule(Array<Operation> ops) {
+  return ScheduleNode::make(ops);
+}
+
 /*! \brief node container for IterVar attr */
 class IterVarAttrNode : public Node {
  public:
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index a23721514..8b2e63f9c 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -73,7 +73,7 @@ def create_schedule(ops):
     """
     if not isinstance(ops, (list, _collections.Array)):
         ops = [ops]
-    return _api_internal._Schedule(ops)
+    return _api_internal._CreateSchedule(ops)
 
 
 @register_node
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 496b499f2..8dc53bc28 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -225,9 +225,9 @@ TVM_REGISTER_API("_IterVar")
         args[3]);
   });
 
-TVM_REGISTER_API("_Schedule")
+TVM_REGISTER_API("_CreateSchedule")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = Schedule(args[0].operator Array<Operation>());
+    *ret = create_schedule(args[0].operator Array<Operation>());
   });
 
 TVM_REGISTER_API("_StageSetScope")
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index 9e8056be1..25d72e346 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -322,45 +322,6 @@ Stage& Stage::parallel(IterVar var) {   // NOLINT(*)
   return *this;
 }
 
-Schedule::Schedule(Array<Operation> ops) {
-  auto n = std::make_shared<ScheduleNode>();
-  node_ = n;
-  n->outputs = ops;
-  auto g = schedule::CreateReadGraph(n->outputs);
-  Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
-  // output set.
-  std::unordered_set<Operation> output_set;
-  for (Operation x : ops) {
-    output_set.insert(x);
-  }
-  for (Operation op : post_order) {
-    Stage stage(op);
-    stage->is_output = output_set.count(op) != 0;
-    n->stages.push_back(stage);
-    n->stage_map.Set(op, stage);
-    // mark scan updates.
-    if (op.as<ScanOpNode>()) {
-      const ScanOpNode* scan = op.as<ScanOpNode>();
-      Array<Tensor> inputs;
-      for (Tensor t : scan->state_placeholder) {
-        inputs.push_back(t);
-      }
-      for (Tensor t : scan->inputs) {
-        inputs.push_back(t);
-      }
-      // Create the scan group.
-      Stage scan_group = create_group(scan->update, inputs, false);
-      scan_group->attach_type = kScanUpdate;
-      scan_group->attach_stage = stage;
-
-      for (size_t i = 0; i < scan->update.size(); ++i) {
-        Stage s = n->stage_map[scan->update[i]->op];
-        CHECK(scan_group.same_as(s->group));
-      }
-    }
-  }
-}
-
 Stage CopyStage(const Stage& s) {
   std::shared_ptr<StageNode> n =
       std::make_shared<StageNode>(*s.operator->());
@@ -580,6 +541,46 @@ void ScheduleNode::InitCache() {
   CHECK_EQ(op2stage_cache_.size(), stages.size());
 }
 
+Schedule ScheduleNode::make(Array<Operation> ops) {
+  auto n = std::make_shared<ScheduleNode>();
+  Schedule sch(n);
+  n->outputs = ops;
+  auto g = schedule::CreateReadGraph(n->outputs);
+  Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g);
+  // output set.
+  std::unordered_set<Operation> output_set;
+  for (Operation x : ops) {
+    output_set.insert(x);
+  }
+  for (Operation op : post_order) {
+    Stage stage(op);
+    stage->is_output = output_set.count(op) != 0;
+    n->stages.push_back(stage);
+    n->stage_map.Set(op, stage);
+    // mark scan updates.
+    if (op.as<ScanOpNode>()) {
+      const ScanOpNode* scan = op.as<ScanOpNode>();
+      Array<Tensor> inputs;
+      for (Tensor t : scan->state_placeholder) {
+        inputs.push_back(t);
+      }
+      for (Tensor t : scan->inputs) {
+        inputs.push_back(t);
+      }
+      // Create the scan group.
+      Stage scan_group = sch.create_group(scan->update, inputs, false);
+      scan_group->attach_type = kScanUpdate;
+      scan_group->attach_stage = stage;
+
+      for (size_t i = 0; i < scan->update.size(); ++i) {
+        Stage s = n->stage_map[scan->update[i]->op];
+        CHECK(scan_group.same_as(s->group));
+      }
+    }
+  }
+  return sch;
+}
+
 IterVarRelation SplitNode::make(IterVar parent,
                                 IterVar outer,
                                 IterVar inner,
-- 
GitLab