From 52ad69fcd26de62662d729b2d11c01259d4f5529 Mon Sep 17 00:00:00 2001
From: ziheng <ziheng@apache.org>
Date: Fri, 12 May 2017 15:11:36 -0700
Subject: [PATCH] [FIX] Add CombineInternal<Mod> & Fix LoopPartition (#138)

* Add CombineInternal<Mod> & Fix LoopPartition

* Add check for path
---
 src/arithmetic/bound_deducer.cc               | 19 ++++++++++++-------
 src/arithmetic/compute_expr.h                 |  6 ++++++
 src/arithmetic/int_set.cc                     | 17 +++++++++++++++++
 src/pass/loop_partition.cc                    |  4 +++-
 .../unittest/test_pass_loop_partition.py      | 15 +++++++++++++++
 5 files changed, 53 insertions(+), 8 deletions(-)

diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc
index db4b1c57d..75beb3ad8 100644
--- a/src/arithmetic/bound_deducer.cc
+++ b/src/arithmetic/bound_deducer.cc
@@ -204,10 +204,14 @@ void BoundDeducer::Transform() {
 void BoundDeducer::Deduce() {
   Init();
   if (!success) return;
-
   Relax();
+  if (!success) return;
   // get the path
   path_ = GetPath(target_, expr_);
+  if (!path_.size()) {
+    success = false;
+    return;
+  }
   // get the sign of every subexpr
   expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
 
@@ -215,13 +219,14 @@ void BoundDeducer::Deduce() {
 }
 
 void BoundDeducer::Relax() {
-  if (is_greater) {
-    expr_  = EvalSet(expr_ , relax_map_).min();
-    result = EvalSet(result, relax_map_).max();
-  } else {
-    expr_  = EvalSet(expr_ , relax_map_).max();
-    result = EvalSet(result, relax_map_).min();
+  IntSet a = EvalSet(expr_, relax_map_);
+  IntSet b = EvalSet(result, relax_map_);
+  if (a.is_everything() || b.is_everything()) {
+    success = false;
+    return;
   }
+  expr_  = is_greater ? a.min() : a.max();
+  result = is_greater ? b.max() : b.min();
 }
 
 IntSet DeduceBound(Expr v, Expr e,
diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h
index cec44b033..110e61f23 100644
--- a/src/arithmetic/compute_expr.h
+++ b/src/arithmetic/compute_expr.h
@@ -112,6 +112,12 @@ inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
   return ir::Div::make(a, b);
 }
 
+template<>
+inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) {
+  if (is_zero(a)) return make_zero(a.type());
+  return ir::Mod::make(a, b);
+}
+
 template<>
 inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
   return Halide::Internal::Interval::make_max(a, b);
diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index 0451df867..848a7a21d 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -291,6 +291,23 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) {
   return IntSet::everything();
 }
 
+template<>
+inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
+  if (a.is_single_point() && b.is_single_point()) {
+    return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
+  }
+  if (b.is_single_point()) {
+    Expr divisor = b.min;
+    if (is_zero(divisor)) {
+      LOG(FATAL) << "Modular by zero in CombineInterval Mod";
+    }
+    return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
+  }
+
+  LOG(WARNING) << "Return Everything in CombineInterval Mod";
+  return IntSet::everything();
+}
+
 template<>
 inline IntSet CombineInterval<Max>(Interval a, Interval b) {
   if (a.is_single_point() && b.is_single_point()) {
diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc
index bc8aea33d..c4567b6ae 100644
--- a/src/pass/loop_partition.cc
+++ b/src/pass/loop_partition.cc
@@ -153,7 +153,9 @@ class PartitionFinder : public IRVisitor {
           std::unordered_set<const Variable*>({current_var_.get()}))) {
         IntSet interval =
           DeduceBound(current_var_, cond, hint_map_, relax_map_);
-        partitions[cond.get()] = Partition{cond, interval};
+        if (!interval.is_nothing()) {
+          partitions[cond.get()] = Partition{cond, interval};
+        }
       }
     } else {
       IRVisitor::Visit_(op);
diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py
index 8ae19eab6..9c97776df 100644
--- a/tests/python/unittest/test_pass_loop_partition.py
+++ b/tests/python/unittest/test_pass_loop_partition.py
@@ -148,6 +148,20 @@ def test_thread_axis2():
     for_body = stmt.body.body.body.body.body.first
     assert('threadIdx' not in str(for_body.extent))
 
+def test_everything_during_deduction():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    ib = tvm.ir_builder.create()
+    with ib.for_range(0, n, 'i') as i:
+        with ib.for_range(0, 32, 'j') as j:
+            with ib.if_scope(ib.likely(i/j < m)):
+                # this guard will produce everything during deduction
+                ib.emit(tvm.make.Evaluate(m))
+    stmt = ib.get()
+    stmt = tvm.ir_pass.LoopPartition(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse))
+
 if __name__ == "__main__":
     test_basic()
     test_multi_loop()
@@ -156,3 +170,4 @@ if __name__ == "__main__":
     test_vectorize()
     test_select()
     test_thread_axis2()
+    test_everything_during_deduction()
-- 
GitLab