From 9ce7f0a353bc092a5a683a38f4904eb83a6e5711 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Sat, 26 Nov 2016 17:21:38 -0800 Subject: [PATCH] Check in inline and test --- include/tvm/ir.h | 1 + include/tvm/ir_pass.h | 15 ++++++--- python/tvm/__init__.py | 1 + python/tvm/_ctypes/_api.py | 24 +++++++-------- python/tvm/ir_pass.py | 1 + src/c_api/c_api.cc | 4 ++- src/c_api/c_api_ir.cc | 5 ++- src/c_api/c_api_lang.cc | 1 - src/c_api/c_api_pass.cc | 35 +++++++++++++++++++++ src/c_api/c_api_registry.h | 22 +++++++++++-- src/lang/tensor.cc | 3 +- src/pass/inline.cc | 60 ++++++++++++++++++++++++++++++++++++ src/pass/ssa.cc | 4 +-- tests/cpp/ir_ssa_test.cc | 6 ++-- tests/python/test_inline.py | 15 +++++++++ tests/python/test_ir_pass.py | 17 ++++++++++ 16 files changed, 183 insertions(+), 31 deletions(-) create mode 100644 python/tvm/ir_pass.py create mode 100644 src/c_api/c_api_pass.cc create mode 100644 src/pass/inline.cc create mode 100644 tests/python/test_inline.py create mode 100644 tests/python/test_ir_pass.py diff --git a/include/tvm/ir.h b/include/tvm/ir.h index dcf68e4a4..0ba993b02 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -18,6 +18,7 @@ namespace ir { using Halide::Internal::ExprNode; using Halide::Internal::IRNodeType; +using Halide::Internal::ForType; /*! \brief Reduction operator operator */ struct Reduce : public ExprNode<Reduce> { diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 29152cd13..def17377d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -1,7 +1,10 @@ /*! * Copyright (c) 2016 by Contributors * \file ir_pass.h - * \brief Collection of IR pass functions and visit functions + * \brief Collection of IR pass functions + * + * All the pass functions in this file are for Stmt, + * We can use PassFunction(Evaluate(expr)) to apply it to Expr */ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ @@ -22,14 +25,14 @@ namespace ir { * \return Whether IR is in SSA form. * \note All the passes in this file uses SSA form and outputs SSA form. */ -bool VerifySSA(const IRNodeRef& ir); +bool VerifySSA(const Stmt& ir); /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. * \return The converted form. */ -Stmt ConvertSSA(const Stmt& stmt); +Stmt ConvertSSA(Stmt stmt); /*! * \brief inline all calls of f in stmt. @@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt); * * \note All the passes in this file uses SSA form and outputs SSA form. */ -Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt); - +Stmt Inline(FunctionRef f, + Array<Var> args, + Expr body, + Stmt stmt); } // namespace ir } // namespace tvm diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 4284b4595..f1c2ea41a 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -6,6 +6,7 @@ from . import tensor as tensor from . import expr from . import stmt from . import make +from . import ir_pass from . import collections from . import schedule diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index de8468ba5..8a1c0b247 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -224,21 +224,19 @@ def _init_function_module(root_namespace): module_obj = sys.modules["%s.function" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace] - module_make = sys.modules["%s.make" % root_namespace] + namespace_match = { + "_make_" : sys.modules["%s.make" % root_namespace], + "_pass_" : sys.modules["%s.ir_pass" % root_namespace] + } for name in op_names: hdl = FunctionHandle() check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl))) - if name.startswith("_make_"): - fname = name[6:] - else: - fname = name - + fname = name + target_module = module_internal if name.startswith('_') else module_obj + for k, v in namespace_match.items(): + if name.startswith(k): + fname = name[len(k):] + target_module = v function = _make_function(hdl, fname) - - if name.startswith("_make_"): - setattr(module_make, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) + setattr(target_module, function.__name__, function) diff --git a/python/tvm/ir_pass.py b/python/tvm/ir_pass.py new file mode 100644 index 000000000..3ba8c219a --- /dev/null +++ b/python/tvm/ir_pass.py @@ -0,0 +1 @@ +"""Namespace of IR pass functions""" diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 7e9e32b33..9d540ed63 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg, API_BEGIN(); ret->arg_stack.resize(ret->arg_stack.size() + 1); APIVariantValue& v = ret->arg_stack.back(); + v.type_id = static_cast<ArgVariantID>(type_id); if (type_id == kStr) { - v = arg.v_str; + v.str = arg.v_str; } else if (type_id == kNodeHandle) { v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); } else { v.v_union = arg; } + API_END_HANDLE_ERROR(ret->Clear()); } diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index 79c6ac7e4..94b65c230 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -9,9 +9,7 @@ #include "./c_api_registry.h" namespace tvm { - -using namespace tvm::ir; -using namespace Halide::Internal; +namespace ir { using ArgStack = const std::vector<APIVariantValue>; using RetValue = APIVariantValue; @@ -135,4 +133,5 @@ REGISTER_MAKE2(Block); REGISTER_MAKE3(IfThenElse); REGISTER_MAKE1(Evaluate); +} // namespace ir } // namespace tvm diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index b2d28164a..91bd2ef9c 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -19,7 +19,6 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_const) .set_body([](const ArgStack& args, RetValue *ret) { using Halide::Internal::make_const; - if (args.at(0).type_id == kLong) { *ret = make_const(args.at(1), args.at(0).operator int64_t()); } else if (args.at(0).type_id == kDouble) { diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc new file mode 100644 index 000000000..d3046ac91 --- /dev/null +++ b/src/c_api/c_api_pass.cc @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2016 by Contributors + * Exposre of pass functions. + * \file c_api_pass.cc + */ +#include <tvm/expr.h> +#include <tvm/ir.h> +#include <tvm/ir_pass.h> +#include "./c_api_registry.h" + +namespace tvm { +namespace ir { + +using ArgStack = const std::vector<APIVariantValue>; +using RetValue = APIVariantValue; + +// make from two arguments +#define REGISTER_PASS1(PassName) \ + TVM_REGISTER_API(_pass_## PassName) \ + .set_body([](const ArgStack& args, RetValue *ret) { \ + *ret = PassName(args.at(0)); \ + }) \ + +#define REGISTER_PASS4(PassName) \ + TVM_REGISTER_API(_pass_## PassName) \ + .set_body([](const ArgStack& args, RetValue *ret) { \ + *ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \ + }) \ + +REGISTER_PASS1(ConvertSSA); +REGISTER_PASS1(VerifySSA); +REGISTER_PASS4(Inline); + +} // namespace ir +} // namespace tvm diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 97a7e0c14..8004bfe38 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -43,6 +43,17 @@ inline Type String2Type(std::string s) { return Type(code, bits, lanes); } +inline const char* TypeId2Str(ArgVariantID type_id) { + switch (type_id) { + case kNull: return "Null"; + case kLong: return "Long"; + case kDouble: return "Double"; + case kStr: return "Str"; + case kNodeHandle: return "NodeHandle"; + default: LOG(FATAL) << "unknown type_id=" << type_id; return ""; + } +} + /*! \brief Variant container for API calls */ class APIVariantValue { public: @@ -74,6 +85,11 @@ class APIVariantValue { v_union.v_long = value; return *this; } + inline APIVariantValue& operator=(bool value) { + type_id = kLong; + v_union.v_long = value; + return *this; + } inline APIVariantValue& operator=(std::string value) { type_id = kStr; str = std::move(value); @@ -130,11 +146,13 @@ class APIVariantValue { return v_union.v_long; } inline operator bool() const { - CHECK_EQ(type_id, kLong); + CHECK_EQ(type_id, kLong) + << "expect boolean(int) but get " << TypeId2Str(type_id); return v_union.v_long != 0; } inline operator std::string() const { - CHECK_EQ(type_id, kStr); + CHECK_EQ(type_id, kStr) + << "expect Str but get " << TypeId2Str(type_id); return str; } inline operator Type() const { diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 643332701..fb02dde25 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const { CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" << "ndim = " << ndim() << ", indices.size=" << indices.size(); - return Call::make( + auto n Call::make( (*this)->dtype, (*this)->name, indices, Call::Halide, *this); + return n; } diff --git a/src/pass/inline.cc b/src/pass/inline.cc new file mode 100644 index 000000000..669324225 --- /dev/null +++ b/src/pass/inline.cc @@ -0,0 +1,60 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file inline.cc + */ +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> +#include <tvm/ir_pass.h> + +namespace tvm { +namespace ir { +namespace { + +// inliner to inline a function +// the result may not be SSA, +// ConvertSSA need to be applied after this pass +class IRInline : public IRMutator { + public: + IRInline(FunctionRef f, Array<Var> args, Expr body) + : f_(f), args_(args), body_(body) {} + + Expr Mutate(Expr expr) final { + const Call* call = expr.as<Call>(); + if (call != nullptr && call->func == f_) { + return InlineCall(call); + } else { + return IRMutator::Mutate(expr); + } + } + + Stmt Mutate(Stmt stmt) final { + return IRMutator::Mutate(stmt); + } + + private: + FunctionRef f_; + Array<Var> args_; + Expr body_; + + Expr InlineCall(const Call* op) { + Expr expr = body_; + + CHECK_EQ(args_.size(), op->args.size()) + << op->args.size() << " vs " << args_.size(); + for (size_t i = 0; i < args_.size(); ++i) { + expr = Let::make(args_[i], op->args[i], expr); + } + return expr; + } +}; + +} // namespace + +Stmt Inline(FunctionRef f, + Array<Var> args, + Expr body, + Stmt stmt) { + return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); +} +} // namespace ir +} // namespace tvm diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 12beffeb9..44b2454de 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator { } // namespace -bool VerifySSA(const IRNodeRef& ir) { +bool VerifySSA(const Stmt& ir) { IRVerifySSA v; v.Visit(ir); return v.is_ssa; } -Stmt ConvertSSA(const Stmt& stmt) { +Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA().Mutate(stmt); } diff --git a/tests/cpp/ir_ssa_test.cc b/tests/cpp/ir_ssa_test.cc index 0f0f9e6da..2de7dba08 100644 --- a/tests/cpp/ir_ssa_test.cc +++ b/tests/cpp/ir_ssa_test.cc @@ -10,9 +10,9 @@ TEST(IRSSA, Convert) { Var x("x"), y; Expr let = Let::make(x, 1, x + 1); - auto z = let + let; + auto z = Evaluate::make(let + let); CHECK(!ir::VerifySSA(z)); - auto z_ssa = ir::ConvertSSA(Evaluate::make(z)); + auto z_ssa = ir::ConvertSSA(z); CHECK(ir::VerifySSA(z_ssa)); } @@ -20,7 +20,7 @@ TEST(IRSSA, Basic) { using namespace Halide::Internal; using namespace tvm; Var x("x"), y; - auto z = x + y; + auto z = Evaluate::make(x + y); CHECK(ir::VerifySSA(z)); } diff --git a/tests/python/test_inline.py b/tests/python/test_inline.py new file mode 100644 index 000000000..9695a832f --- /dev/null +++ b/tests/python/test_inline.py @@ -0,0 +1,15 @@ +import tvm + +def test_inline(): + m = tvm.Var('m') + A = tvm.placeholder((m,), name='A') + T = tvm.compute((m,), lambda i,: A(i) + 10, name='T') + X = T(100) + stmt = tvm.make.Evaluate(T(10) + 11 * T(100)) + stmt = tvm.ir_pass.Inline( + T, T.source_op.iter_var, T.source_op.body, stmt) + print(stmt) + assert(tvm.ir_pass.VerifySSA(stmt)) + +if __name__ == "__main__": + test_inline() diff --git a/tests/python/test_ir_pass.py b/tests/python/test_ir_pass.py new file mode 100644 index 000000000..23262f1cc --- /dev/null +++ b/tests/python/test_ir_pass.py @@ -0,0 +1,17 @@ +import tvm + +def test_verify_ssa(): + x = tvm.Var('x') + y = tvm.Var() + z = tvm.make.Evaluate(x + y) + assert(tvm.ir_pass.VerifySSA(z)) + + +def test_convert_ssa(): + x = tvm.Var('x') + y = tvm.Var() + let = tvm.make.Let(x, 1, x + 1) + z = tvm.make.Evaluate(let + let) + assert(not tvm.ir_pass.VerifySSA(z)) + z_ssa = tvm.ir_pass.ConvertSSA(z) + assert(tvm.ir_pass.VerifySSA(z_ssa)) -- GitLab