Skip to content
Snippets Groups Projects
Commit cea88d00 authored by tqchen's avatar tqchen
Browse files

Skeleton of bound inference passing rule

parent f650216b
No related branches found
No related tags found
No related merge requests found
......@@ -171,4 +171,13 @@ inline IterVar::operator Expr() const {
}
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::IterVar> {
std::size_t operator()(const ::tvm::IterVar& k) const {
return k.hash();
}
};
}
#endif // TVM_EXPR_H_
......@@ -2,5 +2,5 @@
- c_api C API related functions
- lang The definition of DSL related data structure
- schedule The Schedule->Stmt generation logic
- codegen Backend code generation related
\ No newline at end of file
- pass The optimization pass on the IR structure
- bound Bound inference logics.
/*!
* Copyright (c) 2016 by Contributors
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/ir.h>
#include "./int_set.h"
#include "./bound.h"
namespace tvm {
namespace bound {
// result = ceil((a / b)), both a and b are positive integer
inline Expr DivCeil(Expr a, Expr b) {
return (a + b - 1) / b;
}
// Downward message passing algorithm on schedule s,
// pass the range state down from the root to the leaves
// after this pass, every IterVar in the schedule hyper graph will have a range(domain)
void PassDown(const Schedule& s,
std::unordered_map<IterVar, Range>* p_state) {
auto& state = *p_state;
// forwar iteration on relations
for (size_t i = 0; i < s->relations.size(); ++i) {
IterVarRelation rel = s->relations[i];
if (rel.as<SplitNode>()) {
const SplitNode* r = rel.as<SplitNode>();
CHECK(state.count(r->parent));
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
if (r->factor.defined()) {
state[r->inner] = Range::make_with_min_extent(0, r->factor);
if (r->outer->dom.defined()) {
state[r->outer] = r->outer->dom;
} else {
CHECK(!state.count(r->outer));
state[r->outer] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->factor));
}
} else {
CHECK(r->outer->dom.defined());
state[r->outer] = r->outer->dom;
state[r->inner] = Range::make_with_min_extent(
0, DivCeil(range_parent->extent, r->outer->dom->extent));
}
} else if (rel.as<FuseNode>()) {
const FuseNode* r = rel.as<FuseNode>();
CHECK(state.count(r->outer));
CHECK(state.count(r->inner));
const Range& range_outer = state.at(r->outer);
const Range& range_inner = state.at(r->inner);
state[r->fused] = Range::make_with_min_extent(
0, range_outer->extent * range_inner->extent);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
// upward message passing algorithm
// pass the integer set on each leave loop up to the root
// dom_map is the result of PassDown, it records the domain of each IterVar.
// dom_map can be used to get cached result in reverse construction.
void PassUp(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, IntSet>* p_state) {
auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0;--i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
IntSet parent;
const SplitNode* r = rel.as<SplitNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->outer), state.at(r->inner),
&parent);
state[r->parent] = parent;
} else if (rel.as<FuseNode>()) {
IntSet outer, inner;
const FuseNode* r = rel.as<FuseNode>();
IntSet::PassUp(
r, dom_map,
state.at(r->fused),
&outer, &inner);
state[r->outer] = outer;
state[r->inner] = inner;
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
} // namespace bound
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
*/
#ifndef TVM_BOUND_BOUND_H_
#define TVM_BOUND_BOUND_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
namespace tvm {
namespace bound {
/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
*
* \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable
*/
std::unordered_map<IterVar, Range> InferBound(Schedule sch);
} // namespace bound
} // namespace tvm
#endif // TVM_BOUND_BOUND_H_
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.h
* \brief Abstract class for iteration integer sets.
*/
#ifndef TVM_BOUND_INT_SET_H_
#define TVM_BOUND_INT_SET_H_
#include <tvm/expr.h>
#include <tvm/schedule.h>
namespace tvm {
namespace bound {
/*!
* \brief abstract class of integer set for iteration sets.
*/
class IntSet {
public:
// constructor
IntSet();
// whether the set is same as range
bool SameAs(const Range& r) const;
// make integer set by range
static IntSet make(Range r);
// make integer set as a constant value
static IntSet make(Expr value);
// upward inference function
// get the int set of parent given int set of outer and inner
static void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
// upward inference function
// get the int set of outer and inner given int set of fused.
static void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);
};
} // namespace bound
} // namespace tvm
#endif // TVM_BOUND_INT_SET_H_
......@@ -148,8 +148,9 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this;
}
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)
split(x_parent, p_x_outer, p_x_inner, x_factor);
......
......@@ -11,71 +11,6 @@ namespace tvm {
namespace ir {
namespace {
/*!
* \brief make nest loops given list of stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body The inner-most body of the loop
*/
Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back()); nest.pop_back();
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
Stmt MakePipeline(const Schedule& sch, Stmt body) {
return body;
}
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(Schedule sch)
: sch_(sch) {}
Stmt Mutate(Stmt stmt) final {
const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr &&
op->type_key == "split" &&
op->node == sch_->attach_parent) {
return AttrStmt::make(
op->node, op->type_key, op->value,
MakePipeline(sch_, op->body));
} else {
return stmt;
}
}
private:
// the operations to be carried
Schedule sch_;
Scope<AttrKey, Expr> attr_scope_;
};
} // namespace
} // namespace ir
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment