diff --git a/HalideIR b/HalideIR index dbf043a8d8bf379b05c56d8aa9025db55f589d6d..a40a3e2fedee88d2f7b97ba4caf8a9d0eb25886f 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit dbf043a8d8bf379b05c56d8aa9025db55f589d6d +Subproject commit a40a3e2fedee88d2f7b97ba4caf8a9d0eb25886f diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 9bc2d22318ffe8cbc6b9ae7598d84179c49fd2a3..6f7afa0b75d4e10cb2f501fd90922514859512bc 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -118,7 +118,7 @@ class IntSet : public NodeRef { * \brief Range of a linear integer function. * Use to do specify the possible index values. * - * set = { base + coeff * x | x in Z } + * set = { coeff * x + base | x in Z } * * When coeff != 0, it can also be written as * set = { n | n % coeff == base } @@ -127,16 +127,17 @@ class IntSet : public NodeRef { * For example, if index = 0 + 4 x, then we know it can be divided by 4. */ struct ModularEntry { - /*! \brief The base */ - int base{0}; /*! \brief linear co-efficient */ int coeff{1}; + /*! \brief The base */ + int base{0}; /*! \return entry represent everything */ static ModularEntry everything() { // always safe to set 0 + x, so it can be everything. ModularEntry e; - e.base = 0; e.coeff = 1; + e.coeff = 1; + e.base = 0; return e; } /*! @@ -157,14 +158,25 @@ struct IntSetNode : public Node { TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); }; - /*! - * \brief Detect if e can be rewritten as e = base + var * coeff + * \brief Detect if e can be rewritten as e = sum_{i=0}^n var[i] * coeff[i] + coeff[n] * Where coeff and base are invariant of var. * - * \return [base, coeff] if it is possible, empty array if it is not. + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return [coeff[i]] if it is possible, empty array if it is not. + */ +Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars); + +/*! + * \brief Detect if expression corresponds to clip bound of the vars + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value + * return empty if the e does not match the pattern. */ -Array<Expr> DetectLinearEquation(Expr e, Var var); +Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars); /*! * \brief Find an symbolic integer set that contains all possible values of diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index e3b619f1e66ae9e66eeca1c918b401fc9cf13296..31ff5ccb3a15d8f2f4957017936bd806620f0283 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -36,6 +36,11 @@ TVM_REGISTER_API("arith.DetectLinearEquation") *ret = DetectLinearEquation(args[0], args[1]); }); +TVM_REGISTER_API("arith.DetectClipBound") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DetectClipBound(args[0], args[1]); + }); + TVM_REGISTER_API("arith.DeduceBound") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = DeduceBound(args[0], args[1], diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index b4f7db50fd84944c894b8ac50218189cb1237673..63f582160312a7f013d37987f4714f25db25d5c3 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -21,22 +21,27 @@ struct LinearEqEntry { Expr coeff; }; +struct IntervalEntry { + Expr min_value; + Expr max_value; +}; + class LinearEqDetector : public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> { public: explicit LinearEqDetector(Var var) : var_(var) {} - Array<Expr> Detect(const Expr& e) { - LinearEqEntry ret = VisitExpr(e, e); - if (fail_) return Array<Expr>(); - if (!ret.base.defined()) { - ret.base = make_zero(var_.type()); + bool Detect(const Expr& e, LinearEqEntry* ret) { + *ret = VisitExpr(e, e); + if (fail_) return false; + if (!ret->base.defined()) { + ret->base = make_zero(var_.type()); } - if (!ret.coeff.defined()) { - ret.coeff = make_zero(var_.type()); + if (!ret->coeff.defined()) { + ret->coeff = make_zero(var_.type()); } - return Array<Expr>{ret.base, ret.coeff}; + return true; } LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final { @@ -48,6 +53,17 @@ class LinearEqDetector ret.coeff = AddCombine(a.coeff, b.coeff); return ret; } + + LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final { + if (fail_) return LinearEqEntry(); + LinearEqEntry a = VisitExpr(op->a, op->a); + LinearEqEntry b = VisitExpr(op->b, op->b); + LinearEqEntry ret; + ret.base = SubCombine(a.base, b.base); + ret.coeff = SubCombine(a.coeff, b.coeff); + return ret; + } + LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); @@ -94,6 +110,11 @@ class LinearEqDetector if (!b.defined()) return a; return ComputeExpr<Add>(a, b); } + Expr SubCombine(Expr a, Expr b) { + if (!a.defined()) return -b; + if (!b.defined()) return a; + return ComputeExpr<Sub>(a, b); + } Expr MulCombine(Expr a, Expr b) { if (!a.defined()) return a; if (!b.defined()) return b; @@ -101,9 +122,134 @@ class LinearEqDetector } }; -Array<Expr> DetectLinearEquation(Expr e, Var var) { - return LinearEqDetector(var).Detect(e); +Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) { + CHECK_GE(vars.size(), 1U); + Expr base = e; + Array<Expr> coeff; + + for (Var v : vars) { + LinearEqEntry ret; + if (!LinearEqDetector(v).Detect(base, &ret)) { + return Array<Expr>(); + } + coeff.push_back(ret.coeff); + base = std::move(ret.base); + } + + std::unordered_set<const Variable*> vset; + for (size_t i = vars.size(); i != 1; --i) { + vset.insert(vars[i - 1].get()); + // The previous coeff contains the variable + if (ExprUseVar(coeff[i - 2], vset)) { + return Array<Expr>(); + } + } + coeff.push_back(base); + return coeff; } +// Detect clip condition as min max value +bool DetectClipBound( + const Expr& cond, + std::unordered_map<const Variable*, IntervalEntry>* bmap) { + int flag = 0; + Var var; + auto fvisit = [&bmap, &flag, &var](const NodeRef& n) { + if (const Variable* v = n.as<Variable>()) { + if (bmap->count(v)) { + if (flag == 0) { + var = Var(n.node_); + flag = 1; + } else if (flag == 1) { + if (!var.same_as(n)) { + flag = -1; + } + } + } + } + }; + PostOrderVisit(cond, fvisit); + if (flag != 1) return false; + // canonical form: exp >= 0 + Expr canonical; + if (const LT* op = cond.as<LT>()) { + if (!op->a.type().is_int()) return false; + canonical = op->b - op->a - make_const(op->a.type(), 1); + } else if (const LE* op = cond.as<LE>()) { + if (!op->a.type().is_int()) return false; + canonical = op->b - op->a; + } else if (const GT* op = cond.as<GT>()) { + if (!op->a.type().is_int()) return false; + canonical = op->a - op->b - make_const(op->a.type(), 1); + } else if (const GE* op = cond.as<GE>()) { + if (!op->a.type().is_int()) return false; + canonical = op->a - op->b; + } else { + return false; + } + LinearEqEntry ret; + if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; + ret.coeff = Simplify(ret.coeff); + IntervalEntry& p = (*bmap)[var.get()]; + if (is_one(ret.coeff)) { + // var + shift >=0 -> var >= -shift + if (p.min_value.defined()) { + p.min_value = ir::Max::make(p.min_value, -ret.base); + } else { + p.min_value = -ret.base; + } + return true; + } + if (is_const(ret.coeff, -1)) { + // -var + shift >=0 -> var <= shift + if (p.max_value.defined()) { + p.max_value = ir::Min::make(p.max_value, ret.base); + } else { + p.max_value = ret.base; + } + return true; + } + return false; +} + + +template<typename OP> +void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) { + if (const OP* op = e.as<OP>()) { + SplitCommExpr<OP>(op->a, ret); + SplitCommExpr<OP>(op->b, ret); + } else { + ret->push_back(e); + } +} + +// Detect the lower and upper bound from the expression. +// e must be connected by and. +Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) { + std::vector<Expr> splits; + SplitCommExpr<ir::And>(e, &splits); + std::unordered_map<const Variable*, IntervalEntry> rmap; + for (Var v : vars) { + rmap[v.get()] = IntervalEntry(); + } + for (Expr cond : splits) { + if (!DetectClipBound(cond, &rmap)) return Array<Expr>(); + } + Array<Expr> ret; + for (Var v : vars) { + IntervalEntry e = rmap[v.get()]; + if (e.min_value.defined()) { + e.min_value = Simplify(e.min_value); + } + if (e.max_value.defined()) { + e.max_value = Simplify(e.max_value); + } + ret.push_back(e.min_value); + ret.push_back(e.max_value); + } + return ret; +} + + } // namespace arith } // namespace tvm diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 219fe769bba028adae77b2f9e660f6dd0c8c90a8..733eeffb632e0506035afbdd03dbfb6ae46abb00 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -175,10 +175,10 @@ class ChannelAccessRewriter : public IRMutator { r = Range::make_by_min_extent( ir::Simplify(r->min), ir::Simplify(r->extent)); if (ExprUseVar(r->extent, var)) return body; - Array<Expr> linear_eq = DetectLinearEquation(r->min, var); + Array<Expr> linear_eq = DetectLinearEquation(r->min, {var}); if (linear_eq.size() == 0) return body; - Expr base = linear_eq[0]; - Expr coeff = linear_eq[1]; + Expr coeff = linear_eq[0]; + Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); if (!can_prove(left >= 0)) return body; diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ea61ac072229d98710f4d03a5ca99ea8933986 --- /dev/null +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -0,0 +1,21 @@ +import tvm + +def test_basic(): + a = tvm.var("a") + b = tvm.var("b") + c = tvm.var("c") + m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, + a - 1 > 0), [a]) + assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 + assert m[0].value == 2 + m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, + a - 1 > 0), [a, b]) + assert len(m) == 0 + m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20, + b - 1 > 0), [a, b]) + assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 + assert tvm.ir_pass.Simplify(m[2] - 2).value == 0 + + +if __name__ == "__main__": + test_basic() diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 412effa2f0a27ab46511e513b27abb8aedc65144..9d875c910d1c10690b681734c2d5089ac5d75788 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -3,22 +3,41 @@ import tvm def test_basic(): a = tvm.var("a") b = tvm.var("b") - m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, a) - assert m[1].value == 4 - assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7)).value == 0 + m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a]) + assert m[0].value == 4 + assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 - m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, a) + m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a]) assert len(m) == 0 - m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, a) - assert m[1].value == 5 - assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0 + m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a]) + assert m[0].value == 5 + assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 - m = tvm.arith.DetectLinearEquation(a * b + 7, a) - assert m[1] == b + m = tvm.arith.DetectLinearEquation(a * b + 7, [a]) + assert m[0] == b - m = tvm.arith.DetectLinearEquation(b * 7, a) - assert m[1].value == 0 + m = tvm.arith.DetectLinearEquation(b * 7, [a]) + assert m[0].value == 0 + +def test_multivariate(): + v = [tvm.var("v%d" % i) for i in range(4)] + b = tvm.var("b") + m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v) + assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5)) + assert(m[1].value == 8) + + m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) + assert(len(m) == 0) + + m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v) + assert(len(m) == 0) + + m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v) + assert(m[1].value == 16) + assert(m[2].value == 2) + assert(m[len(m)-1].value == 2) if __name__ == "__main__": test_basic() + test_multivariate()