From 6a62beb2b3f6f7fb7ea5dbb5ef3902af0b362f98 Mon Sep 17 00:00:00 2001
From: Ziheng Jiang <jzhtomas@gmail.com>
Date: Wed, 8 Feb 2017 22:33:52 -0800
Subject: [PATCH] [FUSION] add 'void AutoFuseEwise(Schedule sch)' (#36)

* [FUSION] add Fusion(Schedule)

* [FUSION] rename to AutoFuseEwise, detect whether the stage has been scheduled

* [FUSION] change to visitor pattern

* [FUSION] rename filename

* [FUSION] fine-tune the interface

* [FUSION] typo

* move elem_wise to schedule

* rename test function
---
 include/tvm/ir_pass.h                         |  1 -
 include/tvm/schedule.h                        | 11 +++
 include/tvm/schedule_pass.h                   |  7 ++
 python/tvm/schedule.py                        |  2 +-
 src/api/api_schedule.cc                       |  5 ++
 src/schedule/auto_inline_elem_wise.cc         | 76 +++++++++++++++++++
 .../unittest/test_schedule_schedule_ops.py    | 16 ++++
 7 files changed, 116 insertions(+), 2 deletions(-)
 create mode 100644 src/schedule/auto_inline_elem_wise.cc

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index b11486d90..9e3e1b0a1 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -167,7 +167,6 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
  */
 LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
 
-
 }  // namespace ir
 }  // namespace tvm
 
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index f115dbc6f..a7cd58c96 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -123,6 +123,12 @@ class Stage : public NodeRef {
               IterVar* p_x_outer, IterVar* p_y_outer,
               IterVar* p_x_inner, IterVar* p_y_inner,
               Expr x_factor, Expr y_factor);
+  /*!
+   * \brief whether the stage has been scheduled.
+   * \return whether the stage has been scheduled.
+   */
+  inline bool is_scheduled() const;
+
   // declare container type
   using ContainerType = StageNode;
 };
@@ -353,6 +359,11 @@ inline StageNode* Stage::operator->() {
   return static_cast<StageNode*>(node_.get());
 }
 
+inline bool Stage::is_scheduled() const {
+  const StageNode* n = operator->();
+  return !(n->relations.empty() && n->attach_type == kNone);
+}
+
 inline const ScheduleNode* Schedule::operator->() const {
   return static_cast<const ScheduleNode*>(node_.get());
 }
diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h
index 57e442c5c..c4e82cde1 100644
--- a/include/tvm/schedule_pass.h
+++ b/include/tvm/schedule_pass.h
@@ -33,6 +33,13 @@ Map<IterVar, Range> InferBound(Schedule sch);
  */
 Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
 
+/*!
+ * \brief To automatically inline the element-wise operations.
+ *
+ * \param sch The schedule to be inlined.
+ */
+void AutoInlineElemWise(Schedule sch);
+
 }  // namespace schedule
 }  // namespace tvm
 #endif  // TVM_SCHEDULE_PASS_H_
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 3fd7f9730..fee0fb3b1 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -135,7 +135,7 @@ class Stage(NodeBase):
         parent : Stage
             The parent stage
         """
-        _api_internal._StageComputeInline(self)
+        _api_internal._StageComputeRoot(self)
 
     def reorder(self, *args):
         """reorder the arguments in the specified order.
diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc
index a4462117d..882ff94bd 100644
--- a/src/api/api_schedule.cc
+++ b/src/api/api_schedule.cc
@@ -13,6 +13,11 @@
 namespace tvm {
 namespace schedule {
 
+TVM_REGISTER_API(_schedule_AutoInlineElemWise)
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    AutoInlineElemWise(args[0]);
+  });
+
 #define REGISTER_SCHEDULE_PASS1(PassName)                         \
   TVM_REGISTER_API(_schedule_## PassName)                         \
   .set_body([](TVMArgs args,  TVMRetValue *ret) {                 \
diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc
new file mode 100644
index 000000000..66816c955
--- /dev/null
+++ b/src/schedule/auto_inline_elem_wise.cc
@@ -0,0 +1,76 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file auto_inline_elem_wise.cc
+ */
+#include <tvm/schedule_pass.h>
+#include <tvm/ir_visitor.h>
+
+namespace tvm {
+namespace ir {
+
+class ElemWiseDetector : public IRVisitor {
+ public:
+  explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
+
+  void Visit(const NodeRef& e) final {
+    if (!is_elem_wise_) return;
+    IRVisitor::Visit(e);
+  }
+
+  void Visit_(const Call* op) final {
+    Array<Expr> axis = op->args;
+    if (axis_.size() != axis.size()) {
+      is_elem_wise_ = false;
+      return;
+    }
+
+    for (size_t i = 0; i < axis_.size(); ++i) {
+      // const Variable *v1 = axis_[i]->var.as<Variable>();
+      // const Variable *v2 = axis[i].as<Variable>();
+      if (!axis[i].same_as(axis_[i]->var)) {
+      // if (!(v1 && v2) || (v1 != v2)) {
+        is_elem_wise_ = false;
+        return;
+      }
+    }
+    IRVisitor::Visit_(op);
+  }
+
+  bool is_elem_wise_{true};
+
+ private:
+  Array<IterVar> axis_;
+};
+
+
+bool IsElemWise(const Operation& op) {
+  if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
+    ElemWiseDetector v = ElemWiseDetector(compute->axis);
+    v.Visit(compute->body);
+    return v.is_elem_wise_;
+  }
+  return false;
+}
+
+}  // namespace ir
+
+namespace schedule {
+
+void AutoInlineElemWise(Schedule sch) {
+  for (Stage s : sch->stages) {
+    if (!s.is_scheduled() && ir::IsElemWise(s->op)) {
+      bool is_root = false;
+      for (auto r : sch->roots) {
+        if (r == s->op) {
+          is_root = true;
+          break;
+        }
+      }
+      if (!is_root)
+        s.compute_inline();
+    }
+  }
+}
+
+}  // namespace schedule
+}  // namespace tvm
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index feed951e2..9689a1c34 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -42,8 +42,24 @@ def test_schedule2():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     print(stmt)
 
+def test_auto_inline():
+  m = tvm.Var('m')
+  n = tvm.Var('n')
+  A = tvm.placeholder((m, n), name='A')
+  B = tvm.placeholder((m, n), name='B')
+  C = tvm.placeholder((m, n), name='C')
+  T1 = tvm.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='T1')
+  T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')
+
+  s = tvm.Schedule(T2.op)
+  tvm.schedule.AutoInlineElemWise(s)
+  bounds = tvm.schedule.InferBound(s)
+  stmt = tvm.schedule.ScheduleOps(s, bounds)
+  print(stmt)
+
 
 if __name__ == "__main__":
     test_schedule0()
     test_schedule1()
     test_schedule2()
+    test_auto_inline()
-- 
GitLab