From 3555769efd52c08d004962757d4cf15e3e5b4294 Mon Sep 17 00:00:00 2001
From: Ziheng Jiang <jzhtomas@gmail.com>
Date: Tue, 21 Feb 2017 12:42:46 -0800
Subject: [PATCH] [ARITH] Add CombineInterval<Div> in IntSet (#48)

* [FIX] add CombineInterval<Div>

* fix error message and add comment about rounding

* fix comment
---
 src/arithmetic/int_set.cc | 28 +++++++++++++++++++++++++++-
 1 file changed, 27 insertions(+), 1 deletion(-)

diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index 8c89d93e6..709da26a6 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
     if (is_one(b.min)) return IntervalSet::make(a);
     Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
     Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
-    // This is relaxiation
+    // no relaxation is needed in here due to set is inclusive
     // TODO(tqchen): consider convert to StrideSet.
     if (is_positive_const(b.min)) {
       return IntervalSet::make(e1, e2);
@@ -259,6 +259,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
   return IntSet::everything();
 }
 
+template<>
+inline IntSet CombineInterval<Div>(Interval a, Interval b) {
+  if (a.is_single_point() && b.is_single_point()) {
+    return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
+  }
+  if (b.is_single_point()) {
+    if (is_zero(b.min)) {
+      LOG(FATAL) << "Divide by zero in CombineInterval Div";
+    }
+    if (is_one(b.min)) return IntervalSet::make(a);
+    Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
+    Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
+    // no relaxation is needed in here due to set is inclusive
+    if (is_positive_const(b.min)) {
+      return IntervalSet::make(e1, e2);
+    } else if (is_negative_const(b.min)) {
+      return IntervalSet::make(e2, e1);
+    } else if (a.is_bounded()) {
+      Expr cmp = b.min >= make_zero(b.min.type().element_of());
+      return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
+    }
+  }
+  LOG(WARNING) << "Return Everything in CombineInterval Div";
+  return IntSet::everything();
+}
+
 template<>
 inline IntSet CombineInterval<Max>(Interval a, Interval b) {
   if (a.is_single_point() && b.is_single_point()) {
-- 
GitLab