From 52ad69fcd26de62662d729b2d11c01259d4f5529 Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Fri, 12 May 2017 15:11:36 -0700 Subject: [PATCH] [FIX] Add CombineInternal<Mod> & Fix LoopPartition (#138) * Add CombineInternal<Mod> & Fix LoopPartition * Add check for path --- src/arithmetic/bound_deducer.cc | 19 ++++++++++++------- src/arithmetic/compute_expr.h | 6 ++++++ src/arithmetic/int_set.cc | 17 +++++++++++++++++ src/pass/loop_partition.cc | 4 +++- .../unittest/test_pass_loop_partition.py | 15 +++++++++++++++ 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index db4b1c57d..75beb3ad8 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -204,10 +204,14 @@ void BoundDeducer::Transform() { void BoundDeducer::Deduce() { Init(); if (!success) return; - Relax(); + if (!success) return; // get the path path_ = GetPath(target_, expr_); + if (!path_.size()) { + success = false; + return; + } // get the sign of every subexpr expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); @@ -215,13 +219,14 @@ void BoundDeducer::Deduce() { } void BoundDeducer::Relax() { - if (is_greater) { - expr_ = EvalSet(expr_ , relax_map_).min(); - result = EvalSet(result, relax_map_).max(); - } else { - expr_ = EvalSet(expr_ , relax_map_).max(); - result = EvalSet(result, relax_map_).min(); + IntSet a = EvalSet(expr_, relax_map_); + IntSet b = EvalSet(result, relax_map_); + if (a.is_everything() || b.is_everything()) { + success = false; + return; } + expr_ = is_greater ? a.min() : a.max(); + result = is_greater ? b.max() : b.min(); } IntSet DeduceBound(Expr v, Expr e, diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index cec44b033..110e61f23 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -112,6 +112,12 @@ inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) { return ir::Div::make(a, b); } +template<> +inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) { + if (is_zero(a)) return make_zero(a.type()); + return ir::Mod::make(a, b); +} + template<> inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) { return Halide::Internal::Interval::make_max(a, b); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0451df867..848a7a21d 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -291,6 +291,23 @@ inline IntSet CombineInterval<Div>(Interval a, Interval b) { return IntSet::everything(); } +template<> +inline IntSet CombineInterval<Mod>(Interval a, Interval b) { + if (a.is_single_point() && b.is_single_point()) { + return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min)); + } + if (b.is_single_point()) { + Expr divisor = b.min; + if (is_zero(divisor)) { + LOG(FATAL) << "Modular by zero in CombineInterval Mod"; + } + return IntervalSet::make(make_zero(divisor.type()), divisor - 1); + } + + LOG(WARNING) << "Return Everything in CombineInterval Mod"; + return IntSet::everything(); +} + template<> inline IntSet CombineInterval<Max>(Interval a, Interval b) { if (a.is_single_point() && b.is_single_point()) { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index bc8aea33d..c4567b6ae 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -153,7 +153,9 @@ class PartitionFinder : public IRVisitor { std::unordered_set<const Variable*>({current_var_.get()}))) { IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); - partitions[cond.get()] = Partition{cond, interval}; + if (!interval.is_nothing()) { + partitions[cond.get()] = Partition{cond, interval}; + } } } else { IRVisitor::Visit_(op); diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 8ae19eab6..9c97776df 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -148,6 +148,20 @@ def test_thread_axis2(): for_body = stmt.body.body.body.body.body.first assert('threadIdx' not in str(for_body.extent)) +def test_everything_during_deduction(): + m = tvm.var('m') + n = tvm.var('n') + ib = tvm.ir_builder.create() + with ib.for_range(0, n, 'i') as i: + with ib.for_range(0, 32, 'j') as j: + with ib.if_scope(ib.likely(i/j < m)): + # this guard will produce everything during deduction + ib.emit(tvm.make.Evaluate(m)) + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse)) + if __name__ == "__main__": test_basic() test_multi_loop() @@ -156,3 +170,4 @@ if __name__ == "__main__": test_vectorize() test_select() test_thread_axis2() + test_everything_during_deduction() -- GitLab