diff --git a/HalideIR b/HalideIR index f0deabe56bc20e60899e44b432d4a628a90161f3..2b3ea8f5207152340014fd0a1ab12816ac48c326 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit f0deabe56bc20e60899e44b432d4a628a90161f3 +Subproject commit 2b3ea8f5207152340014fd0a1ab12816ac48c326 diff --git a/include/tvm/domain.h b/include/tvm/domain.h index ddea36881c847ec2c2451cfb2cd2f318e5daab13..a2c42a31f106721f917b3edfc5dde1591c877d66 100644 --- a/include/tvm/domain.h +++ b/include/tvm/domain.h @@ -46,6 +46,7 @@ class RDomain : public NodeRef { public: /*! \brief constructor*/ RDomain() {} + explicit RDomain(std::shared_ptr<Node> n) : NodeRef(n) {} /*! * constructor by domain * \param domain The domain of reduction. diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 6995016f404cf824ae3ab068a253d3d5910bc861..031386d96dbe1d49a98e90395e47c7be1f154f81 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -48,4 +48,4 @@ struct Reduce : public ExprNode<Reduce> { } // namespace ir } // namespace tvm -#endif // TVM_IR_NODE_H_ +#endif // TVM_IR_H_ diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 9cf38940d5c1fc44a65833df09b8701663b5a963..0cd03670bea24036bd0a954cd0cb0708e941ee85 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -61,6 +61,9 @@ class NodeBase(object): """ self.handle = handle + def __repr__(self): + return _function_internal.format_str(self) + def __del__(self): check_call(_LIB.TVMNodeFree(self.handle)) diff --git a/python/tvm/collections.py b/python/tvm/collections.py index 0063aff1968515efa66a0d42c4635f302d7fd30e..08350e75f31fc54f3f4c6e1587d793fc67d2039b 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -22,3 +22,8 @@ class Range(NodeBase): def __repr__(self): return ('Range(min='+ str(self.min) + ', extent=' + str(self.extent) + ')') + + +@register_node +class RDomain(NodeBase): + pass diff --git a/python/tvm/expr.py b/python/tvm/expr.py index e1010463473492bed21c210aefe8583fe33ed351..25ab709d04cb3039707b94e04afe6da68591244f 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -1,12 +1,8 @@ from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node -from . import function as _func from . import make as _make class Expr(NodeBase): - def __repr__(self): - return _func.format_str(self) - def __add__(self, other): return _make.Add(self, other) @@ -52,9 +48,14 @@ class CmpExpr(Expr): class LogicalExpr(Expr): pass + + @register_node("Variable") class Var(Expr): + pass +@register_node +class Reduce(Expr): pass @register_node diff --git a/python/tvm/function.py b/python/tvm/function.py index c25a90266a2ee0853d95ae9edcbe19941e71f39b..50b1f4ac5c444f4079b963ab8678320a55d9e1d3 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -1,8 +1,10 @@ from __future__ import absolute_import as _abs from numbers import Number as _Number, Integral as _Integral from ._ctypes._api import _init_function_module -from .import _function_internal -from .import make as _make +from . import _function_internal +from . import make as _make +from . import expr as _expr +from . import collections as _collections int32 = "int32" float32 = "float32" @@ -76,4 +78,98 @@ def Tensor(shape, fcompute=None, dtype=None, name="TensorObj"): shape, name, dtype, None, None) +def RDomain(dom): + """Create a reduction domain given domain + + Parameters + ---------- + dom : list of Range or list of pairs + The reduction domain. + + Returns + ------- + rdom : RDomain + The result rdomain + """ + if not isinstance(dom, (list, tuple)): + dom = [dom] + elif not isinstance(dom[0], (list, tuple)): + dom = [dom] + dnorm = [] + for x in dom: + if isinstance(x, (list, tuple)): + if len(x) != 2: + raise ValueError("need to list of ranges") + dnorm.append(Range(x[0], x[1])) + else: + dnorm.append(x) + dnorm = convert(dnorm) + return _function_internal._RDomain(dnorm) + + +def sum(expr, rdom): + """Create a sum expression over rdom + + Parameters + ---------- + expr : Expr + The source expression. + + rdom : RDomain + The reduction domainx + """ + assert isinstance(rdom, _collections.RDomain) + x = _make.Reduce("Add", expr, rdom) + return x + +def sum(expr, rdom): + """Create a sum expression over rdom + + Parameters + ---------- + expr : Expr + The source expression. + + rdom : RDomain + The reduction domainx + """ + assert isinstance(expr, _expr.Expr) + assert isinstance(rdom, _collections.RDomain) + x = _make.Reduce("Add", expr, rdom) + return x + +def min(expr, rdom): + """Create a min expression over rdom + + Parameters + ---------- + expr : Expr + The source expression. + + rdom : RDomain + The reduction domainx + """ + assert isinstance(expr, _expr.Expr) + assert isinstance(rdom, _collections.RDomain) + x = _make.Reduce("Min", expr, rdom) + return x + + +def max(expr, rdom): + """Create a min expression over rdom + + Parameters + ---------- + expr : Expr + The source expression. + + rdom : RDomain + The reduction domainx + """ + assert isinstance(expr, _expr.Expr) + assert isinstance(rdom, _collections.RDomain) + x = _make.Reduce("Max", expr, rdom) + return x + + _init_function_module("tvm") diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index 42575bc5748fe1511be36c30d6d0cb61d4648a27..b7b3d9f956e3e750bb6986bc2c81cac53672c705 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -18,7 +18,7 @@ namespace tvm { using ArgStack = const std::vector<APIVariantValue>; using RetValue = APIVariantValue; -TVM_REGISTER_API(format_str) +TVM_REGISTER_API(_format_str) .set_body([](const ArgStack& args, RetValue *ret) { using Halide::Internal::BaseExprNode; using Halide::Internal::BaseStmtNode; diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index 99628fb2ee5cfe1440568e3191561f8a3dd5a193..291e600952414082206d7c823d794203b03fc36a 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -4,11 +4,13 @@ * \file c_api_ir.cc */ #include <tvm/expr.h> +#include <tvm/ir.h> #include <ir/IROperator.h> #include "./c_api_registry.h" namespace tvm { +using namespace tvm::ir; using namespace Halide::Internal; using ArgStack = const std::vector<APIVariantValue>; @@ -29,6 +31,12 @@ TVM_REGISTER_API(_make_For) args.at(5)); }); +TVM_REGISTER_API(_make_Reduce) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = Reduce::make(args.at(0), + args.at(1), + args.at(2)); + }); TVM_REGISTER_API(_make_Call) .set_body([](const ArgStack& args, RetValue *ret) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 4307387aaa13c25b2bc4e5dcfca190f74bb52d5e..a702d8f5f7b117aa666e7986ee85c4ce9f3360c0 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -17,11 +17,22 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); namespace Halide { namespace Internal { +using tvm::ir::Reduce; + template<> -void ExprNode<tvm::ir::Reduce>::accept(IRVisitor *v) const { +void ExprNode<Reduce>::accept(IRVisitor *v) const { LOG(FATAL) << "Reduce do not work with IRVisitor yet"; } +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { + p->stream << "reduce(" + << op->op + << ", "; + p->print(op->source); + p->stream << ", rdom=" << op->rdom << ")"; +}); + } // namespace Internal } // namespace Halide @@ -31,7 +42,7 @@ namespace ir { // reduce TVM_REGISTER_NODE_TYPE(Reduce); -Expr make(std::string op, Expr source, RDomain rdom) { +Expr Reduce::make(std::string op, Expr source, RDomain rdom) { auto n = std::make_shared<Reduce>(); CHECK(source.defined()); n->type = source.type(); diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py index 29561ead88ffeae7538a7e95e901f1518ddfee71..06847ae58865eb10ef81efba3c32d247fa4906ca 100644 --- a/tests/python/test_tensor.py +++ b/tests/python/test_tensor.py @@ -11,6 +11,17 @@ def test_tensor(): assert(tuple(T.shape) == (m, n, l)) assert(A.source is None) +def test_tensor_reduce(): + m = tvm.Var('m') + n = tvm.Var('n') + l = tvm.Var('l') + A = tvm.Tensor((m, l), name='A') + B = tvm.Tensor((n, l), name='B') + T = tvm.Tensor((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) + rd = tvm.RDomain(tvm.Range(A.shape[1])) + C = tvm.Tensor((m, n), lambda i, j: tvm.sum(T(i, j, rd.index[0]), rdom=rd)) + print(tvm.format_str(C.source)) if __name__ == "__main__": test_tensor() + test_tensor_reduce()