Newer
Older
/*!
* Copyright (c) 2017 by Contributors
* \file storage_access.h
* \brief Common data structure for storage access analysis.
*/
#ifndef TVM_PASS_STORAGE_ACCESS_H_
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
#include <unordered_map>
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
using runtime::StorageRank;
/*!
* \brief Base class of storage access analysis
*/
class StorageAccessVisitor : public IRVisitor {
public:
/*! \brief Storage access type */
enum AccessType {
kRead,
kWrite,
kSync,
kAlloc,
// acquired version of read, only need to handle WAR dep.
kReadAcquire
};
/*! \brief An access entry */
struct AccessEntry {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
arith::IntSet touched;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write{false};
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
/*! \brief The statement */
const Node* stmt;
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
// override visitor pattern
void Visit_(const Load* op) final;
void Visit_(const Store* op) final;
void Visit_(const Evaluate* op) final;
void Visit_(const AttrStmt* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Call* op) final;
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
protected:
StorageAccessVisitor() {
scope_.push_back(std::vector<StmtEntry>());
}
/*! \return number of conditions in the current scope. */
int condition_counter() const {
return condition_counter_;
}
/*! \return whether we are in device environment. */
bool in_device_env() const {
return in_device_env_;
}
/*! \return environment threads */
const Array<IterVar>& env_threads() const {
return env_threads_;
}
/*!
* \brief Whether we need analyze the buffer in current scope.
* \param buffer The buffer to be checked
* \param scope The scope of the buffer.
* \return Whether the analysis of buffer is enabled.
*/
virtual bool Enabled(const Variable* buffer,
const StorageScope& scope) const {
return true;
}
/*!
* \brief Summarize the sequence of operations into parent.
*
* Insert synchronization if necessary and remove un-necessary
* memory access which are already synced.
*
* \param seq The sequence of the access operations.
* \param loop Pass loop node if it is a loop, otherwise nullptr.
* \return The summarized sequence that represent access that
* the parent should taken care of to synchronize.
*/
virtual std::vector<AccessEntry> Summarize(
std::vector<StmtEntry> seq, const For* loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
StorageScope GetScope(const Variable* buf) const;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
private:
// whether access appending is enabled.
bool allow_append_{false};
// Whether we are in device environment
bool in_device_env_{false};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
const Variable* double_buffer_write_{nullptr};
// the current free stmt entry.
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_STORAGE_ACCESS_H_