Skip to content
Snippets Groups Projects
Commit 4f1473f3 authored by Tianqi Chen's avatar Tianqi Chen Committed by GitHub
Browse files

[CODEGEN] Add LoweredFunc, MakeAPI to build a C API function (#23)

* [CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice

* update halideir
parent 3c1020df
No related branches found
No related tags found
No related merge requests found
Showing
with 1242 additions and 279 deletions
Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
Subproject commit 30bf0f043e6388418958fd1f29259ee43c42b600
......@@ -50,6 +50,9 @@ class Buffer : public NodeRef {
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
/*! \brief specify container node */
using ContainerType = BufferNode;
};
/*! \brief Node to represent a buffer */
......
......@@ -30,6 +30,7 @@
#endif
#include <stdint.h>
#include <stddef.h>
TVM_EXTERN_C {
......@@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Launch a generated TVM function
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);
/*!
* \brief TVM Function API: Launch generated function.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C
#endif // TVM_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"
namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_H_
......@@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* }
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* }
* \sa TVMArrayFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
};
} // namespace intrinsic
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
......
......@@ -9,6 +9,7 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
#include "./ir.h"
namespace tvm {
namespace ir {
......@@ -51,6 +52,20 @@ class IRMutator {
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
};
/*!
......
......@@ -56,6 +56,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
*/
bool VerifySSA(const Stmt& ir);
/*!
* \brief Whether the expression have side effect.
* \return whether expression have side effect
*/
bool HasSideEffect(const Expr& e);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
......@@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
Expr body);
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
......
......@@ -34,6 +34,17 @@ class IRVisitor {
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
};
/*!
......
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;
// Internal node container of module.
class ModuleNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
};
// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_MODULE_H_
......@@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
......@@ -7,6 +7,7 @@
#define TVM_BASE_COMMON_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>
namespace tvm {
......@@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 32, 1);
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/codegen.h>
#include "./c_api_registry.h"
#include "../codegen/codegen_c.h"
......@@ -17,9 +18,19 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = CodeGenC().Compile(
*ret = CodeGenC().Compile(args.at(0), args.at(1));
});
TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = MakeAPI(
args.at(0), args.at(1), args.at(2), args.at(3));
});
TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = SplitHostDevice(args.at(0));
});
} // namespace codegen
} // namespace tvm
......@@ -9,24 +9,27 @@ namespace codegen {
using namespace ir;
std::string CodeGenC::Compile(
Stmt stmt, std::string fun_name,
Array<Var> args, bool output_ssa) {
std::string CodeGenC::Compile(LoweredFunc f,
bool output_ssa) {
print_ssa_form_ = output_ssa;
// skip the first underscore, so SSA variable starts from _1
if (print_ssa_form_) GetUniqueName("_");
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
HandleTypeRegister(kv.first.get(), kv.second.type());
}
this->indent += 2;
this->stream << "void " << fun_name << "(";
for (size_t i = 0; i < args.size(); ++i) {
Var v = args[i];
this->stream << "void " << f->name << "(";
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
PrintType(v.type(), stream);
stream << ' ' << vid;
}
stream << ") {\n";
this->PrintStmt(stmt);
this->PrintStmt(f->body);
this->indent -= 2;
this->PrintIndent();
this->stream << "}\n";
......@@ -104,12 +107,22 @@ std::string CodeGenC::GetVarID(const Variable* v) const {
return it->second;
}
bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const {
auto it = alloc_buf_type_.find(buf_var);
if (it == alloc_buf_type_.end()) return false;
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
CHECK(it->second == t)
<< "conflicting buf var type";
}
}
void CodeGenC::PrintIndent() {
for (int i = 0; i < this->indent; ++i) {
this->stream << ' ';
......@@ -234,6 +247,18 @@ inline void PrintBinaryExpr(const T* op,
os << ')';
}
inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
os << opstr;
p->PrintExpr(op->args[1], os);
os << ')';
}
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<Cast>([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
p->PrintType(op->type, os);
......@@ -300,24 +325,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << '!';
p->PrintExpr(op->a, os);
})
.set_dispatch<Call>([](const Call *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
p->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
}
os << ")";
});
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) {
std::string cond = p->PrintExpr(op->condition);
p->PrintIndent();
p->stream << "assert(" << cond << ");\n";
})
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
p->PrintStmt(op->body);
})
......@@ -372,14 +382,95 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
.DISPATCH_EXPR(Load)
.DISPATCH_EXPR(Call)
.DISPATCH_EXPR(Let)
.DISPATCH_EXPR(Ramp)
.DISPATCH_EXPR(Broadcast)
.DISPATCH_EXPR(Select);
void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
CodeGenC* p = this;
if (op->is_intrinsic(Call::bitwise_and)) {
PrintBinaryIntrinsitc(op, " & ", os, p);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
PrintBinaryIntrinsitc(op, " ^ ", os, p);
} else if (op->is_intrinsic(Call::bitwise_or)) {
PrintBinaryIntrinsitc(op, " | ", os, p);
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
p->PrintExpr(op->args[0], os);
os << ')';
} else if (op->is_intrinsic(Call::shift_left)) {
PrintBinaryIntrinsitc(op, " << ", os, p);
} else if (op->is_intrinsic(Call::shift_right)) {
PrintBinaryIntrinsitc(op, " >> ", os, p);
} else if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
os << "((";
p->PrintType(l->type.element_of(), os);
os << " *)" << p->GetVarID(l->buffer_var.get())
<< " + ";
p->PrintExpr(l->index, os);
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
CHECK_EQ(op->args.size(), 3U);
if (!op->type.is_handle()) {
os << '(';
p->PrintType(op->type, os);
os << ')';
}
os << "(((TVMArg*)";
p->PrintExpr(op->args[0], os);
os << ")[" << op->args[2] << "].";
if (op->type.is_handle()) {
os << "v_handle";
} else if (op->type.is_float()) {
os << "v_double";
} else if (op->type.is_int() || op->type.is_uint()) {
os << "v_long";
} else {
LOG(FATAL) << "donot know how to handle type" << op->type;
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
os << "(((TVMArray*)";
p->PrintExpr(op->args[0], os);
os << ")->";
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: os << "data"; break;
case intrinsic::kShape: os << "shape"; break;
case intrinsic::kStrides: os << "strides"; break;
case intrinsic::kNDim: os << "ndim"; break;
case intrinsic::kTypeCode: os << "dtype.type_code"; break;
case intrinsic::kTypeBits: os << "dtype.bits"; break;
case intrinsic::kTypeLanes: os << "dtype.lanes"; break;
default: LOG(FATAL) << "unknown field code";
}
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
os << "(";
p->PrintExpr(op->args[0], os);
os << " == NULL)";
} else {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
p->PrintExpr(op->args[i], os);
if (i < op->args.size() - 1) {
os << ", ";
}
}
os << ")";
}
}
void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(op->buffer_var.get());
if (!BufferTypeMatch(op->buffer_var.get(), op->type)) {
if (!HandleTypeMatch(op->buffer_var.get(), op->type)) {
os << "((const ";
PrintType(op->type, os);
os << "*)" << vid << ')';
......@@ -416,7 +507,8 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); });
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); });
void CodeGenC::PrintStmt(const LetStmt* op) {
......@@ -426,10 +518,20 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
PrintType(op->var.type(), this->stream);
this->stream << ' '
<< AllocVarID(op->var.get())
<< " = " << value << ";\n";
if (op->var.type() == Handle() &&
handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "* "
<< AllocVarID(op->var.get())
<< " = (";
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "*)" << value << ";\n";
} else {
PrintType(op->var.type(), this->stream);
this->stream << ' '
<< AllocVarID(op->var.get())
<< " = " << value << ";\n";
}
}
PrintStmt(op->body);
}
......@@ -439,7 +541,7 @@ void CodeGenC::PrintStmt(const Store* op) {
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
std::string vid = GetVarID(op->buffer_var.get());
if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) {
if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) {
this->stream << "((";
PrintType(op->value.type(), this->stream);
this->stream << "*)" << vid << ')';
......@@ -452,16 +554,25 @@ void CodeGenC::PrintStmt(const Store* op) {
}
void CodeGenC::PrintStmt(const Allocate* op) {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
std::string vid = AllocVarID(op->buffer_var.get());
CHECK(!op->new_expr.defined());
CHECK(!is_zero(op->condition));
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
PrintType(op->type, stream);
stream << ' '<< vid << '['
<< constant_size << "]\n;";
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
PrintType(op->type, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
PrintType(op->type, stream);
stream << ' '<< vid << '['
<< constant_size << "]\n;";
}
HandleTypeRegister(op->buffer_var.get(), op->type);
this->PrintStmt(op->body);
}
......@@ -469,15 +580,29 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
if (op->type_key == "scope") {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
this->PrintIndent();
PrintType(iv->var.type(), stream);
stream << ' '
<< AllocVarID(iv->var.get())
<< " = " << iv->thread_tag << ";\n";
if (!var_idmap_.count(iv->var.get())) {
this->PrintIndent();
PrintType(iv->var.type(), stream);
stream << ' '
<< AllocVarID(iv->var.get())
<< " = " << iv->thread_tag << ";\n";
}
}
}
this->PrintStmt(op->body);
}
void CodeGenC::PrintStmt(const AssertStmt* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (op->message.as<StringImm>()) {
// GLOG style check
stream << "CHECK(" << cond << ") << \""
<< op->message.as<StringImm>()->value << "\";\n";
} else {
stream << "assert(" << cond << ");\n";
}
}
} // namespace codegen
} // namespace tvm
......@@ -8,6 +8,7 @@
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/module.h>
#include <string>
#include <unordered_map>
......@@ -23,16 +24,12 @@ class CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param body The body of the function.
* \param fun_name The name of the function.
* \param args The arguments to the function.
* \param f The function to be compiled
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(Stmt body,
std::string fun_name,
Array<Var> args,
std::string Compile(LoweredFunc f,
bool output_ssa);
/*!
* \brief Print the Stmt n to CodeGenC->stream
......@@ -49,7 +46,7 @@ class CodeGenC {
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
inline std::string PrintExpr(const Expr& n) {
std::string PrintExpr(const Expr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
......@@ -85,7 +82,9 @@ class CodeGenC {
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
......@@ -116,7 +115,13 @@ class CodeGenC {
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool BufferTypeMatch(const Variable* buf_var, Type t) const;
bool HandleTypeMatch(const Variable* buf_var, Type t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void HandleTypeRegister(const Variable* buf_var, Type t);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
......@@ -128,7 +133,7 @@ class CodeGenC {
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> alloc_buf_type_;
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
......
/*!
* Copyright (c) 2017 by Contributors
* \file make_api.cc Build API function.
*/
#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <tvm/buffer.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include "../pass/ir_util.h"
namespace tvm {
namespace codegen {
using namespace ir;
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
return Call::make(
t, intrinsic::tvm_array_get_field,
{arr, IntImm::make(Int(32), kind)},
Call::PureIntrinsic);
}
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
return AssertStmt::make(lhs == rhs, msg);
}
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args) {
const Type tvm_index_type = UInt(32);
const Stmt nop = Evaluate::make(0);
// Data field definitions
// The packed fields
Var v_packed_args("args", Handle());
Var v_packed_arg_type_ids("arg_type_ids", Handle());
Var v_num_packed_args("num_args", Int(32));
// The arguments of the function.
Array<Var> args;
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after iniit
std::vector<Stmt> seq_init, seq_check;
std::unordered_set<const Variable*> visited;
// the handle data types
Map<Var, Expr> handle_data_type;
// ---------------------------
// local function defintiions
// load i-th argument as type t
auto f_arg_value = [&](Type t, int i) {
Array<Expr> call_args{
v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)};
return Call::make(
t, intrinsic::tvm_api_load_arg, call_args,
Call::PureIntrinsic);
};
// get declaration of argument i
auto f_arg_decl = [&](int i) {
std::ostringstream os;
os << "arg" << i;
const Variable* v = api_args[i].as<Variable>();
return Var(os.str(), v ? v->type: Handle());
};
// Push related into assertions or variable defintion
// given the symbolic declaration and concrete value
auto f_push = [&](Expr sym, Expr value, std::string field) {
if (sym.as<Variable>()) {
// If sym is a Variable and this Variable is not yet defined
// add this to defintion.
Var v(sym.node_);
if (!visited.count(v.get())) {
seq_init.emplace_back(LetStmt::make(v, value, nop));
visited.insert(v.get());
return true;
}
}
// otherwise, assume sym is already defined, insert assertion.
std::ostringstream os;
os << "Field " << field << " has a unsatisfied constraint";
seq_check.emplace_back(MakeAssertEQ(sym, value, os.str()));
return false;
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
if (num_packed_args != 0) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;
os << "expected num_args to be " << num_packed_args;
seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
for (size_t i = 0; i < api_args.size(); ++i) {
Var v_arg = f_arg_decl(i);
if (i < static_cast<size_t>(num_packed_args)) {
seq_init.emplace_back(LetStmt::make(
v_arg, f_arg_value(v_arg.type(), i), nop));
} else {
args.push_back(v_arg);
}
// add checks for functions.
if (api_args[i].as<Variable>()) {
f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint);
} else {
// Buffer checks
CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim);
std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal "
<< buf->shape.size();
seq_init.emplace_back(
MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()),
ndim_err_msg.str()));
// type checks
Type dtype = buf->dtype;
std::ostringstream type_err_msg;
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field
if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
v_arg->name_hint + ".data")) {
Var vptr(buf->ptr);
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
}
// shape field
Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0));
seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))),
field_name.str());
}
// strides field
Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0));
seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
if (buf->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << "arg_" << i << ".strides:"
<< " expected to be nullptr for contiguous array";
seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buf->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))),
field_name.str());
}
}
}
}
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
n->name = name;
n->args = args;
n->handle_data_type = handle_data_type;
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " does not appeared in api_args";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
return f;
}
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <tvm/module.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
namespace tvm {
namespace codegen {
using namespace ir;
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!use_count_.count(iv->var.get())) {
this->HandleDef(iv->var.get());
thread_axis_.push_back(iv);
thread_extent_.push_back(op->value);
}
Expr value = op->value;
if (visit_thread_extent_) {
value = this->Mutate(value);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s;
return AttrStmt::make(op->node, op->type_key, value, body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
this->HandleDef(op->var.get());
Stmt body = this->Mutate(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
}
Stmt Mutate_(const For *op, const Stmt& s) final {
this->HandleDef(op->loop_var.get());
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate *op, const Stmt& s) final {
this->HandleDef(op->buffer_var.get());
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt& s) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Let *op, const Expr& e) final {
this->HandleDef(op->var.get());
Expr body = this->Mutate(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
}
Expr Mutate_(const Variable *op, const Expr& e) final {
this->HandleUse(e);
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Load *op, const Expr& e) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, e);
}
void HandleDef(const Variable* v) {
CHECK(!use_count_.count(v))
<< "variable is already defined";
use_count_[v] = 0;
}
void HandleUse(const Expr& v) {
CHECK(v.as<Variable>());
Var var(v.node_);
auto it = use_count_.find(var.get());
if (it != use_count_.end()) {
if (it->second >= 0) {
++it->second;
}
} else {
undefined_.push_back(var);
use_count_[var.get()] = -1;
}
}
// The fields are publically readible to
// be accessible to the users.
bool visit_thread_extent_{true};
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<Expr> thread_extent_;
std::unordered_map<const Variable*, int> use_count_;
};
class HostDeviceSplitter : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") {
LOG(INFO) << "??";
IterVar iv(op->node.node_);
return SplitDeviceFunc(s);
}
return IRMutator::Mutate_(op, s);
}
Array<LoweredFunc> Split(LoweredFunc f) {
for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second;
}
name_ = f->name;
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
}
return ret;
}
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
// isolate the device function.
IRUseDefAnalysis m;
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->name = os.str();
n->args = m.undefined_;
CHECK_NE(m.thread_extent_.size(), 0U);
// improve the handle data type
for (Var arg : n->args) {
auto it = handle_data_type_.find(arg.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(arg, it->second);
}
}
LoweredFunc f_device(n);
Array<Expr> call_args;
for (Var arg : n->args) {
call_args.push_back(arg);
}
for (Expr ext : m.thread_extent_) {
call_args.push_back(ext);
}
device_funcs_.emplace_back(f_device);
return Evaluate::make(Call::make(
Int(32), f_device->name, call_args, Call::Extern, f_device));
}
// function name
std::string name_;
// the device functions
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const Variable*, Expr> handle_data_type_;
};
Array<Var> UndefinedVars(const LoweredFunc& f) {
IRUseDefAnalysis m;
for (Var arg : f->args) {
m.use_count_[arg.get()] = 0;
}
m.Mutate(f->body);
return m.undefined_;
}
Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
return HostDeviceSplitter().Split(func);
}
} // namespace codegen
} // namespace tvm
......@@ -17,36 +17,28 @@ class IRInline : public IRMutator {
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call);
} else {
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
} else {
return e;
}
}
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
}
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
Expr InlineCall(const Call* op) {
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
}
};
Stmt Inline(Stmt stmt,
......
......@@ -58,6 +58,183 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
}
}
#define DISPATCH_TO_MUTATE_STMT(OP) \
set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \
return m->Mutate_(op, s); \
})
#define DISPATCH_TO_MUTATE_EXPR(OP) \
set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \
return m->Mutate_(op, e); \
})
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Free);
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
}
}
Stmt IRMutator::Mutate_(const For *op, const Stmt& s) {
Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
}
Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
IRMutator* m = this;
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
}
Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
auto new_args = MutateArray(op->args, this);
auto new_value = this->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, op->value_index, new_value, new_args);
}
}
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
Halide::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body);
}
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
}
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s;
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Variable);
Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
auto new_args = MutateArray(op->args, this);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index);
}
}
Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
}
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
return e;
}
Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
......@@ -70,24 +247,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<AttrStmt>([](const AttrStmt* op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, value, body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
.set_dispatch<FloatImm>(ReturnSelfExpr)
.set_dispatch<StringImm>(ReturnSelfExpr)
.set_dispatch<Variable>(ReturnSelfExpr);
.set_dispatch<StringImm>(ReturnSelfExpr);
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
......@@ -150,14 +314,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
return Select::make(cond, t, f);
}
})
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
})
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->Mutate(op->base);
Expr stride = m->Mutate(op->stride);
......@@ -175,38 +331,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} else {
return Broadcast::make(value, op->lanes);
}
})
.set_dispatch<Call>([](const Call *op, const Expr& e, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index);
}
})
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
Expr body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<LetStmt>([](const LetStmt *op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Expr message = m->Mutate(op->message);
......@@ -225,93 +352,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->Mutate(op->min);
Expr extent = m->Mutate(op->extent);
Stmt body = m->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
})
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Expr index = m->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
})
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m);
auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, op->value_index, new_value, new_args);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
})
.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) {
return s;
})
.set_dispatch<Realize>([](const Realize *op, const Stmt& s, IRMutator* m) {
Halide::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body);
}
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->Mutate(op->first);
Stmt rest = m->Mutate(op->rest);
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_util.h
* \brief Helper functions to construct and compose IR nodes.
*/
#ifndef TVM_PASS_IR_UTIL_H_
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <vector>
namespace tvm {
namespace ir {
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<Stmt> nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
body = MergeNest(*ri, body);
}
return body;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
......@@ -8,7 +8,6 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
......@@ -26,7 +25,6 @@ class IRApplyVisit : public IRVisitor {
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
......@@ -36,12 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
// namespace to register the functors.
namespace {
using namespace Halide::Internal;
void NoOp(const NodeRef& n, IRVisitor* v) {
}
......@@ -59,24 +51,82 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
}
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->Visit(op->source);
});
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);
void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt *op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const For *op) {
IRVisitor* v = this;
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
}
void IRVisitor::Visit_(const Allocate *op) {
IRVisitor* v = this;
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
}
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->Visit(op->new_expr);
}
}
void IRVisitor::Visit_(const Load *op) {
this->Visit(op->index);
}
void IRVisitor::Visit_(const Store *op) {
this->Visit(op->value);
this->Visit(op->index);
}
void IRVisitor::Visit_(const Let *op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const Free* op) {}
void IRVisitor::Visit_(const Call *op) {
VisitArray(op->args, this);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->Visit(op->source);
})
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
.set_dispatch<StringImm>(NoOp)
.set_dispatch<Variable>(NoOp);
.set_dispatch<StringImm>(NoOp);
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
......@@ -116,29 +166,15 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->Visit(op->index);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->Visit(op->value);
})
.set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v);
})
.set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->Visit(op->condition);
v->Visit(op->message);
......@@ -146,30 +182,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->Visit(op->body);
})
.set_dispatch<For>([](const For *op, IRVisitor* v) {
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
})
.set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->index);
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
v->Visit(op->value);
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
}
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->Visit(op->new_expr);
}
})
.set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
......@@ -193,6 +209,5 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->value);
});
} // namespace
} // namespace ir
} // namespace tvm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment