From 3555769efd52c08d004962757d4cf15e3e5b4294 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang <jzhtomas@gmail.com> Date: Tue, 21 Feb 2017 12:42:46 -0800 Subject: [PATCH] [ARITH] Add CombineInterval<Div> in IntSet (#48) * [FIX] add CombineInterval<Div> * fix error message and add comment about rounding * fix comment --- src/arithmetic/int_set.cc | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 8c89d93e6..709da26a6 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) { if (is_one(b.min)) return IntervalSet::make(a); Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min; Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max; - // This is relaxiation + // no relaxation is needed in here due to set is inclusive // TODO(tqchen): consider convert to StrideSet. if (is_positive_const(b.min)) { return IntervalSet::make(e1, e2); @@ -259,6 +259,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) { return IntSet::everything(); } +template<> +inline IntSet CombineInterval<Div>(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr<Div>(a.min, b.min)); + } + if (b.is_single_point()) { + if (is_zero(b.min)) { + LOG(FATAL) << "Divide by zero in CombineInterval Div"; + } + if (is_one(b.min)) return IntervalSet::make(a); + Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min; + Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max; + // no relaxation is needed in here due to set is inclusive + if (is_positive_const(b.min)) { + return IntervalSet::make(e1, e2); + } else if (is_negative_const(b.min)) { + return IntervalSet::make(e2, e1); + } else if (a.is_bounded()) { + Expr cmp = b.min >= make_zero(b.min.type().element_of()); + return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1)); + } + } + LOG(WARNING) << "Return Everything in CombineInterval Div"; + return IntSet::everything(); +} + template<> inline IntSet CombineInterval<Max>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { -- GitLab