unroll_loop.cc 5.83 KiB
/*!
* Copyright (c) 2017 by Contributors
* Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
class LoopUnroller : public IRMutator {
public:
explicit LoopUnroller(int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_);
Stmt ret = this->Mutate(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->Mutate(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
} else {
return IRMutator::Mutate_(op, stmt);
}
}
Stmt Mutate_(const For* op, const Stmt& s) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<For>();
int value = GetExtent(op);
// condition for auto unroll
bool auto_unroll = (
op->for_type == ForType::Serial &&
value >= 0 &&
normal_loop_depth_ == 0 &&
unroll_depth_ <= auto_max_depth_);
auto_unroll = auto_unroll && (
value * step_count_ <= auto_max_step_||
value <= auto_max_extent_);
if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop";
auto_unroll = true;
}
if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
} else {
normal_loop_depth_ += 1;
}
if ((auto_unroll && explicit_unroll_) ||
// unroll loops with extent = 1, no matter how many steps in body
(value <= auto_max_extent_ && auto_max_extent_ == 1)) {
return Unroll(op);
} else {
if (auto_unroll) {
if (op->for_type != ForType::Unrolled) {
return For::make(
op->loop_var, op->min, op->extent,
ForType::Unrolled, op->device_api, op->body);
}
}
return stmt;
}
}
Stmt Mutate_(const Store* op, const Stmt& stmt) final {
++step_count_;
return IRMutator::Mutate_(op, stmt);
}
Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final {
++step_count_;
return IRMutator::Mutate_(op, stmt);
}
Stmt Mutate_(const Block* op, const Stmt& stmt) final {
Stmt first = this->Mutate(op->first);
// cleanup state
int step_count = step_count_;
int unroll_depth = unroll_depth_;
int normal_loop_depth = normal_loop_depth_;
step_count_ = 0;
unroll_depth_ = 0;
normal_loop_depth_ = 0;
// work on rest part
Stmt rest = this->Mutate(op->rest);
step_count_ += step_count;
normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return stmt;
} else {
return Block::make(first, rest);
}
}
Stmt Unroll(const For* op) {
using arith::ComputeExpr;
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
if (value == 0) return Evaluate::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Stmt unrolled;
for (int i = 0; i < value; ++i) {
Var lv(op->loop_var.node_);
vmap.Set(lv,
ComputeExpr<Add>(
op->min, make_const(op->loop_var.type(), i)));
Stmt step = Substitute(body, vmap);
if (unrolled.defined()) {
unrolled = Block::make(unrolled, step);
} else {
unrolled = step;
}
}
return unrolled;
}
private:
// returns the extent of the loop if it's a constant integer, otherwise return -1
int GetExtent(const For* op) {
// constant folding.
Expr extent = ir::Simplify(op->extent);
const IntImm *v1 = extent.as<IntImm>();
const UIntImm *v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
return value;
}
// maximum number of step to perform auto unroll.
int auto_max_step_;
int auto_max_depth_;
// max extent of loop to auto unroll
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
};
Stmt UnrollLoop(Stmt stmt,
int auto_max_step,
int auto_max_depth,
int auto_max_extent,
bool explicit_unroll) {
Stmt ret = LoopUnroller(
auto_max_step,
auto_max_depth,
auto_max_extent,
explicit_unroll).Mutate(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
return ret;
}
}
Stmt UnrollLoopExplicitly(Stmt stmt) {
const For* op = stmt.as<For>();
if (!op) {
LOG(FATAL) << "attempted to unroll a non-loop statement";
}
return LoopUnroller(0, 0, 0, false).Unroll(op);
}
} // namespace ir
} // namespace tvm