Skip to content
Snippets Groups Projects
storage_rewrite.cc 20.34 KiB
/*!
 * Copyright (c) 2017 by Contributors
 * \file storage_rewrite.cc
 * \brief Memory access pattern analysis and optimization.
 *  Re-write data access to enable memory sharing when possible.
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include "./ir_util.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace ir {

using namespace storage;
// Find a linear pattern of storage acess
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class StorageAccessPatternFinder final : public IRVisitor {
 public:
  // Get linear access pattern.
  std::vector<StmtEntry> GetLinearSeq(const Stmt& s) {
    this->Visit(s);
    return std::move(linear_seq_);
  }
  void Visit_(const Allocate* op) final {
    CHECK(!in_parallel_env_)
        << "Allocation inside parallel is not yet handled.";
    size_t level = scope_.size();
    const Variable* buf = op->buffer_var.get();
    CHECK(!alloc_scope_level_.count(buf));
    alloc_scope_level_[buf] = level;
    StmtEntry e;
    e.stmt = op;
    e.access.emplace_back(
        AccessEntry(buf, Expr(), kAlloc, GetScope(buf)));
    linear_seq_.emplace_back(std::move(e));
    IRVisitor::Visit_(op);
  }
  void Visit_(const Store* op) final {
    scope_.push_back(StmtEntry());
    // visit subexpr
    IRVisitor::Visit_(op);
    // Add write access.
    const Variable* buf = op->buffer_var.get();
    auto it = alloc_scope_level_.find(buf);
    if (it != alloc_scope_level_.end()) {
      scope_[it->second].access.emplace_back(
        AccessEntry(buf, op->index, kWrite, GetScope(buf)));
    }
    StmtEntry e = scope_.back();
    scope_.pop_back();
    if (e.access.size() != 0) {
      e.stmt = op;
      linear_seq_.push_back(e);
    }
  }
  void Visit_(const Evaluate* op) final {
    scope_.push_back(StmtEntry());
    // visit subexpr
    IRVisitor::Visit_(op);
    StmtEntry e = scope_.back();
    scope_.pop_back();
    if (e.access.size() != 0) {
      e.stmt = op;
      linear_seq_.push_back(e);
    }
  }
  void Visit_(const Load* op) final {
    // Add write access.
    IRVisitor::Visit_(op);
    const Variable* buf = op->buffer_var.get();
    auto it = alloc_scope_level_.find(buf);
    if (it != alloc_scope_level_.end()) {
      CHECK_LT(it->second, scope_.size())
          << "Load memory in places other than store.";
      scope_[it->second].access.emplace_back(
          AccessEntry(buf, op->index, kRead, GetScope(buf)));
    }
  }
  void Visit_(const Call* op) final {
    if (op->is_intrinsic(intrinsic::tvm_address_of)) {
      const Load* l = op->args[0].as<Load>();
      this->Visit(l->index);
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const Variable* buf) final {
    // Directly reference to the variable count as a read.
    auto it = alloc_scope_level_.find(buf);
    if (it != alloc_scope_level_.end()) {
      CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
      scope_[it->second].access.emplace_back(
          AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
    }
  }
  template<typename T>
  void VisitNewScope(const T* op) {
    scope_.push_back(StmtEntry());
    StmtEntry e;
    e.stmt = op;
    // before scope.
    linear_seq_.push_back(e);
    IRVisitor::Visit_(op);
    // after scope.
    e.access = std::move(scope_.back().access);
    scope_.pop_back();
    linear_seq_.push_back(e);
  }
  void Visit_(const AttrStmt* op) final {
    // Only record the outer most thread extent.
    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
      in_thread_env_ = true;
      VisitNewScope(op);
      in_thread_env_ = false;
    } else if (op->attr_key == attr::storage_scope) {
      const Variable* buf = op->node.as<Variable>();
      storage_scope_[buf] =
          StorageScope::make(op->value.as<StringImm>()->value);
      IRVisitor::Visit_(op);
    } else {
      IRVisitor::Visit_(op);
    }
  }
  void Visit_(const For* op) final {
    if (op->for_type == ForType::Parallel) {
      bool in_par = in_parallel_env_;
      in_parallel_env_ = true;
      VisitNewScope(op);
      in_parallel_env_ = in_par;
    } else {
      VisitNewScope(op);
    }
  }
  void Visit_(const IfThenElse* op) final {
    VisitNewScope(op);
  }

 private:
  // Get storage scope of buffer.
  StorageScope GetScope(const Variable* buf) const {
    auto it = storage_scope_.find(buf);
    CHECK(it != storage_scope_.end());
    return it->second;
  }
  // Whether already in thread env.
  bool in_thread_env_{false};
  // Whether already in parallel env.
  bool in_parallel_env_{false};
  // linearized access sequence.
  std::vector<StmtEntry> linear_seq_;
  // The scope stack.
  std::vector<StmtEntry> scope_;
  // The storage scope of each buffer
  std::unordered_map<const Variable*, StorageScope> storage_scope_;
  // buffer -> allocated scope level in the IR.
  std::unordered_map<const Variable*, size_t> alloc_scope_level_;
};

// Planner to plan and rewrite memory allocation.
class StoragePlanRewriter : public IRMutator {
 public:
  Stmt Rewrite(Stmt stmt) {
    std::vector<StmtEntry> seq =
        StorageAccessPatternFinder().GetLinearSeq(stmt);
    this->FindFreeLocation(seq);
    this->PlanMemory(seq);
    this->PrepareNewAlloc();
    stmt = this->Mutate(stmt);
    if (attach_map_.count(nullptr)) {
      std::vector<Stmt> nest;
      for (StorageEntry* e : attach_map_.at(nullptr)) {
        CHECK_EQ(e->scope.rank, 0);
        if (e->new_alloc.defined()) {
          nest.emplace_back(AttrStmt::make(
              e->alloc_var, attr::storage_scope,
              StringImm::make(e->scope.to_string()),
              Evaluate::make(0)));
          nest.push_back(e->new_alloc);
        }
      }
      stmt = MergeNest(nest, stmt);
    }
    return stmt;
  }
  Stmt Mutate_(const Store* op, const Stmt& s) final {
    Stmt stmt = IRMutator::Mutate_(op, s);
    op = stmt.as<Store>();
    auto it = alloc_map_.find(op->buffer_var.get());
    if (it == alloc_map_.end()) return stmt;
    return Store::make(it->second->alloc_var,
                       op->value,
                       RemapIndex(op->value.type(), op->index, it->second),
                       op->predicate);
  }
  Expr Mutate_(const Load* op, const Expr& e) final {
    Expr expr = IRMutator::Mutate_(op, e);
    op = expr.as<Load>();
    auto it = alloc_map_.find(op->buffer_var.get());
    if (it == alloc_map_.end()) return expr;
    return Load::make(op->type,
                      it->second->alloc_var,
                      RemapIndex(op->type, op->index, it->second),
                      op->predicate);
  }
  Expr Mutate_(const Variable* op, const Expr& e) final {
    auto it = alloc_map_.find(op);
    if (it != alloc_map_.end()) {
      if (it->second->elem_offset != 0) {
        LOG(WARNING) << "Use a merged buffer variable address, could cause error";
      }
      return it->second->alloc_var;
    } else {
      return e;
    }
  }
  Expr Mutate_(const Call* op, const Expr& e) final {
    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
      CHECK_EQ(op->args.size(), 5U);
      Type dtype = op->args[0].type();
      const Variable* buffer = op->args[1].as<Variable>();
      auto it = alloc_map_.find(buffer);
       if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
       const StorageEntry* se = it->second;
       Expr offset = Mutate(op->args[2]);
       Expr extent = Mutate(op->args[3]);
       CHECK_EQ(se->elem_type, dtype.element_of())
           << " buffer=" << buffer->name_hint;
       CHECK_EQ(se->elem_offset % dtype.lanes(), 0);
       if (se->elem_offset != 0) {
         offset = make_const(offset.type(), se->elem_offset / dtype.lanes()) + offset;
       }
       return Call::make(
           op->type, op->name,
           {op->args[0], se->alloc_var, offset, extent, op->args[4]},
           op->call_type);
    } else {
      return IRMutator::Mutate_(op, e);
    }
  }
  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
    CHECK(op->attr_key != attr::virtual_thread)
        << "InjectVirtualThread before StoragePlan";
    if (op->attr_key == attr::storage_scope) {
      return this->Mutate(op->body);
    } else if (op->attr_key == attr::thread_extent) {
      // remake all the allocation at the thread extent.
      if (attach_map_.count(op)) {
        std::vector<Stmt> nest;
        for (StorageEntry* e : attach_map_.at(op)) {
          nest.emplace_back(AttrStmt::make(
              e->alloc_var, attr::storage_scope,
              StringImm::make(e->scope.to_string()),
              Evaluate::make(0)));
          nest.push_back(e->new_alloc);
        }
        Stmt stmt = IRMutator::Mutate_(op, s);
        op = stmt.as<AttrStmt>();
        Stmt body = MergeNest(nest, op->body);
        return AttrStmt::make(
            op->node, op->attr_key, op->value, body);
      } else {
        return IRMutator::Mutate_(op, s);
      }
    } else if (op->attr_key == attr::volatile_scope) {
      Stmt stmt = IRMutator::Mutate_(op, s);
      op = stmt.as<AttrStmt>();
      auto it = alloc_map_.find(op->node.as<Variable>());
      if (it == alloc_map_.end()) return stmt;
      return AttrStmt::make(
          it->second->alloc_var, op->attr_key, op->value, op->body);
    } else {
      return IRMutator::Mutate_(op, s);
    }
  }
  Stmt Mutate_(const For* op, const Stmt& s) final {
    CHECK(op->for_type != ForType::Vectorized)
        << "VectorizeLoop before LiftStorageAlloc";
    return IRMutator::Mutate_(op, s);
  }
  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
    return this->Mutate(op->body);
  }

 private:
  // Alllocate entry of node.
  struct StorageEntry {
    // The scope that this alloc attaches after
    // For shared/local memory it is beginning of the thread extent.
    // for global memory it is nullptr, means beginning of everything.
    const Node* attach_scope_{nullptr};
    // The constant size of the buffer in bits, only used if it is constant
    size_t const_nbits{0};
    // The storage scope.
    StorageScope scope;
    // Allocs that shares this entry.
    std::vector<const Allocate*> allocs;
    // The children of this entry, not including itself.
    std::vector<StorageEntry*> merged_children;
    // The replacement allocation, if any.
    Stmt new_alloc;
    // The var expr of new allocation.
    VarExpr alloc_var;
    // The allocation element type.
    Type elem_type;
    // This is non-zero if this allocate is folded into another one
    // the address becomes alloc_var + sizeof(elem_type) * elem_offset;
    size_t elem_offset{0};
  };
  // Remap the index
  Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
    CHECK_EQ(dtype.element_of(), e->elem_type);
    if (e->elem_offset == 0) return index;
    return make_const(index.type(), e->elem_offset) + index;
  }
  // Prepare the new allocations
  void PrepareNewAlloc() {
    for (size_t i = 0; i < alloc_vec_.size(); ++i) {
      StorageEntry* e = alloc_vec_[i].get();
      attach_map_[e->attach_scope_].push_back(e);
    }
    // find allocation via attach map.
    for (auto &kv : attach_map_) {
      // find the element with the most amount of bytes.
      std::vector<StorageEntry*>& vec = kv.second;
      // try to find merge, for tagged memory
      for (size_t i = 0; i < vec.size(); ++i) {
        StorageEntry* e = vec[i];
        if (e->scope.tag.length() != 0) {
          CHECK_NE(e->const_nbits, 0U)
              << "Special tagged memory must be const size";
          for (size_t j = 0; j < i; ++j) {
            if (e->scope == vec[j]->scope) {
              vec[j]->merged_children.push_back(e);
              break;
            }
          }
        }
      }
      // Start allocation
      for (size_t i = 0; i < vec.size(); ++i) {
        StorageEntry* e = vec[i];
        // already merged
        if (e->elem_offset != 0) continue;
        if (e->merged_children.size() != 0) {
          NewAllocTagMerged(e); continue;
        }
        // Get the allocation size;
        e->alloc_var = e->allocs[0]->buffer_var;
        Type alloc_type = e->allocs[0]->type;
        for (const Allocate* op : e->allocs) {
          if (op->type.lanes() > alloc_type.lanes()) {
            alloc_type = op->type;
          }
        }
        if (e->allocs.size() == 1) {
          // simply use the original allocation.
          e->new_alloc = Allocate::make(
              e->alloc_var, alloc_type, e->allocs[0]->extents,
              e->allocs[0]->condition, Evaluate::make(0));
        } else {
          // Build a merged allocation
          Expr combo_size;
          for (const Allocate* op : e->allocs) {
            Expr sz = arith::ComputeReduce<Mul>(op->extents);
            if (alloc_type.lanes() != op->type.lanes()) {
              sz = (sz * make_const(sz.type(), op->type.lanes()) +
                    make_const(sz.type(), alloc_type.lanes() - 1)) /
                  make_const(sz.type(), alloc_type.lanes());
            }
            if (combo_size.defined()) {
              combo_size = max(combo_size, sz);
            } else {
              combo_size = sz;
            }
          }
          combo_size = ir::Simplify(combo_size);
          e->new_alloc = Allocate::make(
              e->alloc_var, alloc_type, {combo_size}, const_true(),
              Evaluate::make(0));
        }
      }
    }
  }
  // New allocation for merged data
  void NewAllocTagMerged(StorageEntry* e) {
    CHECK_NE(e->scope.tag.length(), 0U);
    // allocate with element type.
    CHECK_NE(e->const_nbits, 0U);
    MemoryInfo info = GetMemoryInfo(e->scope.to_string());
    size_t align = 1;
    if (info.defined()) {
      align = (info->max_simd_bits + e->elem_type.bits() - 1) / e->elem_type.bits();
    }
    size_t total_elem = e->const_nbits / e->elem_type.bits();
    if (total_elem % align != 0) {
      total_elem += align  - (total_elem % align);
    }
    e->alloc_var = e->allocs[0]->buffer_var;
    for (StorageEntry* child : e->merged_children) {
      CHECK_NE(e->const_nbits, 0U);
      CHECK_NE(total_elem, 0U);
      size_t num_elem = child->const_nbits / child->elem_type.bits();
      child->elem_offset = total_elem;
      child->alloc_var = e->alloc_var;
      total_elem += num_elem;
      if (total_elem % align != 0) {
        total_elem += align  - (total_elem % align);
      }
    }
    Expr alloc_size = make_const(e->allocs[0]->extents[0].type(), total_elem);
    e->new_alloc = Allocate::make(
        e->alloc_var, e->elem_type, {alloc_size}, const_true(),
        Evaluate::make(0));
    if (info.defined()) {
      CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
          << "Allocation exceed bound of memory tag " << e->scope.to_string();
    }
  }
  // Find the free location of each varaible.
  // Just do a reverse linear scan.
  void FindFreeLocation(const std::vector<StmtEntry>& seq) {
    std::unordered_set<const Variable*> touched;
    for (size_t i = seq.size(); i != 0; --i) {
      const StmtEntry& s = seq[i - 1];
      for (const AccessEntry& e : s.access) {
        if (!touched.count(e.buffer)) {
          touched.insert(e.buffer);
          free_loc_[i - 1].push_back(e.buffer);
        }
      }
    }
  }
  // Memory plan algorithm
  void PlanMemory(const std::vector<StmtEntry>& seq) {
    for (size_t i = 0; i < seq.size(); ++i) {
      const StmtEntry& s = seq[i];
      if (s.stmt->is_type<AttrStmt>()) {
        const auto* op = static_cast<const AttrStmt*>(s.stmt);
        CHECK_EQ(op->attr_key, attr::thread_extent);
        if (thread_scope_ != nullptr) {
          CHECK(thread_scope_ == op);
          // erase all non-global memory from constant free map.
          for (auto it = const_free_map_.begin();
               it != const_free_map_.end();) {
            if (it->second->scope.rank != 0) {
              it = const_free_map_.erase(it);
            } else {
              ++it;
            }
          }
          thread_scope_ = nullptr;
        } else {
          thread_scope_ = op;
        }
      } else if (s.stmt->is_type<Allocate>()) {
        const auto* op = static_cast<const Allocate*>(s.stmt);
        StorageEntry* e = this->FindAlloc(op, s.access[0].scope);
        e->allocs.emplace_back(op);
        alloc_map_[op->buffer_var.get()] = e;
      }
      // free list
      if (free_loc_.count(i)) {
        for (const Variable* var : free_loc_.at(i)) {
          this->Free(var);
        }
      }
    }
  }
  // Allocate new storage entry.
  StorageEntry* NewAlloc(const Allocate* op,
                         const StorageScope& scope,
                         size_t const_nbits) {
    // Re-use not successful, allocate a new buffer.
    std::unique_ptr<StorageEntry> entry(new StorageEntry());
    entry->attach_scope_ = thread_scope_;
    entry->scope = scope;
    entry->elem_type = op->type.element_of();
    entry->const_nbits = const_nbits;
    StorageEntry* e = entry.get();
    alloc_vec_.emplace_back(std::move(entry));
    return e;
  }
  StorageEntry* FindAlloc(const Allocate* op,
                          const StorageScope& scope) {
    // skip plan for local variable,
    // compiler can do a better job with register allocation.
    const size_t match_range = 16;
    size_t const_nbits = static_cast<size_t>(
        op->constant_allocation_size() * op->type.bits() * op->type.lanes());
    if (scope.rank > 1 || op->type.is_handle()) {
      return NewAlloc(op, scope, const_nbits);
    }
    // disable reuse of small arrays, they will be lowered to registers in LLVM
    if (const_nbits > 0  &&
        const_nbits <= 32 &&
        scope.tag.length() == 0) {
      return NewAlloc(op, scope, const_nbits);
    }
    if (const_nbits != 0) {
      // constant allocation.
      auto begin = const_free_map_.lower_bound(const_nbits / match_range);
      auto mid = const_free_map_.lower_bound(const_nbits);
      auto end = const_free_map_.upper_bound(const_nbits * match_range);
      for (auto it = mid; it != end; ++it) {
        StorageEntry *e = it->second;
        if (e->scope != scope) continue;
        if (e->elem_type != op->type.element_of()) continue;
        e->const_nbits = std::max(const_nbits, e->const_nbits);
        const_free_map_.erase(it);
        return e;
      }
      for (auto it = mid; it != begin;) {
        --it;
        StorageEntry *e = it->second;
        if (e->scope != scope) continue;
        if (e->elem_type != op->type.element_of()) continue;
        const_free_map_.erase(it);
        return e;
      }
    } else {
      // Simple strategy: round roubin.
      for (auto it = sym_free_list_.begin();
           it != sym_free_list_.end(); ++it) {
        StorageEntry* e = *it;
        if (e->scope != scope) continue;
        if (e->elem_type != op->type.element_of()) continue;
        sym_free_list_.erase(it);
        return e;
      }
    }
    return NewAlloc(op, scope, const_nbits);
  }
  // simulated free.
  void Free(const Variable* var) {
    auto it = alloc_map_.find(var);
    CHECK(it != alloc_map_.end());
    StorageEntry* e = it->second;
    // Disable sharing of local memory.
    if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
    // disable reuse of small arrays
    if (e->const_nbits > 0 && e->const_nbits <= 32) return;
    // normal free.
    if (e->const_nbits != 0) {
      const_free_map_.insert({e->const_nbits, e});
    } else {
      sym_free_list_.push_back(e);
    }
  }
  // thread scope.
  const Node* thread_scope_{nullptr};
  // Locations of free ops.
  std::unordered_map<size_t,
                     std::vector<const Variable*> > free_loc_;
  // The allocation attach map
  std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
  // The allocation assign map
  std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
  // constant size free map.
  std::multimap<size_t, StorageEntry*> const_free_map_;
  // symbolic free list, for non constant items.
  std::list<StorageEntry*> sym_free_list_;
  // The allocations
  std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
};

Stmt StorageRewrite(Stmt stmt) {
  return StoragePlanRewriter().Rewrite(stmt);
}
}  // namespace ir
}  // namespace tvm