From f8f028295e93423d72c1ec666eb6181f51faeebb Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 12 Feb 2017 13:19:13 -0800
Subject: [PATCH] [SCHEDULE] Refactor bound inference logic (#41)

---
 include/tvm/schedule_pass.h |   2 +-
 src/arithmetic/int_set.cc   |  22 +--
 src/arithmetic/int_set.h    |   3 +
 src/schedule/bound.cc       | 293 +++++++++++++++++++++++-------------
 4 files changed, 205 insertions(+), 115 deletions(-)

diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h
index c4e82cde1..b3f64db1e 100644
--- a/include/tvm/schedule_pass.h
+++ b/include/tvm/schedule_pass.h
@@ -22,7 +22,7 @@ namespace schedule {
  * \param sch The root schedule to infer all the bounds.
  * \return the result bound of the iteration Variable
  */
-Map<IterVar, Range> InferBound(Schedule sch);
+Map<IterVar, Range> InferBound(const Schedule& sch);
 
 /*!
  * \brief Schedule s' dependent operations.
diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index d60504f2c..a80594335 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
 .set_dispatch<And>(Binary<And>)
 .set_dispatch<Or>(Binary<Or>);
 
-
 IntSet EvalSet(Expr e,
                const std::unordered_map<const Variable*, IntSet>& dom_map) {
   return IntSetEvaluator(dom_map).Eval(e);
@@ -444,17 +443,12 @@ IntSet EvalSet(Expr e,
   for (auto kv : dom_map) {
     dmap[kv.first->var.as<Variable>()] = kv.second;
   }
-  IntSetEvaluator m(dmap);
-  return m.Eval(e);
+  return EvalSet(e, dmap);
 }
 
 IntSet EvalSet(Range r,
-               const Map<IterVar, IntSet>& dom_map) {
-  std::unordered_map<const Variable*, IntSet> dmap;
-  for (auto kv : dom_map) {
-    dmap[kv.first->var.as<Variable>()] = kv.second;
-  }
-  IntSetEvaluator m(dmap);
+               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+  IntSetEvaluator m(dom_map);
   IntSet min_set = m.Eval(r->min);
   IntSet ext_set = m.Eval(r->extent).cover_interval();
   const Interval& ei = ext_set.as<IntervalSet>()->i;
@@ -463,6 +457,15 @@ IntSet EvalSet(Range r,
   return Combine<Add>(min_set, ext_set);
 }
 
+IntSet EvalSet(Range r,
+               const Map<IterVar, IntSet>& dom_map) {
+  std::unordered_map<const Variable*, IntSet> dmap;
+  for (auto kv : dom_map) {
+    dmap[kv.first->var.as<Variable>()] = kv.second;
+  }
+  return EvalSet(r, dmap);
+}
+
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
     p->stream << "interval-set["
@@ -470,6 +473,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
               << op->i.max << ']';
   });
 
-
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h
index 979d138af..f5de74501 100644
--- a/src/arithmetic/int_set.h
+++ b/src/arithmetic/int_set.h
@@ -103,6 +103,9 @@ IntSet EvalSet(Expr e,
  */
 IntSet EvalSet(Range r,
                const Map<IterVar, IntSet>& dom_map);
+IntSet EvalSet(Range r,
+               const std::unordered_map<const Variable*, IntSet>& dom_map);
+
 
 /*!
  * \brief Create an union set of all sets
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 9fe530b67..4514d0228 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -7,6 +7,8 @@
 #include <tvm/ir_visitor.h>
 #include <tvm/ir_pass.h>
 #include <tvm/schedule_pass.h>
+#include <unordered_map>
+#include <unordered_set>
 #include "./graph.h"
 #include "../arithmetic/int_set.h"
 #include "../runtime/thread_storage_scope.h"
@@ -131,7 +133,6 @@ void PassUp(const FuseNode* s,
   }
 }
 
-
 void PassUp(const RebaseNode* s,
             const std::unordered_map<IterVar, Range>& dom_map,
             const IntSet& rebased,
@@ -180,82 +181,69 @@ void PassUp(const Stage& s,
   }
 }
 
-/*!
- * \brief Pass the bound of tensor read
- *  to the corresponding bound of the IterVar of operation
- * \param tensor The tensor to be passed.
- * \param dim_bounds The read index set on each dimension.
- * \param The result IterVar bound .
- */
-void PassToOperation(
-    const Tensor& tensor,
-    const std::vector<IntSet>& dim_bounds,
-    std::unordered_map<IterVar, std::vector<IntSet> >* result) {
-  // This is a push style operation, given output bound, push to the op IterVar bound.
-  // It cannot handle complicated cases where op bound is coupled with bounds of
-  // all of its outputs, without having a simple communicative union relation.
-  //
-  // Eventually, we need to change the inference to be a Pull style inference
-  if (tensor->op.as<ComputeOpNode>()) {
-    auto root_iter_vars = tensor->op->root_iter_vars();
-    const ComputeOpNode* op = tensor->op.as<ComputeOpNode>();
-    CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size());
-    for (size_t i = 0; i < op->axis.size(); ++i) {
-      (*result)[op->axis[i]].push_back(dim_bounds[i]);
-    }
-    // reduction.
-    for (size_t i = 0; i < op->reduce_axis.size(); ++i) {
-      (*result)[op->reduce_axis[i]].push_back(
-          IntSet::range(op->reduce_axis[i]->dom));
-    }
-  } else {
-    LOG(FATAL) << "unknown operation mode " << tensor->op->type_key();
-  }
-}
+
+/*! \brief temporary data structure to store Tensor domain */
+struct TensorDom {
+  // constructor
+  explicit TensorDom(int ndim)
+      : data(ndim) {}
+  /*! \brief The domain data*/
+  std::vector<std::vector<IntSet> > data;
+};
 
 /*!
- * \brief Recursively propagate bound
- * \param post_order The propagation order.
+ * \brief Propagate bound to target
  * \param dom_map The domain map to be propagated
+ * \param out The tensor set to be passed
  * \return The result bound
  */
-std::unordered_map<IterVar, IntSet>
-BoundProp(const Array<Operation>& post_order,
-          std::unordered_map<IterVar, std::vector<IntSet> > *p_state) {
-  std::unordered_map<IterVar, IntSet> result;
-
-  for (size_t i = post_order.size(); i != 0; --i) {
-    Operation op = post_order[i - 1];
-    if (op.as<ComputeOpNode>()) {
-      for (auto iv : op->root_iter_vars()) {
-        CHECK(p_state->count(iv))
-            << "Bound of root operator must exists";
-        CHECK(!result.count(iv));
-        result[iv] = Union(p_state->at(iv));
-      }
-      auto fvisit = [p_state, &result](const NodeRef& n) {
-        auto *call = n.as<ir::Call>();
-        if (call != nullptr && call->func.defined()) {
-          Tensor t = Operation(call->func.node_).output(call->value_index);
-          if (t->op.defined() && !t->op.as<PlaceholderOpNode>()) {
-            std::vector<IntSet> arg_bounds;
-            for (size_t i = 0; i < t.ndim(); ++i) {
-              arg_bounds.push_back(EvalSet(call->args[i], result));
-            }
-            PassToOperation(t, arg_bounds, p_state);
+void BoundProp(const Operation& op,
+               const std::unordered_map<const Variable*, IntSet>& dom_map,
+               std::unordered_map<Tensor, TensorDom> *out) {
+  if (op.as<ComputeOpNode>()) {
+    auto fvisit = [&dom_map, out](const NodeRef& n) {
+      auto *call = n.as<ir::Call>();
+      if (call != nullptr && call->func.defined()) {
+        Tensor t = Operation(call->func.node_).output(call->value_index);
+        if (t->op.defined() && out->count(t)) {
+          TensorDom& dom = out->at(t);
+          for (size_t i = 0; i < t.ndim(); ++i) {
+            dom.data[i].push_back(EvalSet(call->args[i], dom_map));
           }
         }
-      };
-      ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
-    } else if (op.as<PlaceholderOpNode>()) {
-      // do nothing
-    } else {
-      LOG(FATAL) << "unknown operation mode " << op->type_key();
-    }
+      }
+    };
+    ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+  } else if (op.as<PlaceholderOpNode>()) {
+    // do nothing
+  } else {
+    LOG(FATAL) << "unknown operation mode " << op->type_key();
   }
-  return result;
 }
 
+void InferOpBound(const Operation& op,
+                  const std::unordered_map<Tensor, TensorDom>& tmap,
+                  std::unordered_map<IterVar, Range>* rmap) {
+  if (op.as<ComputeOpNode>()) {
+    auto root_iter_vars = op->root_iter_vars();
+    const ComputeOpNode* compute = op.as<ComputeOpNode>();
+    const TensorDom& tdom = tmap.at(op.output(0));
+
+    for (size_t i = 0; i < compute->axis.size(); ++i) {
+      Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom);
+      CHECK(!rmap->count(compute->axis[i]));
+      (*rmap)[compute->axis[i]] = r;
+    }
+    for (size_t i = 0; i < compute->reduce_axis.size(); ++i) {
+      CHECK(!rmap->count(compute->reduce_axis[i]));
+      (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
+    }
+  } else if (op.as<PlaceholderOpNode>()) {
+    // dp nothing
+  } else {
+    LOG(FATAL) << "unknown operation mode " << op->type_key();
+  }
+}
 
 // check if scope
 inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
@@ -267,8 +255,18 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
   return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
 }
 
-void InferBound(const Stage& stage,
-                std::unordered_map<IterVar, Range>* rmap) {
+// The map beteen tensor and operation it feeds ti
+using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
+
+// AttachPath maps op-> a list of IterVar
+// That represents the loop nest op sits in from inner most to outermost
+using AttachPath = Map<Operation, Array<IterVar> >;
+
+
+void InferRootBound(const Stage& stage,
+                    const FeedGraph& feed_graph,
+                    const AttachPath& attach_path,
+                    std::unordered_map<IterVar, Range>* rmap) {
   if (stage->attach_type == kInline) return;
   if (stage->attach_type == kRoot || stage->attach_type == kNone) {
     auto root_iter_vars = stage->op->root_iter_vars();
@@ -277,15 +275,46 @@ void InferBound(const Stage& stage,
       CHECK(!rmap->count(iv));
       (*rmap)[iv] = iv->dom;
     }
+    return;
   }
+  // Infer root bounds for the attached node.
+  CHECK_EQ(stage->attach_type, kScope);
+  Stage parent = stage->attach_stage;
+  CHECK(parent.defined());
 
-  if (stage->attach_type == kScope) {
-    Stage parent = stage->attach_stage;
-    CHECK(parent.defined());
-    auto g = CreateReadGraph({parent->op});
-    auto post_order = PostDFSOrder({parent->op}, g);
-    std::unordered_map<IterVar, IntSet> up_state;
+  // The tensor domain.
+  std::unordered_map<Tensor, TensorDom> tmap;
+  // consumers other than parent
+  std::unordered_set<Operation> consumers;
+  // initialize the result
+  bool direct_consume_by_parent = false;
+  for (int i = 0; i < stage->op->num_outputs(); ++i) {
+    Tensor t = stage->op.output(i);
+    tmap.emplace(t, TensorDom(t.ndim()));
+    auto it = feed_graph.find(t);
+    if (it != feed_graph.end()) {
+      for (const Operation& op : it->second) {
+        if (op != parent->op) {
+          consumers.insert(op);
+        } else {
+          direct_consume_by_parent = true;
+        }
+      }
+    }
+  }
+  // The relax set
+  // Thie specifieds the iteration variables that need to be relaxed
+  // from the already inferred bounds.
+  std::unordered_map<const Variable*, IntSet> relax_set;
+  for (IterVar iv : attach_path.at(stage->op)) {
+    if (ScopeRelax(iv, stage->scope)) {
+      relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
+    }
+  }
 
+  if (direct_consume_by_parent) {
+    // Bound inference logics in parent.
+    std::unordered_map<IterVar, IntSet> up_state;
     bool fix_value = true;
     for (auto iv : parent->leaf_iter_vars) {
       Range vrange = rmap->at(iv);
@@ -305,48 +334,104 @@ void InferBound(const Stage& stage,
         fix_value = false;
       }
     }
-    // get the bound of the root IterVars given the current condition
+    // get the bound of the root IterVars given current location.
     PassUp(parent, *rmap, &up_state);
-    std::unordered_map<IterVar, std::vector<IntSet> > bp_state;
+
+    std::unordered_map<const Variable*, IntSet> dom_map;
     for (auto iv : parent->op->root_iter_vars()) {
-      CHECK(up_state.count(iv));
-      bp_state[iv] = {up_state.at(iv)};
+      Range r = up_state.at(iv).cover_range(iv->dom);
+      if (relax_set.size() != 0) {
+        dom_map[iv->var.get()] = EvalSet(r, relax_set);
+      } else {
+        dom_map[iv->var.get()] = IntSet::range(r);
+      }
     }
-    auto result = BoundProp(post_order, &bp_state);
-
-    // Set relaxation for the threads in parent.
-    Map<IterVar, IntSet> relax_set;
-    Stage s = stage;
-    while (s->attach_type == kScope) {
-      s = s->attach_stage;
-      for (auto iv : s->leaf_iter_vars) {
-        if (ScopeRelax(iv, stage->scope)) {
-          relax_set.Set(iv, IntSet::range(rmap->at(iv)));
-        }
+    // prop from parent.
+    BoundProp(parent->op, dom_map, &tmap);
+  }
+  // Bound prop by other consumers.
+  // To explain the the general logic, consider the example:
+  //
+  // for (i_outer, 0, 10) {
+  //   producer
+  //
+  //   for (i_inner, 0, 4) {
+  //     consumer op
+  //   }
+  // }
+  // - Get domain of each of consumer op, say [i_inner + i_outer*8, extent=4)
+  // - We need to relax it since the producer is attached at i_outer
+  // - Consumer's path is [i_inner, i_outer], then [i_inner] need to be relaxed
+  // - Traverse attach_path, relax until reaching the producer's attachment point.
+  for (const Operation& op : consumers) {
+    std::unordered_map<const Variable*, IntSet> dom_map;
+    bool found = false;
+    for (IterVar iv : attach_path.at(op)) {
+      if (iv == stage->attach_ivar) {
+        found = true; break;
       }
+      Range vrange = rmap->at(iv);
+      CHECK(is_zero(vrange->min))
+          << "InferBound requires every leaf iter var's min equals 0, "
+          << "call schedule.normalize to achieve this.";
+      relax_set[iv->var.get()] = IntSet::range(vrange);
     }
+    CHECK(found)
+        << "Invalid Schedule, cannot find the producer " << stage->op
+        << " along the loop nest specified by compute_at of consumer " << op;
+    for (auto iv : op->root_iter_vars()) {
+      Range r = rmap->at(iv);
+      dom_map[iv->var.get()] = EvalSet(r, relax_set);
+    }
+    BoundProp(op, dom_map, &tmap);
+  }
+  InferOpBound(stage->op, tmap, rmap);
+}
 
-    for (auto iv : stage->op->root_iter_vars()) {
-      CHECK(result.count(iv));
-      CHECK(!rmap->count(iv));
-      Range r = result.at(iv).cover_range(iv->dom);
-      if (relax_set.size() != 0) {
-        r = EvalSet(r, relax_set).cover_range(iv->dom);
-      }
-      (*rmap)[iv] = r;
+FeedGraph CreateFeedGraph(const Schedule& sch) {
+  auto g = CreateReadGraph(sch->roots);
+  FeedGraph fg;
+  for (auto kv : g) {
+    for (Tensor t : kv.second) {
+      fg[t].push_back(kv.first);
     }
   }
-  // get range of all child iter vars.
-  PassDown(stage, rmap);
+  return fg;
 }
 
+// Create AttachPath that  maps op-> a list of IterVar
+// That represents the loop nest op sits in from inner most to outermost
+AttachPath CreateAttachPath(const Schedule& sch) {
+  AttachPath ret;
+  for (Stage stage : sch->stages) {
+    Array<IterVar> path;
+    for (Stage s = stage; s->attach_type == kScope;) {
+      IterVar attach_ivar = s->attach_ivar;
+      s = s->attach_stage;
+      bool start_attach = false;
+      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+        IterVar iv = s->leaf_iter_vars[i - 1];
+        if (iv == attach_ivar) start_attach = true;
+        if (start_attach) path.push_back(iv);
+      }
+      CHECK(start_attach)
+          << "Invalid Schedule: cannot find attach point " << attach_ivar
+          << " in the schedule of " << s->op;
+    }
+    ret.Set(stage->op, path);
+  }
+  return ret;
+}
 
-Map<IterVar, Range> InferBound(Schedule sch) {
+Map<IterVar, Range> InferBound(const Schedule& sch) {
+  FeedGraph feed_graph = CreateFeedGraph(sch);
+  AttachPath attach_path = CreateAttachPath(sch);
   std::unordered_map<IterVar, Range> ret;
-  // reverse post DFS order, from out most stage to the innermost
   for (size_t i = sch->stages.size(); i != 0; --i) {
-    Stage stage = sch->stages[i - 1];
-    InferBound(stage, &ret);
+    const Stage& stage = sch->stages[i - 1];
+    InferRootBound(stage, feed_graph, attach_path, &ret);
+    // pass down to get bound of all iter vars.
+    PassDown(stage, &ret);
   }
   return Map<IterVar, Range>(ret.begin(), ret.end());
 }
-- 
GitLab