diff --git a/src/bound/bound.cc b/src/bound/bound.cc index bb2b80b492f871dc52c6fd82e0742c08102197d5..052920b932a6b20e7bb9395d8e433b50d0ec1a61 100644 --- a/src/bound/bound.cc +++ b/src/bound/bound.cc @@ -66,23 +66,21 @@ void PassUp(const Schedule& s, const std::unordered_map<IterVar, Range>& dom_map, std::unordered_map<IterVar, IntSet>* p_state) { auto& state = *p_state; - for (size_t i = s->relations.size(); i != 0;--i) { + for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; if (rel.as<SplitNode>()) { IntSet parent; const SplitNode* r = rel.as<SplitNode>(); - IntSet::PassUp( - r, dom_map, - state.at(r->outer), state.at(r->inner), - &parent); + PassUp(r, dom_map, + state.at(r->outer), state.at(r->inner), + &parent); state[r->parent] = parent; } else if (rel.as<FuseNode>()) { IntSet outer, inner; const FuseNode* r = rel.as<FuseNode>(); - IntSet::PassUp( - r, dom_map, - state.at(r->fused), - &outer, &inner); + PassUp(r, dom_map, + state.at(r->fused), + &outer, &inner); state[r->outer] = outer; state[r->inner] = inner; } else { diff --git a/src/bound/int_set.cc b/src/bound/int_set.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a4c8a5740d51d269eb0bb7f39c4e2c91a38313f --- /dev/null +++ b/src/bound/int_set.cc @@ -0,0 +1,343 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file int_set.cc + * \brief The integer set functions + */ +#include <tvm/ir.h> +#include "./int_set.h" + +namespace tvm { +namespace bound { + +using namespace ir; + +/*! + * \brief Internal node container of int set. + */ +class IntSetNode : public Node { + public: + /*! \brief The base range scope */ + Range base; + /*! \brief additional strided domain */ + Array<Range> domain; + /*! \brief The stride of each strided domain */ + Array<Expr> stride; + /*! + * \brief The concrete set, + * used when concrete execution is enabled. + */ + std::vector<int32_t> concrete; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("base", &base); + v->Visit("domain", &domain); + v->Visit("stride", &stride); + } + + static constexpr const char* _type_key = "IntSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntSetNode); +}; + +TVM_REGISTER_NODE_TYPE(IntSetNode); + +namespace { + +inline bool Match(const Expr& e, int64_t value) { + const ir::IntImm* v = e.as<ir::IntImm>(); + return v != nullptr && v->value; +} + +// whether a exactly matches b. +inline bool Match(const IntSet& a, + const Range& b) { + if (a->base == b && + a->domain.size() == 0 && + a->concrete.size() == 0) { + return true; + } else { + return false; + } +} + +// whether a exactly matches b. +inline bool Match(const IntSet& a, + const Expr& b) { + if (a->domain.size() == 0 && + a->concrete.size() == 0) { + return Match(a->base->extent, 1) && a->base->min.same_as(b); + } else { + return false; + } +} + +inline bool IsNumber(const IntSet& s) { + if (s->domain.size() != 0) return false; + if (s->concrete.size() != 0) { + return s->concrete.size() == 1; + } + return Match(s->base->extent, 1); +} + +inline Expr AsNumber(const IntSet& s) { + return s->base->min; +} + +// set combination rule by operators +template<typename T> +inline IntSet BinaryCombine(IntSet a, IntSet b) { + LOG(WARNING) << "cannot evaluate binary op " << T::_type_key; + return IntSet::make_all_set(); +} + +template<> +inline IntSet BinaryCombine<Add>(IntSet a, IntSet b) { + auto n = std::make_shared<IntSetNode>(*(a.operator->())); + for (size_t i = 0; i < b->domain.size(); ++i) { + n->domain.push_back(b->domain[i]); + n->stride.push_back(b->stride[i]); + } + + if (IsNumber(a)) { + n->base = Range::make_with_min_extent( + a->base->min + b->base->min, + b->base->extent); + } else if (IsNumber(b)) { + n->base = Range::make_with_min_extent( + a->base->min + b->base->min, + a->base->extent); + } else { + n->base = Range::make_with_min_extent( + a->base->min + b->base->min, + a->base->extent + b->base->extent - 1); + } + return IntSet(n); +} + +inline Range Negation(Range a) { + if (Match(a->extent, 1)) { + return Range::make_with_min_extent(-a->min, a->extent); + } else { + return Range::make_with_min_extent(-(a->min + a->extent - 1), a->extent); + } +} + +inline IntSet Negation(IntSet a) { + CHECK_EQ(a->concrete.size(), 0); + auto n = std::make_shared<IntSetNode>(); + n->base = Negation(a->base); + for (size_t i = 0; i < a->domain.size(); ++i) { + n->domain.push_back(Negation(a->domain[i])); + n->stride.push_back(a->stride[i]); + } + return IntSet(a); +} + +template<> +inline IntSet BinaryCombine<Sub>(IntSet a, IntSet b) { + return BinaryCombine<Add>(a, Negation(b)); +} + +inline IntSet BinaryMul(IntSet a, Expr b) { + // copy construct + if (Match(b, 1)) return a; + if (Match(b, -1)) return Negation(a); + auto n = std::make_shared<IntSetNode>(); + n->base = Range::make_with_min_extent(0, 1); + n->domain.push_back(a->base); + n->stride.push_back(b); + for (size_t i = 0; i < a->domain.size(); ++i) { + n->domain.push_back(a->domain[i]); + n->stride.push_back(a->stride[i] * b); + } + return IntSet(a); +} + +template<> +inline IntSet BinaryCombine<Mul>(IntSet a, IntSet b) { + if (IsNumber(a)) { + return BinaryMul(a, AsNumber(b)); + } else if (IsNumber(b)) { + return BinaryMul(b, AsNumber(a)); + } else { + return IntSet::make_all_set(); + } +} + +} // namespace + +inline const IntSetNode* IntSet::operator->() const { + return static_cast<const IntSetNode*>(node_.get()); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<IntSetNode>([](const IntSetNode *op, IRPrinter *p) { + p->stream << "int-set(base="; + p->print(op->base); + p->stream << ')'; + }); + +IntSet IntSet::make(Range dom) { + auto n = std::make_shared<IntSetNode>(); + n->base = dom; + return IntSet(n); +} + +void PassUp(const SplitNode* s, + const std::unordered_map<IterVar, Range>& dom_map, + const IntSet& outer, + const IntSet& inner, + IntSet* parent) { + if (dom_map.count(s->outer) && + dom_map.count(s->inner) && + dom_map.count(s->parent) && + Match(outer, dom_map.at(s->outer)) && + Match(inner, dom_map.at(s->inner))) { + *parent = IntSet::make(dom_map.at(s->parent)); + return; + } + // copy construct + auto n = std::make_shared<IntSetNode>(*(inner.operator->())); + + if (IsNumber(outer)) { + // shift the base offset + n->base = Range::make_with_min_extent( + AsNumber(outer) * s->factor + inner->base->min, + inner->base->extent); + *parent = IntSet(n); + } else { + // default use all domains in the data. + n->domain.push_back(outer->base); + n->stride.push_back(s->factor); + for (size_t i = 0; i < outer->domain.size(); ++i) { + n->domain.push_back(outer->domain[i]); + n->stride.push_back(outer->stride[i] * s->factor); + } + } +} + +void PassUp(const FuseNode* s, + const std::unordered_map<IterVar, Range>& dom_map, + const IntSet& fused, + IntSet* outer, + IntSet* inner) { + CHECK(dom_map.count(s->outer)); + CHECK(dom_map.count(s->inner)); + CHECK(dom_map.count(s->fused)); + + if (Match(fused, dom_map.at(s->fused))) { + *outer = IntSet::make(dom_map.at(s->outer)); + *inner = IntSet::make(dom_map.at(s->inner)); + return; + } + + if (IsNumber(fused)) { + Expr value = AsNumber(fused); + Expr factor = dom_map.at(s->outer)->extent; + *outer = IntSet::make(Range::make_with_min_extent(value / factor, 1)); + *inner = IntSet::make(Range::make_with_min_extent(value % factor, 1)); + } else { + LOG(WARNING) << "use fallback inference rule in fuse"; + // simply use the entire set, this rule can be enhanced. + *outer = IntSet::make(dom_map.at(s->outer)); + *inner = IntSet::make(dom_map.at(s->inner)); + return; + } +} + +namespace { +// evaluator to evaluate the int set +class IRSetEvaluator { + public: + inline IntSet Eval(Expr expr) { + static const FType& f = vtable(); + if (f.can_dispatch(expr)) { + return f(expr, expr, this); + } else { + LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); + return IntSet::make_all_set(); + } + } + + using FType = tvm::IRFunctor<IntSet (const NodeRef&, const Expr&, IRSetEvaluator *)>; + static FType& vtable() { // NOLINT(*) + static FType inst; return inst; + } + + std::unordered_map<const Variable*, IntSet> dom_map; +}; + +inline IntSet ConstOp(const NodeRef&, const Expr& e, IRSetEvaluator*) { + return IntSet::make(Range::make_with_min_extent(e, 1)); +} + +TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +.set_dispatch<IntImm>(ConstOp) +.set_dispatch<UIntImm>(ConstOp) +.set_dispatch<FloatImm>(ConstOp); + +TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +.set_dispatch<Variable>([](const Variable* op, const Expr& e, IRSetEvaluator* m) { + auto it = m->dom_map.find(op); + if (it != m->dom_map.end()) { + return it->second; + } else { + return IntSet::make(Range::make_with_min_extent(e, 1)); + } + }); + +// binary operator +template<typename T> +inline IntSet Binary(const T* op, const Expr& e, IRSetEvaluator* m) { + IntSet a = m->Eval(op->a); + IntSet b = m->Eval(op->b); + if (IsNumber(a) && IsNumber(b)) { + if (Match(a, op->a) && + Match(b, op->b)) { + return IntSet::make(Range::make_with_min_extent(e, 1)); + } else { + return IntSet::make(Range::make_with_min_extent( + T::make(AsNumber(a), AsNumber(b)), 1)); + } + } else { + return BinaryCombine<T>(a, b); + } +} + +TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +.set_dispatch<Add>(Binary<Add>) +.set_dispatch<Sub>(Binary<Sub>) +.set_dispatch<Mul>(Binary<Mul>) +.set_dispatch<Div>(Binary<Div>) +.set_dispatch<Mod>(Binary<Mod>) +.set_dispatch<Min>(Binary<Min>) +.set_dispatch<Max>(Binary<Max>); + +// use simply bound for logical expressions for now. +inline IntSet Logical(const NodeRef&, const Expr& e, IRSetEvaluator*) { + return IntSet::make(Range::make_with_min_extent(0, 2)); +} + +TVM_STATIC_IR_FUNCTOR(IRSetEvaluator, vtable) +.set_dispatch<EQ>(Logical) +.set_dispatch<NE>(Logical) +.set_dispatch<LT>(Logical) +.set_dispatch<LE>(Logical) +.set_dispatch<GT>(Logical) +.set_dispatch<GE>(Logical) +.set_dispatch<And>(Logical) +.set_dispatch<Or>(Logical); + +} // namespace + +IntSet Eval(Expr e, + const std::unordered_map<IterVar, IntSet>& dom_map) { + IRSetEvaluator m; + for (auto kv : dom_map) { + m.dom_map[kv.first->var.as<Variable>()] = kv.second; + } + return m.Eval(e); +} + +} // namespace bound +} // namespace tvm + diff --git a/src/bound/int_set.h b/src/bound/int_set.h index ddadacbadd55dbdfa111a7f1f749f2ba7ba28de5..f5531ea4b86680f199ceaabd4796a91a92be9de6 100644 --- a/src/bound/int_set.h +++ b/src/bound/int_set.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2016 by Contributors * \file int_set.h - * \brief Abstract class for iteration integer sets. + * \brief Abstraction for all integer set operations. */ #ifndef TVM_BOUND_INT_SET_H_ #define TVM_BOUND_INT_SET_H_ @@ -11,35 +11,92 @@ namespace tvm { namespace bound { + +// internal node container of int set. +class IntSetNode; + /*! - * \brief abstract class of integer set for iteration sets. + * \brief Integer set class, represent a set of integers in one dimension. */ -class IntSet { +class IntSet : public NodeRef { public: - // constructor - IntSet(); - // whether the set is same as range - bool SameAs(const Range& r) const; - // make integer set by range - static IntSet make(Range r); - // make integer set as a constant value - static IntSet make(Expr value); - // upward inference function - // get the int set of parent given int set of outer and inner - static void PassUp(const SplitNode* s, - const std::unordered_map<IterVar, Range>& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent); - // upward inference function - // get the int set of outer and inner given int set of fused. - static void PassUp(const FuseNode* s, - const std::unordered_map<IterVar, Range>& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner); + /*! \brief constructor */ + IntSet() {} + // constructor from not deontainer. + explicit IntSet(std::shared_ptr<Node> n) : NodeRef(n) {} + /*! \return whether the set is empty */ + inline bool is_empty() const { + return !defined(); + } + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const IntSetNode* operator->() const; + /*! + * \param dom The domain to be created. + * \return create integer set from existing domain + */ + static IntSet make(Range dom); + /*! + * \return create integer set that represents everything + */ + static IntSet make_all_set(); }; +/*! + * \brief Find an symbolic integer set that contains all possible values of + * e given the domain of each iteration variables. + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values of e. + */ +IntSet Eval(Expr e, + const std::unordered_map<IterVar, IntSet>& dom_map); +/*! + * \brief Conditional upward message passing. + * + * Get domain of parent, condition on domain of children. + * Domain is represented as IntSet. + * + * \param s The Split relation node. + * \param dom_map The old domain result from downward message passing. + * Contains the domain set if all the children are full set. + * \param outer domain of outer iteration. + * \param inner domain of inner iteration. + * \param parent The result domain of parent. + */ +void PassUp(const SplitNode* s, + const std::unordered_map<IterVar, Range>& dom_map, + const IntSet& outer, + const IntSet& inner, + IntSet* parent); +/*! + * \brief Conditional upward message passing. + * + * Get domain of parent, condition on domain of children. + * Domain is represented as IntSet. + * + * \param s The Fuse relation node. + * \param dom_map The old domain result from downward message passing. + * Contains the domain set if all the children are full set. + * \param fused domain of fused iteration. + * \param outer The result domain of outer iteration. + * \param inner The result domain of inner iteration. + */ +void PassUp(const FuseNode* s, + const std::unordered_map<IterVar, Range>& dom_map, + const IntSet& fused, + IntSet* outer, + IntSet* inner); +/*! + * \brief Create an union set of all sets + * \param sets The sets to be unioned + * \return the set after union + */ +IntSet Union(const Array<IntSet>& sets); + } // namespace bound } // namespace tvm diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index 9d5334635487930f9c7176ec2f5b48439b089550..aa7c5b51fb812e09b7d870063a71a0502cf1f7d9 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -152,7 +152,6 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor) { // NOLINT(*) - split(x_parent, p_x_outer, p_x_inner, x_factor); split(y_parent, p_y_outer, p_y_inner, y_factor); reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer})); diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index eb584bef614612c9340fc5e9b00f206002ef2278..13d7867539ebdbf0ea65419c6d0e37ed9ff5c39d 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -10,8 +10,6 @@ namespace tvm { namespace ir { namespace { - - } // namespace } // namespace ir } // namespace tvm