-
Tianqi Chen authoredTianqi Chen authored
storage_rewrite.cc 20.34 KiB
/*!
* Copyright (c) 2017 by Contributors
* \file storage_rewrite.cc
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
using namespace storage;
// Find a linear pattern of storage acess
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder final : public IRVisitor {
public:
// Get linear access pattern.
std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
this->Visit(s);
return std::move(linear_seq_);
}
void Visit_(const Allocate* op) final {
CHECK(!in_parallel_env_)
<< "Allocation inside parallel is not yet handled.";
size_t level = scope_.size();
const Variable* buf = op->buffer_var.get();
CHECK(!alloc_scope_level_.count(buf));
alloc_scope_level_[buf] = level;
StmtEntry e;
e.stmt = op;
e.access.emplace_back(
AccessEntry(buf, Expr(), kAlloc, GetScope(buf)));
linear_seq_.emplace_back(std::move(e));
IRVisitor::Visit_(op);
}
void Visit_(const Store* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
// Add write access.
const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kWrite, GetScope(buf)));
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void Visit_(const Evaluate* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void Visit_(const Load* op) final {
// Add write access.
IRVisitor::Visit_(op);
const Variable* buf = op->buffer_var.get();
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second].access.emplace_back(
AccessEntry(buf, op->index, kRead, GetScope(buf)));
}
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load* l = op->args[0].as<Load>();
this->Visit(l->index);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second].access.emplace_back(
AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
}
}
template<typename T>
void VisitNewScope(const T* op) {
scope_.push_back(StmtEntry());
StmtEntry e;
e.stmt = op;
// before scope.
linear_seq_.push_back(e);
IRVisitor::Visit_(op);
// after scope.
e.access = std::move(scope_.back().access);
scope_.pop_back();
linear_seq_.push_back(e);
}
void Visit_(const AttrStmt* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
VisitNewScope(op);
in_thread_env_ = false;
} else if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const For* op) final {
if (op->for_type == ForType::Parallel) {
bool in_par = in_parallel_env_;
in_parallel_env_ = true;
VisitNewScope(op);
in_parallel_env_ = in_par;
} else {
VisitNewScope(op);
}
}
void Visit_(const IfThenElse* op) final {
VisitNewScope(op);
}
private:
// Get storage scope of buffer.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
CHECK(it != storage_scope_.end());
return it->second;
}
// Whether already in thread env.
bool in_thread_env_{false};
// Whether already in parallel env.
bool in_parallel_env_{false};
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The scope stack.
std::vector<StmtEntry> scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
// buffer -> allocated scope level in the IR.
std::unordered_map<const Variable*, size_t> alloc_scope_level_;
};
// Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator {
public:
Stmt Rewrite(Stmt stmt) {
std::vector<StmtEntry> seq =
StorageAccessPatternFinder().GetLinearSeq(stmt);
this->FindFreeLocation(seq);
this->PlanMemory(seq);
this->PrepareNewAlloc();
stmt = this->Mutate(stmt);
if (attach_map_.count(nullptr)) {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(nullptr)) {
CHECK_EQ(e->scope.rank, 0);
if (e->new_alloc.defined()) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
}
stmt = MergeNest(nest, stmt);
}
return stmt;
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var,
op->value,
RemapIndex(op->value.type(), op->index, it->second),
op->predicate);
}
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
return Load::make(op->type,
it->second->alloc_var,
RemapIndex(op->type, op->index, it->second),
op->predicate);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->elem_offset != 0) {
LOG(WARNING) << "Use a merged buffer variable address, could cause error";
}
return it->second->alloc_var;
} else {
return e;
}
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
Type dtype = op->args[0].type();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
const StorageEntry* se = it->second;
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
CHECK_EQ(se->elem_type, dtype.element_of())
<< " buffer=" << buffer->name_hint;
CHECK_EQ(se->elem_offset % dtype.lanes(), 0);
if (se->elem_offset != 0) {
offset = make_const(offset.type(), se->elem_offset / dtype.lanes()) + offset;
}
return Call::make(
op->type, op->name,
{op->args[0], se->alloc_var, offset, extent, op->args[4]},
op->call_type);
} else {
return IRMutator::Mutate_(op, e);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
CHECK(op->attr_key != attr::virtual_thread)
<< "InjectVirtualThread before StoragePlan";
if (op->attr_key == attr::storage_scope) {
return this->Mutate(op->body);
} else if (op->attr_key == attr::thread_extent) {
// remake all the allocation at the thread extent.
if (attach_map_.count(op)) {
std::vector<Stmt> nest;
for (StorageEntry* e : attach_map_.at(op)) {
nest.emplace_back(AttrStmt::make(
e->alloc_var, attr::storage_scope,
StringImm::make(e->scope.to_string()),
Evaluate::make(0)));
nest.push_back(e->new_alloc);
}
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
Stmt body = MergeNest(nest, op->body);
return AttrStmt::make(
op->node, op->attr_key, op->value, body);
} else {
return IRMutator::Mutate_(op, s);
}
} else if (op->attr_key == attr::volatile_scope) {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<AttrStmt>();
auto it = alloc_map_.find(op->node.as<Variable>());
if (it == alloc_map_.end()) return stmt;
return AttrStmt::make(
it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For* op, const Stmt& s) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
return this->Mutate(op->body);
}
private:
// Alllocate entry of node.
struct StorageEntry {
// The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Node* attach_scope_{nullptr};
// The constant size of the buffer in bits, only used if it is constant
size_t const_nbits{0};
// The storage scope.
StorageScope scope;
// Allocs that shares this entry.
std::vector<const Allocate*> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry*> merged_children;
// The replacement allocation, if any.
Stmt new_alloc;
// The var expr of new allocation.
VarExpr alloc_var;
// The allocation element type.
Type elem_type;
// This is non-zero if this allocate is folded into another one
// the address becomes alloc_var + sizeof(elem_type) * elem_offset;
size_t elem_offset{0};
};
// Remap the index
Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
CHECK_EQ(dtype.element_of(), e->elem_type);
if (e->elem_offset == 0) return index;
return make_const(index.type(), e->elem_offset) + index;
}
// Prepare the new allocations
void PrepareNewAlloc() {
for (size_t i = 0; i < alloc_vec_.size(); ++i) {
StorageEntry* e = alloc_vec_[i].get();
attach_map_[e->attach_scope_].push_back(e);
}
// find allocation via attach map.
for (auto &kv : attach_map_) {
// find the element with the most amount of bytes.
std::vector<StorageEntry*>& vec = kv.second;
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
if (e->scope.tag.length() != 0) {
CHECK_NE(e->const_nbits, 0U)
<< "Special tagged memory must be const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
vec[j]->merged_children.push_back(e);
break;
}
}
}
}
// Start allocation
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
// already merged
if (e->elem_offset != 0) continue;
if (e->merged_children.size() != 0) {
NewAllocTagMerged(e); continue;
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
Type alloc_type = e->allocs[0]->type;
for (const Allocate* op : e->allocs) {
if (op->type.lanes() > alloc_type.lanes()) {
alloc_type = op->type;
}
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate::make(0));
} else {
// Build a merged allocation
Expr combo_size;
for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents);
if (alloc_type.lanes() != op->type.lanes()) {
sz = (sz * make_const(sz.type(), op->type.lanes()) +
make_const(sz.type(), alloc_type.lanes() - 1)) /
make_const(sz.type(), alloc_type.lanes());
}
if (combo_size.defined()) {
combo_size = max(combo_size, sz);
} else {
combo_size = sz;
}
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate::make(0));
}
}
}
}
// New allocation for merged data
void NewAllocTagMerged(StorageEntry* e) {
CHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type.
CHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
size_t align = 1;
if (info.defined()) {
align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits();
}
size_t total_elem = e->const_nbits / e->elem_type.bits();
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry* child : e->merged_children) {
CHECK_NE(e->const_nbits, 0U);
CHECK_NE(total_elem, 0U);
size_t num_elem = child->const_nbits / child->elem_type.bits();
child->elem_offset = total_elem;
child->alloc_var = e->alloc_var;
total_elem += num_elem;
if (total_elem % align != 0) {
total_elem += align - (total_elem % align);
}
}
Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), total_elem);
e->new_alloc = Allocate::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
Evaluate::make(0));
if (info.defined()) {
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
// Find the free location of each varaible.
// Just do a reverse linear scan.
void FindFreeLocation(const std::vector<StmtEntry>& seq) {
std::unordered_set<const Variable*> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
for (const AccessEntry& e : s.access) {
if (!touched.count(e.buffer)) {
touched.insert(e.buffer);
free_loc_[i - 1].push_back(e.buffer);
}
}
}
}
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (s.stmt->is_type<AttrStmt>()) {
const auto* op = static_cast<const AttrStmt*>(s.stmt);
CHECK_EQ(op->attr_key, attr::thread_extent);
if (thread_scope_ != nullptr) {
CHECK(thread_scope_ == op);
// erase all non-global memory from constant free map.
for (auto it = const_free_map_.begin();
it != const_free_map_.end();) {
if (it->second->scope.rank != 0) {
it = const_free_map_.erase(it);
} else {
++it;
}
}
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
}
} else if (s.stmt->is_type<Allocate>()) {
const auto* op = static_cast<const Allocate*>(s.stmt);
StorageEntry* e = this->FindAlloc(op, s.access[0].scope);
e->allocs.emplace_back(op);
alloc_map_[op->buffer_var.get()] = e;
}
// free list
if (free_loc_.count(i)) {
for (const Variable* var : free_loc_.at(i)) {
this->Free(var);
}
}
}
}
// Allocate new storage entry.
StorageEntry* NewAlloc(const Allocate* op,
const StorageScope& scope,
size_t const_nbits) {
// Re-use not successful, allocate a new buffer.
std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = thread_scope_;
entry->scope = scope;
entry->elem_type = op->type.element_of();
entry->const_nbits = const_nbits;
StorageEntry* e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
return e;
}
StorageEntry* FindAlloc(const Allocate* op,
const StorageScope& scope) {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const size_t match_range = 16;
size_t const_nbits = static_cast<size_t>(
op->constant_allocation_size() * op->type.bits() * op->type.lanes());
if (scope.rank > 1 || op->type.is_handle()) {
return NewAlloc(op, scope, const_nbits);
}
// disable reuse of small arrays, they will be lowered to registers in LLVM
if (const_nbits > 0 &&
const_nbits <= 32 &&
scope.tag.length() == 0) {
return NewAlloc(op, scope, const_nbits);
}
if (const_nbits != 0) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
const_free_map_.erase(it);
return e;
}
} else {
// Simple strategy: round roubin.
for (auto it = sym_free_list_.begin();
it != sym_free_list_.end(); ++it) {
StorageEntry* e = *it;
if (e->scope != scope) continue;
if (e->elem_type != op->type.element_of()) continue;
sym_free_list_.erase(it);
return e;
}
}
return NewAlloc(op, scope, const_nbits);
}
// simulated free.
void Free(const Variable* var) {
auto it = alloc_map_.find(var);
CHECK(it != alloc_map_.end());
StorageEntry* e = it->second;
// Disable sharing of local memory.
if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32) return;
// normal free.
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
} else {
sym_free_list_.push_back(e);
}
}
// thread scope.
const Node* thread_scope_{nullptr};
// Locations of free ops.
std::unordered_map<size_t,
std::vector<const Variable*> > free_loc_;
// The allocation attach map
std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
// constant size free map.
std::multimap<size_t, StorageEntry*> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry*> sym_free_list_;
// The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
};
Stmt StorageRewrite(Stmt stmt) {
return StoragePlanRewriter().Rewrite(stmt);
}
} // namespace ir
} // namespace tvm