diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 945f7a8c30b96d12ff14d86d92fe602f6b1a8c51..8a88ed23e2625ebc52c675d9e4d58621443f8d4b 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -132,27 +132,18 @@ IntSet IntSet::interval(Expr min, Expr max) { return IntervalSet::make(min, max); } +inline bool prove_equal(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + // Check if a is created from b. bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; const IntervalSet* a_int = a.as<IntervalSet>(); if (!a_int) return false; const Interval& i = a_int->i; - if (!i.min.same_as(b)) return false; - if (is_one(b->extent)) return i.is_single_point(); - if (is_positive_const(b->extent) && is_const(b->min)) { - // deep equality - return Equal( - ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1), - a_int->i.max); - } - const Sub* sub = i.max.as<Sub>(); - if (!sub) return false; - if (is_one(sub->b)) return false; - const Add* add = sub->a.as<Add>(); - return add && - add->a.same_as(b->min) && - add->b.same_as(b->extent); + return prove_equal(i.min, b->min) && + prove_equal(i.max, ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1)); } inline bool MatchPoint(const IntSet& a,