Skip to content
Snippets Groups Projects
Commit 9ce7f0a3 authored by tqchen's avatar tqchen
Browse files

Check in inline and test

parent 2fafa935
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ namespace ir {
using Halide::Internal::ExprNode;
using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
......@@ -22,14 +25,14 @@ namespace ir {
* \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
bool VerifySSA(const IRNodeRef& ir);
bool VerifySSA(const Stmt& ir);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt& stmt);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief inline all calls of f in stmt.
......@@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt);
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
} // namespace ir
} // namespace tvm
......
......@@ -6,6 +6,7 @@ from . import tensor as tensor
from . import expr
from . import stmt
from . import make
from . import ir_pass
from . import collections
from . import schedule
......
......@@ -224,21 +224,19 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
module_make = sys.modules["%s.make" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace]
}
for name in op_names:
hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
if name.startswith("_make_"):
fname = name[6:]
else:
fname = name
fname = name
target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items():
if name.startswith(k):
fname = name[len(k):]
target_module = v
function = _make_function(hdl, fname)
if name.startswith("_make_"):
setattr(module_make, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
setattr(target_module, function.__name__, function)
"""Namespace of IR pass functions"""
......@@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg,
API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id);
if (type_id == kStr) {
v = arg.v_str;
v.str = arg.v_str;
} else if (type_id == kNodeHandle) {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
} else {
v.v_union = arg;
}
API_END_HANDLE_ERROR(ret->Clear());
}
......
......@@ -9,9 +9,7 @@
#include "./c_api_registry.h"
namespace tvm {
using namespace tvm::ir;
using namespace Halide::Internal;
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
......@@ -135,4 +133,5 @@ REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -19,7 +19,6 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
......
/*!
* Copyright (c) 2016 by Contributors
* Exposre of pass functions.
* \file c_api_pass.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
namespace tvm {
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
} // namespace ir
} // namespace tvm
......@@ -43,6 +43,17 @@ inline Type String2Type(std::string s) {
return Type(code, bits, lanes);
}
inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) {
case kNull: return "Null";
case kLong: return "Long";
case kDouble: return "Double";
case kStr: return "Str";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
}
}
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -74,6 +85,11 @@ class APIVariantValue {
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(bool value) {
type_id = kLong;
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_id = kStr;
str = std::move(value);
......@@ -130,11 +146,13 @@ class APIVariantValue {
return v_union.v_long;
}
inline operator bool() const {
CHECK_EQ(type_id, kLong);
CHECK_EQ(type_id, kLong)
<< "expect boolean(int) but get " << TypeId2Str(type_id);
return v_union.v_long != 0;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr);
CHECK_EQ(type_id, kStr)
<< "expect Str but get " << TypeId2Str(type_id);
return str;
}
inline operator Type() const {
......
......@@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
return Call::make(
auto n Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
return n;
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file inline.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline : public IRMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
return InlineCall(call);
} else {
return IRMutator::Mutate(expr);
}
}
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
}
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
Expr InlineCall(const Call* op) {
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
}
};
} // namespace
Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator {
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
bool VerifySSA(const Stmt& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
Stmt ConvertSSA(const Stmt& stmt) {
Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt);
}
......
......@@ -10,9 +10,9 @@ TEST(IRSSA, Convert) {
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
auto z = let + let;
auto z = Evaluate::make(let + let);
CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
auto z_ssa = ir::ConvertSSA(z);
CHECK(ir::VerifySSA(z_ssa));
}
......@@ -20,7 +20,7 @@ TEST(IRSSA, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
auto z = Evaluate::make(x + y);
CHECK(ir::VerifySSA(z));
}
......
import tvm
def test_inline():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline(
T, T.source_op.iter_var, T.source_op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
if __name__ == "__main__":
test_inline()
import tvm
def test_verify_ssa():
x = tvm.Var('x')
y = tvm.Var()
z = tvm.make.Evaluate(x + y)
assert(tvm.ir_pass.VerifySSA(z))
def test_convert_ssa():
x = tvm.Var('x')
y = tvm.Var()
let = tvm.make.Let(x, 1, x + 1)
z = tvm.make.Evaluate(let + let)
assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment