diff --git a/CMakeLists.txt b/CMakeLists.txt index 8bfca8020c3c69d68f14c3225c4b9175173f44ef..98bbc5b650d3106cc6f73bd773c184984aa12ef3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,7 @@ include_directories("3rdparty/compiler-rt") # initial variables set(TVM_LINKER_LIBS "") set(TVM_RUNTIME_LINKER_LIBS "") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Generic compilation options if(MSVC) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index fdaed2b5be403cbd09674f38b47409c6ac31f8d8..f8da87d8cfd2c5a0b6b5696f1a64db70a19a8bf9 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -22,13 +22,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun @tvm.hybrid.script def outer_product(a, b, c): + c = output_tensor((100, 99), 'float32') for i in range(a.shape[0]): for j in range(b.shape[0]): c[i, j] = a[i] * b[j] - a = numpy.random.rand(100) - b = numpy.random.rand(99) - c = numpy.zeros((100, 99)) - outer_product(a, b, c) + return c + a = numpy.random.randn(100) + b = numpy.random.randn(99) + c = outer_product(a, b) + This decorator will import `Keywords`_ required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need @@ -40,25 +42,25 @@ or ``numpy`` numeric type. Backend Compilation ~~~~~~~~~~~~~~~~~~~ +This function is not encouraged to use, users are encouraged to use the second interface. The current parse interface looks like: .. code-block:: python a = tvm.placeholder((100, ), name='a') b = tvm.placeholder((99, ), name='b') - c = tvm.placeholder((100, 99), name='c') - tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function + parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function -If we pass these tvm tensors to this function, it returns a op node: -**Under construction, we are still deciding what kind of node should be returned.** +If we pass these tvm tensors to this function, it returns a op node: .. code-block:: python a = tvm.placeholder((100, ), name='a') b = tvm.placeholder((99, ), name='b') - c = tvm.placeholder((100, 99), name='c') - op = outer_product(a, b, c) # return the corresponding op node + c = outer_product(a, b, c) # return the output tensor(s) of the operator + +**Under construction, we are still deciding what kind of node should be returned.** Tuning ~~~~~~ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 1a1d28ab71bbc27689737ea9533491ce3ffd7565..02cd0d016f39f6f6c95763b25931a092ae6fe0e8 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -450,6 +450,69 @@ class ExternOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); }; +/*! + * \brief A computation operator that generated by hybrid script. + */ +class HybridOpNode : public OperationNode { + public: + /*! \brief The input tensors */ + Array<Tensor> inputs; + /*! \brief Symbolic placeholder representation of outputs */ + Array<Tensor> outputs; + /*! \brief the statement that generates the computation. This is + * slightly different from the body in ExternOpNode. All the output + * tensors keep its own name specified by users in the script. + * However, when compilation, these tensors will be placed by those + * actual output tensors. */ + Stmt body; + + /*! \brief constructor */ + HybridOpNode() {} + // override functions + int num_outputs() const final; + Array<IterVar> root_iter_vars() const final; + Type output_dtype(size_t i) const final; + Array<Expr> output_shape(size_t i) const final; + Array<Tensor> InputTensors() const final; + Operation ReplaceInputs( + const Operation& self, + const std::unordered_map<Tensor, Tensor>& rmap) const final; + void PropBoundToInputs( + const Operation& self, + const std::unordered_map<const Variable*, IntSet>& dom_map, + std::unordered_map<Tensor, TensorDom>* out_dom_map) const final; + void GatherBound( + const Operation& self, + const std::unordered_map<Tensor, TensorDom>& tensor_dom, + std::unordered_map<IterVar, Range>* out_dom_map) const final; + Stmt BuildRealize( + const Stage& stage, + const std::unordered_map<IterVar, Range>& realize_map, + const Stmt& body) const final; + Stmt BuildProvide( + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map, + bool debug_keep_trivial_loop) const final; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + v->Visit("body", &body); + } + EXPORT static Operation make(std::string name, + std::string tag, + Map<std::string, NodeRef> attrs, + Array<Tensor> inputs, + Array<Tensor> outputs, + Stmt body); + + static constexpr const char* _type_key = "HybridOp"; + TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode); +}; + /*! \brief The compute function to specify the input source of a Tensor */ using FCompute = std::function<Expr (const Array<Var>& i)>; diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 2bb7442bab7657cd50dd51537b54822364c5dce2..d65642340bad35d693014817f252eac08bf005de 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -340,11 +340,6 @@ def lower(sch, bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.InjectPrefetch(stmt) - else: - #So far there is no op for hybrid script, so a plain ir body is given - if not isinstance(sch, _stmt.Stmt): - raise ValueError("sch should be either a Schedule or a Stmt") - stmt = sch for f in lower_phase0: stmt = f(stmt) diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py index e0a39c562f0fb06c73c7799db97ef1cf8c3e1717..6c137490c38e2962aa5c4df7840f4155a9ad9351 100644 --- a/python/tvm/hybrid/__init__.py +++ b/python/tvm/hybrid/__init__.py @@ -7,4 +7,5 @@ python semantic emulation. 2. Developers can build HalideIR by writing Python code. """ -from .api import script, parse +from .api import script +from .parser import parse_python diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py index 48e192d4ba39e78b7c04ed5ddfdb972a16776c25..5267731f4f5210c6dec6c8b9f297997cdc78146c 100644 --- a/python/tvm/hybrid/api.py +++ b/python/tvm/hybrid/api.py @@ -1,9 +1,12 @@ """APIs of lowering the Python subset to HalideIR""" from __future__ import absolute_import as _abs -import types from .._ffi.base import decorate +from .. import _api_internal as _tvm_internal +from ..tensor import Tensor + from .parser import parse_python +from .util import _pruned_source def script(pyfunc): @@ -17,40 +20,26 @@ def script(pyfunc): hybrid_func : function A decorated hybrid script function. """ - def wrapped_func(func, *args, **kwargs): + def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types if _is_tvm_arg_types(args): - return parse(func, args) + src = _pruned_source(func) + parser = parse_python(src, args) + + input_tensors = [] + for i in args: + if isinstance(i, Tensor): + input_tensors.append(i) + + op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, + parser.outputs, parser.parsed_body) + res = [op.output(i) for i in range(len(parser.outputs))] + + return res[0] if len(res) == 1 else res intersect = _enter_hybrid_runtime(func) value = func(*args, **kwargs) _restore_runtime(func, intersect) return value - return decorate(pyfunc, wrapped_func) - - -def parse(func, args): - """Parse a subset of Python to HalideIR - Parameters - ---------- - func : str or types.FunctionType - If it is a string, parse the source code - If it is a function, parse the function - - args : list of Buffer or Tensor or Var - The argument lists to the function. - Leave it None if no buffer is related to the function to be parsed - - Returns - ------- - root : Stmt - The result Halide IR and the parser class instance. - """ - from .util import _pruned_source - if isinstance(func, str): - src = func - else: - assert isinstance(func, types.FunctionType) - src = _pruned_source(func) - return parse_python(src, args) + return decorate(pyfunc, wrapped_func) diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py index b3fb64579b6061d993e4232e3f3d8b1d4ac9cdec..92e259585b7a7520bb67471ad074b0e1a7fb4d15 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -48,6 +48,7 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar """ return numpy.zeros(shape).astype(dtype) +output_tensor = allocate #pylint: disable=invalid-name def popcount(x): """ @@ -87,18 +88,19 @@ def sigmoid(x): HYBRID_GLOBALS = { - 'unroll' : unroll, - 'vectorize' : vectorize, - 'parallel' : parallel, - 'allocate' : allocate, - 'bind' : bind, - 'sqrt' : numpy.sqrt, - 'log' : numpy.log, - 'tanh' : numpy.tanh, - 'power' : numpy.power, - 'exp' : numpy.exp, - 'sigmoid' : sigmoid, - 'popcount' : popcount + 'unroll' : unroll, + 'vectorize' : vectorize, + 'parallel' : parallel, + 'allocate' : allocate, + 'output_tensor': output_tensor, + 'bind' : bind, + 'sqrt' : numpy.sqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount } diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index cf21ea95054933f028f8047589b9fb28c2e5018d..a16f5abd434991518fc9fa19bcae9f519b8279d5 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -2,8 +2,9 @@ import ast import operator +import logging import sys -from .util import make_nop, halide_imm_types, is_docstring +from .util import make_nop, halide_imm_types, is_docstring, _internal_assert from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -72,15 +73,17 @@ class HybridParser(ast.NodeVisitor): The name of the function to be lowered; if not provided, the compiler will use the name in the AST """ - self.args = args[:] + self.args = list(args) self.usage = usage.copy() self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) - self.var_buffers = {} # Buffers formed by mutatble variables self.alloc_buffers = {} # Buffers formed by allocate instructions self.loops_above = {} # State variable that indicates loop levels above the current node self.var_consts = {} # Variables that are determined as readonly in previous stage self.func_name = func_name # The name of the function to be lowered - self.iter_axis = [] + self.outputs = [] # Output tensors' name + self.side_effect = set() # Tensors with side effects + self.parsed_body = None # The parsed HalideIR body + self.returned = False def wrap_up_realize(self, node, body): @@ -90,9 +93,8 @@ class HybridParser(ast.NodeVisitor): continue _, level, _ = val if level == node: - if key in self.var_buffers.keys(): - _buf = self.var_buffers[key] - _scope = 'global' + if key in self._args.keys(): + continue else: _buf, _scope = self.alloc_buffers[key] _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] @@ -103,12 +105,13 @@ class HybridParser(ast.NodeVisitor): return body - def _get_buffer_from_id(self, s): - if s not in self._args.keys() and s not in self.alloc_buffers.keys(): - raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) - if s in self._args.keys() and s in self.alloc_buffers.keys(): - raise ValueError("%s, a buffer cannot be both argument and allocated!" % s) + def _get_buffer_from_id(self, s, for_provide=False): + _internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1, + "This %s is expected to be in either \ + argument list or allocated buffer!" % s) if s in self._args.keys(): + if for_provide: + self.side_effect.add(self._args[s]) return self._args[s] return self.alloc_buffers[s][0] @@ -116,15 +119,15 @@ class HybridParser(ast.NodeVisitor): #pylint: disable=invalid-name, missing-docstring def visit_Module(self, node): - if len(node.body) != 1: - raise ValueError("Only one-function source code can be fed to this parser!") + _internal_assert(len(node.body) == 1, \ + "Only one-function source code can be fed to this parser!") return self.visit(node.body[0]) def visit_FunctionDef(self, node): - if len(node.args.args) != len(self.args): - raise ValueError("The number of arguments passed to the function\ - should be the same as it is defined!") + _internal_assert(len(node.args.args) == len(self.args), \ + "The number of arguments passed to the \ + function should be the same as it is defined!") for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self._args[getattr(arg, _attr)] = self.args[idx] @@ -145,17 +148,17 @@ class HybridParser(ast.NodeVisitor): return self._args[_id] elif _id in self.loops_above.keys(): return self.loops_above[_id] - if _id in self._args.keys(): - raise ValueError("This id %s should be handled in visit_Subscript!" % _id) - if _id not in self.usage.keys(): - raise ValueError("This id %s is expected to be a defined variable!" % _id) + _internal_assert(_id not in self._args.keys(), \ + "This id %s should be handled in visit_Subscript!" % _id) + _internal_assert(_id in self.usage.keys(), \ + "This id %s is expected to be a defined variable!" % _id) # Buffer - if _id in self.var_buffers.keys(): - _buf = self.var_buffers[_id] + if _id in self.alloc_buffers.keys(): + _buf, _ = self.alloc_buffers[_id] return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) # Compilation time constant - if _id not in self.var_consts.keys(): - raise ValueError("This id %s is expected to a compilation time constant!" % _id) + _internal_assert(_id in self.var_consts.keys(), + "This id %s is expected to a compilation time constant!" % _id) return self.var_consts[_id] @@ -164,8 +167,7 @@ class HybridParser(ast.NodeVisitor): def visit_Assign(self, node): - if len(node.targets) != 1: - raise ValueError("So far only one-valued assignment is supported!") + _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] rhs = self.visit(node.value) if isinstance(rhs, _expr.Expr): @@ -174,36 +176,40 @@ class HybridParser(ast.NodeVisitor): #TODO: support defined intermediate buffer later lhs_ = lhs lhs = lhs.id - if lhs in self.loops_above.keys(): - raise ValueError("You CAN NEVER overwrite a loop variable!") + _internal_assert(lhs not in self.loops_above.keys(), \ + "Loop variable cannot be overwritten!") decl, _, rw = self.usage[lhs] if decl == lhs_: - if lhs in self.var_consts.keys(): - raise ValueError("BUG: A constant cannot be overwritten!") - if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys(): - raise ValueError("BUG: This value should not be defined before this point!") + _internal_assert(lhs not in self.var_consts.keys(), \ + "A constant cannot be overwritten!") + _internal_assert(lhs not in self.alloc_buffers.keys(), \ + "This value should not be defined before this point!") if isinstance(rhs, tuple): shape, dtype, scope = rhs ph = _api.placeholder(shape, dtype=dtype, name=lhs) - self.alloc_buffers[lhs] = (ph, scope) + if scope != 'output': + self.alloc_buffers[lhs] = (ph, scope) + else: + self._args[lhs] = ph + self.outputs.append(lhs) return make_nop() if isinstance(rhs, halide_imm_types) and ast.Store not in rw: self.var_consts[lhs] = rhs else: - self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) + ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) + self.alloc_buffers[lhs] = (ph, 'global') if lhs in self.var_consts.keys(): return make_nop() - else: - if lhs not in self.var_buffers.keys(): - raise ValueError("BUG: This variable should be defined before!") - tgt = self.var_buffers[lhs] - return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) + _internal_assert(lhs in self.alloc_buffers.keys(), \ + "This variable should be defined before!") + tgt, _ = self.alloc_buffers[lhs] + return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) else: lhs = self.visit(lhs) - if not isinstance(lhs, _expr.Call): - raise ValueError("An array access's LHS is expected to be a expr.Call!") + _internal_assert(isinstance(lhs, _expr.Call), \ + "An array access's LHS is expected to be a expr.Call!") #TODO: support slice later - buf = self._get_buffer_from_id(lhs.name) + buf = self._get_buffer_from_id(lhs.name, for_provide=True) return _make.Provide(buf.op, 0, rhs, lhs.args) @@ -219,21 +225,20 @@ class HybridParser(ast.NodeVisitor): array = node.value.id _buf = self._get_buffer_from_id(array) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) - elif isinstance(node.value, ast.Attribute): - if not isinstance(node.value.value, ast.Name): - raise ValueError("The root of array access is expect to be a id!") - if node.value.attr != "shape": - raise ValueError("Attribute access so far only 'shape' is supported!") - if len(args) != 1: - raise ValueError("For 'shape' access the argument should be only one!") - args = args[0] - #TODO: maybe support non-constant value later? - if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): - raise ValueError("So far only constant shape access supported!") - buf = self._get_buffer_from_id(node.value.value.id) - return buf.shape[args.value] - else: - raise ValueError("Not supported yet!") + + _internal_assert(isinstance(node.value, ast.Attribute), \ + "Only variable and attribute's subscript supported so far") + _internal_assert(isinstance(node.value.value, ast.Name), \ + "The root of array access is expect to be a id!") + _internal_assert(node.value.attr == "shape", \ + "Attribute access so far only 'shape' is supported!") + _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") + args = args[0] + #TODO: maybe support non-constant value later? + _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ + "So far only constant shape access supported!") + buf = self._get_buffer_from_id(node.value.value.id) + return buf.shape[args.value] def visit_With(self, node): @@ -241,14 +246,11 @@ class HybridParser(ast.NodeVisitor): context = node.context_expr option = node.optional_vars else: - if len(node.items) != 1: - raise ValueError("Only one with element is supported so far!") + _internal_assert(len(node.items) == 1, "Only one with element is supported so far!") context = node.items[0].context_expr option = node.items[0].optional_vars - if not isinstance(context, ast.Call): - raise ValueError("The object must be a Python function call!") - if not isinstance(option, ast.Name): - raise ValueError("The object after 'as' must be an id!") + _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!") + _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!") self.annotation[option.id] = context.func.id return list_to_block(self.visit, node.body) @@ -272,10 +274,8 @@ class HybridParser(ast.NodeVisitor): def visit_Compare(self, node): lhs = self.visit(node.left) - if len(node.ops) != 1: - raise ValueError("Only one compare op is supported!") - if len(node.comparators) != 1: - raise ValueError("Only one comparator is supported!") + _internal_assert(len(node.ops) == 1, "Only one compare op is supported!") + _internal_assert(len(node.comparators) == 1, "Only one comparator is supported!") rhs = self.visit(node.comparators[0]) return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs) @@ -293,16 +293,15 @@ class HybridParser(ast.NodeVisitor): def visit_Call(self, node): # Yet, no function pointer supported - if not isinstance(node.func, ast.Name): - raise ValueError("Only id-function function call is supported so far!") + _internal_assert(isinstance(node.func, ast.Name), \ + "Only id-function function call is supported so far!") func_id = node.func.id n = len(node.args) if func_id in LOOP_INTRIN.keys() and func_id != 'bind': if n == 1: low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) else: - if n != 2: - raise ValueError("A loop intrinsic should only have 1 or 2 arguments!") + _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") low, ext = self.visit(node.args[0]), self.visit(node.args[1]) if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): ext = ext - low @@ -310,10 +309,9 @@ class HybridParser(ast.NodeVisitor): iter_var = None return iter_var, low, ext, for_type elif func_id == 'bind': - if n != 2: - raise ValueError("A loop bind should only have 2 arguments!") - if not isinstance(node.args[0], ast.Str): - raise ValueError("A loop bind's first argument should be a string!") + _internal_assert(n == 2, "A loop bind should only have 2 arguments!") + _internal_assert(isinstance(node.args[0], ast.Str), \ + "A loop bind's first argument should be a string!") _vn = node.args[0].s iter_var = thread_axis(node.args[0].s) low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) @@ -321,29 +319,39 @@ class HybridParser(ast.NodeVisitor): return iter_var, low, ext, for_type elif func_id in MATH_INTRIN: return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) - elif func_id == 'allocate': - if not isinstance(node.args[0], ast.Tuple): - raise ValueError("allocate's first argument should be a tuple of shape!") + elif func_id in ['allocate', 'output_tensor']: + _internal_assert(isinstance(node.args[0], ast.Tuple), \ + "allocate's first argument should be a tuple of shape!") shape = tuple(self.visit(i) for i in node.args[0].elts) + if func_id == 'output_tensor': + _internal_assert(not self.loops_above, \ + "Are you sure to allocate a output buffer multiple times?") for i in shape: - if not isinstance(i, _expr.Expr): - raise ValueError("The shape should be an expression") + _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression") if n > 1: - if not isinstance(node.args[1], ast.Str): - raise ValueError("The data type should be an string") - dtype = node.args[1].s + if isinstance(node.args[1], ast.Str): + dtype = node.args[1].s + else: + _internal_assert(isinstance(node.args[1], ast.Attribute), \ + "Unable to evaluate to get data type") + to_eval = node.args[1] + _internal_assert(isinstance(to_eval.value, ast.Name), \ + "Unable to evaluate the attribute to get data type") + _internal_assert(to_eval.attr == 'dtype', \ + "Only dtype attribute is supported so far") + dtype = self._get_buffer_from_id(to_eval.value.id).dtype else: dtype = 'float32' if n > 2: - if not isinstance(node.args[2], ast.Str): - raise ValueError("The data type should be an string") + _internal_assert(isinstance(node.args[2], ast.Str), \ + "The data scope should be an string") + _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope") scope = node.args[2].s else: - scope = 'global' + scope = 'global' if func_id != 'output_tensor' else 'output' return (shape, dtype, scope) elif func_id == 'max' or func_id == 'min': - if n != 2: - raise ValueError("Max/Min function should have 2 elements") + _internal_assert(n == 2, "Max/Min function should have 2 elements") a, b = self.visit(node.args[0]), self.visit(node.args[1]) return getattr(_make, func_id.title())(a, b) else: @@ -352,19 +360,17 @@ class HybridParser(ast.NodeVisitor): def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) - if not isinstance(node.target, ast.Name): - raise ValueError("The loop iterator should be a variable!") + _internal_assert(isinstance(node.target, ast.Name), \ + "The loop iterator should be a variable!") _name = node.target.id if iter_var is None: - if for_type is None: - raise ValueError("The loop bind function parse error!") + _internal_assert(for_type is not None, "The loop bind function parse error!") offset = iter_var = _api.var(_name) if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): offset = iter_var + low self.loops_above[_name] = offset else: - if for_type is not None: - raise ValueError("The loop iterating function parse error!") + _internal_assert(for_type is None, "The loop iterating function parse error!") self.loops_above[_name] = iter_var.var _body = list_to_block(self.visit, node.body) _body = self.wrap_up_realize(node, _body) @@ -376,10 +382,46 @@ class HybridParser(ast.NodeVisitor): return res + def visit_Return(self, node): + _internal_assert(not self.loops_above, "Return should not be in a loop body!") + ids = [] + if isinstance(node.value, ast.Name): + ids.append(node.value.id) + else: + _internal_assert(isinstance(node.value, ast.Tuple), \ + "You should return either a single tensor or a tuple") + for i in node.value.elts: + _internal_assert(isinstance(i, ast.Name), "What do you return?") + ids.append(i.id) + _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") + if len(ids) != len(self.outputs): + logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') + self.outputs = [self._args[i] for i in ids] + self.returned = True + return make_nop() + + def parse_python(src, args): - """The helper function of calling the AST visitor""" + """The helper function of calling the AST visitor + + Parameters + ---------- + src : str + The source code of the function to be parsed. + + args : list of Tensors or Vars + The argument lists to the function. + It is NOT encouraged to write a function without arguments. + It is NOT encouraged to write a function with side effect. + + Returns + ------- + root : Stmt + The result Halide IR and the parser class instance. + """ root = ast.parse(src) var_usage = determine_variable_usage(root, args) parser = HybridParser(args, var_usage) - halide_ir = parser.visit(root) - return halide_ir + parser.parsed_body = parser.visit(root) + _internal_assert(parser.returned, 'No valid return found in the function body!') + return parser diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 2a43957e97068933cce367401b0cad159e5d6a3e..e38f466381ff419886d0e17fc80a7035f320f2ef 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -2,6 +2,8 @@ import ast import inspect +import logging +import sys import numpy from .intrin import HYBRID_GLOBALS from .._ffi.base import numeric_types @@ -30,10 +32,17 @@ def is_docstring(node): def _pruned_source(func): """Prune source code's extra leading spaces""" - lines = inspect.getsource(func).split('\n') - leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) - lines = [line[leading_space:] for line in lines] - return '\n'.join(lines) + try: + lines = inspect.getsource(func).split('\n') + leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) + lines = [line[leading_space:] for line in lines] + return '\n'.join(lines) + except IOError as err: + if sys.version_info[0] == 2 and str(err) == 'could not get source code': + logging.log(logging.CRITICAL, \ + 'This module is not fully operated under Python2... ' \ + 'Please move to Python3!') + raise err def _is_tvm_arg_types(args): @@ -70,3 +79,12 @@ def _restore_runtime(func, intersect): _globals.pop(elem) for k, v in intersect: _globals[k] = v + +def _internal_assert(cond, err): + """Simplify the code segment like if not XXX then raise an error""" + if not cond: + raise ValueError(err) + +# Almost the same functionality as the one above, but in this case, +# the error is caused by users inproper usage. +_user_assert = _internal_assert diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index df38bac1acba5f41c23658b635bae86210484615..586ef95461ea8dcbd8a072e511ff79f4ff05d57a 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -3,6 +3,7 @@ import ast import sys from .intrin import HYBRID_GLOBALS +from .util import _internal_assert class PyVariableUsage(ast.NodeVisitor): @@ -18,8 +19,8 @@ class PyVariableUsage(ast.NodeVisitor): def visit_FunctionDef(self, node): self.scope_level.append(node) - if len(node.args.args) != len(self.args): - raise ValueError('#arguments passed should be the same as #arguments defined') + _internal_assert(len(node.args.args) == len(self.args), \ + '#arguments passed should be the same as #arguments defined') for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self._args[getattr(arg, _attr)] = self.args[idx] @@ -28,8 +29,8 @@ class PyVariableUsage(ast.NodeVisitor): def visit_For(self, node): - if not isinstance(node.target, ast.Name): - raise ValueError("For's iterator should be an id") + _internal_assert(isinstance(node.target, ast.Name), \ + "For's iterator should be an id") self.visit(node.iter) self.scope_level.append(node) for i in node.body: @@ -39,11 +40,10 @@ class PyVariableUsage(ast.NodeVisitor): def visit_Call(self, node): #No function pointer supported so far - if not isinstance(node.func, ast.Name): - raise ValueError("Function call should be an id") + _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id") func_id = node.func.id - if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']: - raise ValueError("Function call id not in intrinsics' list") + _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \ + "Function call id not in intrinsics' list") for elem in node.args: self.visit(elem) @@ -56,12 +56,12 @@ class PyVariableUsage(ast.NodeVisitor): if node.id in fors: return # The loop variable cannot be overwritten when iteration - if isinstance(node.ctx, ast.Store) and node.id in fors: - raise ValueError("Iter var cannot be overwritten") + _internal_assert(not isinstance(node.ctx, ast.Store) or node.id not in fors, \ + "Iter var cannot be overwritten") if node.id not in self.status.keys(): - if not isinstance(node.ctx, ast.Store): - raise ValueError('In Python, "first store" indicates "declaration"') + _internal_assert(isinstance(node.ctx, ast.Store), \ + 'Undeclared variable %s' % node.id) self.status[node.id] = (node, self.scope_level[-1], set()) else: decl, loop, usage = self.status[node.id] diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f32b70eb9a12e80e665a2557a5da7cc93e82d6a2..9a98e9a6e769192dc400135ff9117fd9cf6d385a 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -180,3 +180,8 @@ class ScanOp(Operation): class ExternOp(Operation): """Extern operation.""" pass + +@register_node +class HybridOp(Operation): + """Hybrid operation.""" + pass diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 3525e23b8b20793c77645566c19d8d175a7270b0..e30111e938bd117981f2a3ac15549c7b38561cc7 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -313,6 +313,16 @@ TVM_REGISTER_API("_ExternOp") args[6]); }); +TVM_REGISTER_API("_HybridOp") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = HybridOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4], + args[5]); + }); + TVM_REGISTER_API("_OpGetOutput") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Operation().output( diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4dbb2c0b964fb98ecd54246f5b9b1a65df0b87b2 --- /dev/null +++ b/src/op/hybrid_op.cc @@ -0,0 +1,189 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Hybrid computation rule. + * \file hybrid_op.cc + */ +#include <tvm/operation.h> +#include <tvm/arithmetic.h> +#include <tvm/ir.h> +#include <tvm/ir_mutator.h> +#include <unordered_set> +#include "op_util.h" + +namespace tvm { +using namespace ir; +// HybridOpNode +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) { + p->stream << "hybrid(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(HybridOpNode); + +int HybridOpNode::num_outputs() const { + return static_cast<int>(outputs.size()); +} + +Array<IterVar> HybridOpNode::root_iter_vars() const { + return {}; +} + +Type HybridOpNode::output_dtype(size_t i) const { + return outputs[i]->dtype; +} + +Array<Expr> HybridOpNode::output_shape(size_t i) const { + return outputs[i]->shape; +} + + +Operation HybridOpNode::make(std::string name, + std::string tag, + Map<std::string, NodeRef> attrs, + Array<Tensor> inputs, + Array<Tensor> outputs, + Stmt body) { + if (!attrs.defined()) { + attrs = Map<std::string, NodeRef>(); + } + auto n = make_node<HybridOpNode>(); + n->name = std::move(name); + n->tag = std::move(tag); + n->attrs = std::move(attrs); + n->inputs = std::move(inputs); + n->outputs = std::move(outputs); + n->body = std::move(body); + Operation res = Operation(n); + return res; +} + +Array<Tensor> HybridOpNode::InputTensors() const { + return inputs; +} + +Operation HybridOpNode::ReplaceInputs( + const Operation& self, + const std::unordered_map<Tensor, Tensor>& rmap) const { + CHECK_EQ(self.operator->(), this); + auto n = make_node<HybridOpNode>(*this); + n->body = op::ReplaceTensor(this->body, rmap); + for (size_t i = 0; i < n->inputs.size(); ++i) { + Tensor t = n->inputs[i]; + if (rmap.count(t)) { + n->inputs.Set(i, rmap.at(t)); + } + } + + if (body.same_as(n->body) && + inputs.same_as(n->inputs)) { + return self; + } else { + return Operation(n); + } +} + +void HybridOpNode::PropBoundToInputs( + const Operation& self, + const std::unordered_map<const Variable*, IntSet>& dom_map, + std::unordered_map<Tensor, TensorDom>* out_dom_map) const { + for (Tensor t : this->inputs) { + auto it = out_dom_map->find(t); + if (it == out_dom_map->end()) continue; + TensorDom& dom = it->second; + for (size_t i = 0; i < t->shape.size(); ++i) { + dom.data[i].emplace_back(IntSet::range( + Range::make_by_min_extent( + make_const(t->shape[i].type(), 0), t->shape[i]))); + } + } +} + +void HybridOpNode::GatherBound( + const Operation& self, + const std::unordered_map<Tensor, TensorDom>& tensor_dom, + std::unordered_map<IterVar, Range>* out_dom_map) const { +} + +Stmt HybridOpNode::BuildRealize( + const Stage& stage, + const std::unordered_map<IterVar, Range>& realize_map, + const Stmt& body) const { + CHECK_EQ(stage->op.get(), this); + Stmt realize_body = body; + for (int k = 0; k < num_outputs(); ++k) { + Tensor t = stage->op.output(k); + HalideIR::Internal::Region bounds; + for (size_t i = 0; i < t->shape.size(); ++i) { + bounds.push_back( + Range::make_by_min_extent( + make_const(t->shape[i].type(), 0), t->shape[i])); + } + realize_body = ir::Realize::make( + t->op, t->value_index, t->dtype, + bounds, const_true(), realize_body); + } + return realize_body; +} + +Stmt HybridOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map<IterVar, Range>& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); + auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { + Array<NodeRef> bind_spec; + Array<Expr> tuple; + bind_spec.push_back(buffer); + bind_spec.push_back(tensor); + for (size_t k = 0; k < buffer->shape.size(); ++k) { + tuple.push_back(make_const(buffer->shape[k].type(), 0)); + tuple.push_back(buffer->shape[k]); + } + ret = AttrStmt::make( + bind_spec, attr::buffer_bind_scope, + Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret); + }; + for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) { + Buffer buffer = decl_buffer( + outputs[i]->shape, + outputs[i]->dtype); + f_push_bind(buffer, stage->op.output(i)); + } + for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) { + Buffer buffer = decl_buffer( + inputs[i]->shape, + inputs[i]->dtype); + f_push_bind(buffer, inputs[i]); + } + + std::unordered_map<Tensor, Tensor> rmap; + for (int i = 0; i < this->num_outputs(); ++i) { + rmap[outputs[i]] = stage->op.output(i); + } + auto n = make_node<HybridOpNode>(*this); + /* + * These two lines of codes replace tensors' reads & writes. + * This is the simplest way I (@were) can come up with to glue + * hybrid scripts to the structure of TVM op. + * NAMING CONFLICT: In hybrid script all the tensors have their own + * names specified by the users. However, In TVM op, all the output + * tensors' names are the same as the op's name. I cannot change the + * name to the op's name in the function body after the op node is + * formed, because: + * 1. Output tensors all point to the corresponding op node. + * 2. Once OpNode is wrapped up by an Operation node, it can + * no longer be changed. + * This is a chiken-egg paradox. It is impossible to put the output + * tensors into the function body without forming the op node. The + * function body is immutable after the node is formed. + * + * Finally, I decided to resolve this issue "lazily". During the + * pipeline of compilation, these tensors will be replaced when + * forming the function body and passing to next stage of compilation. + * */ + ret = op::ReplaceTensor(ret, rmap); + ret = op::ReplaceProvideTensor(ret, rmap); + return ret; +} +} // namespace tvm diff --git a/src/op/op_util.cc b/src/op/op_util.cc index ba83997a0a16646adbbdc0e51ee240f81f238d81..886f7c9123032c9c79098eb855a03bb83cf83950 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -164,6 +164,37 @@ std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { return nest; } +// replacer to replace tensors' usage in Provide +class ProviderReplacer : public ir::IRMutator { + public: + explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) + : vmap_(vmap) {} + + Stmt Mutate_(const ir::Provide* op, const Stmt& s) { + Tensor t = Operation(op->func.node_).output(op->value_index); + auto it = vmap_.find(t); + if (it != vmap_.end()) { + Stmt ret = ir::Provide::make( + it->second->op, it->second->value_index, op->value, op->args); + found = true; + return IRMutator::Mutate_(ret.as<ir::Provide>(), ret); + } + return IRMutator::Mutate_(op, s); + } + + // whether it is found. + bool found{false}; + + private: + const std::unordered_map<Tensor, Tensor>& vmap_; +}; + +Stmt ReplaceProvideTensor(Stmt stmt, + const std::unordered_map<Tensor, Tensor>& replace) { + ProviderReplacer repl(replace); + Stmt ret = repl.Mutate(stmt); + return repl.found ? ret : stmt; +} // replacer to replace tensors class TensorReplacer : public ir::IRMutator { diff --git a/src/op/op_util.h b/src/op/op_util.h index 558e8d4e7324c48e0c57b126d7f9148341ca8889..6971f14eef731310449637644f444119ab0a5833 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -49,14 +49,22 @@ MakeLoopNest(const Stage& stage, std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates); /*! - * \brief Replace the tensor reference in stmt by the replace map. + * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map. + * \param stmt The statement to be processed. + * \param replace The replacement rule. + */ +Stmt ReplaceProvideTensor(Stmt stmt, + const std::unordered_map<Tensor, Tensor>& replace); + +/*! + * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param stmt The statement to be processed. * \param replace The replacement rule. */ Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace); /*! - * \brief Replace the tensor reference in expr by the replace map. + * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 3124586ca34390ff75d6dbd0ab3560c0369623b4..9156e40f949f0702a9bcdf05f2f04e08f546d064 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -3,7 +3,7 @@ from tvm.hybrid import script from tvm.hybrid.intrin import HYBRID_GLOBALS @nose.tools.nottest -def run_and_check(func, args, outs, var_dict={}, target='llvm'): +def run_and_check(func, args, var_dict={}, target='llvm'): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) @@ -14,39 +14,50 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'): emu_args = [] nd_args = [] - to_check = [] for i in args: if isinstance(i, tvm.tensor.Tensor): shape = [tvm_val_2_py_val(j) for j in i.shape] - if i in outs: - emu_args.append(numpy.zeros(shape).astype(i.dtype)) - nd_args.append(tvm.nd.array(emu_args[-1], ctx)) - to_check.append((nd_args[-1], emu_args[-1])) - else: - emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) - nd_args.append(tvm.nd.array(emu_args[-1], ctx)) + emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) + nd_args.append(tvm.nd.array(emu_args[-1], ctx)) else: assert isinstance(i, tvm.expr.Var) emu_args.append(tvm_val_2_py_val(i)) nd_args.append(emu_args[-1]) - func(*emu_args) - - lowerd_func = tvm.lower(func(*args), args) - module = tvm.build(lowerd_func, target=target) + outs = func(*args) + op = outs[0].op if isinstance(outs, list) else outs.op + sch = tvm.create_schedule(op) + module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target) assert module + + out_tensors = [] + for i in range(op.num_outputs): + output = op.output(i) + shape = [tvm_val_2_py_val(j) for j in output.shape] + nd_args.append(tvm.nd.array(numpy.zeros(shape).astype(output.dtype), ctx)) + out_tensors.append(nd_args[-1]) + + ref_data = func(*emu_args) + if isinstance(ref_data, numpy.ndarray): + ref_data = [ref_data] + module(*nd_args) - for nd, np in to_check: + for nd, np in zip(out_tensors, ref_data): tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) @script -def outer_product(n, m, a, b, c): - """This is a simple outer product""" +def outer_product(n, m, a, b): + """This is a simple outer product. + Actually this function is not required to be documented. + I write this docstring to test skipping docstring functionality. + """ + c = output_tensor((n, m), a.dtype) for i in range(n): for j in range(m): c[i, j] = a[i] * b[j] + return c #Test global function #Test bridge between frontend and backend @@ -55,8 +66,14 @@ def test_outer_product(): m = tvm.var('m') a = tvm.placeholder((n, ), name='a') b = tvm.placeholder((m, ), name='b') - c = tvm.placeholder((n, m), name='c') - ir = outer_product(n, m, a, b, c) + + try: + c = outer_product(n, m, a, b) + ir = c.op.body + except IOError as err: + assert sys.version_info[0] == 2 and str(err) == 'could not get source code' + return + #Check for i in (0, n) assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i' @@ -81,10 +98,8 @@ def test_outer_product(): assert mul.a.name == 'a' assert mul.b.name == 'b' - func = tvm.lower(ir, [n, m, a, b, c]) - func = tvm.build(func) - run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001}) + run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) for key, _ in HYBRID_GLOBALS.items(): assert key not in globals().keys() @@ -94,19 +109,25 @@ def test_outer_product(): #Test allocation of local variable def test_fanout(): @script - def fanout(n, a, b): + def fanout(n, a): three = 3.0 + b = output_tensor((a.shape[0] - 3, ), a.dtype) for i in range(a.shape[0] - 3): sigma = 0.0 for j in range(3): sigma = sigma + a[i + j] sigma = sigma / three b[i] = sigma + return b n = tvm.var('n') a = tvm.placeholder((n, ), 'float32', name='a') - b = tvm.placeholder((n-3, ), 'float32', name='b') - ir = fanout(n, a, b) + try: + b = fanout(n, a) + ir = b.op.body + except IOError as err: + assert sys.version_info[0] == 2 and str(err) == 'could not get source code' + return #Check for i in (0, n-3) assert isinstance(ir, tvm.stmt.For) @@ -163,38 +184,31 @@ def test_fanout(): assert len(write.value.args) == 1 assert write.value.args[0].value == 0 - run_and_check(fanout, [n, a, b], [b], {n: 10}) - - -@script -def failure(): - for i in range(1, 100): - i = 0 - -def test_failure(): - try: - tvm.hybrid.parse(failure, []) - except IOError as err: - assert sys.version_info[0] == 2 - print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err)) - except Exception as err: - assert str(err) == 'You CAN NEVER overwrite a loop variable!' + run_and_check(fanout, [n, a], {n: 10}) def test_looptype(): @script def looptype(a, b, c): + d = output_tensor((8, ), 'int32') + e = output_tensor((8, ), 'int32') + f = output_tensor((8, ), 'int32') for i in parallel(8): - a[i] = i + d[i] = a[i] for j in vectorize(8): - b[j] = j + e[j] = b[j] for k in unroll(8): - c[k] = k + f[k] = c[k] + return d, e, f a = tvm.placeholder((8, ), name='a', dtype='int32') b = tvm.placeholder((8, ), name='b', dtype='int32') c = tvm.placeholder((8, ), name='c', dtype='int32') - ir = looptype(a, b, c) + try: + d, e, f = looptype(a, b, c) + ir = d.op.body + except: + return iloop = ir.first jloop = ir.rest.first kloop = ir.rest.rest @@ -202,24 +216,26 @@ def test_looptype(): assert jloop.for_type == tvm.stmt.For.Vectorized assert kloop.for_type == tvm.stmt.For.Unrolled - run_and_check(looptype, [a, b, c], [a, b, c]) + run_and_check(looptype, [a, b, c]) def test_if(): @script - def if_then_else(a, b): + def if_then_else(a): + b = output_tensor((10, ), 'int32') + c = output_tensor((10, ), 'int32') for i in range(10): if i % 2 == 0: - a[i] = -1 + c[i] = a[i] else: - a[i] = 1 + c[i] = b[i] for i in unroll(10): b[i] = -1 if i % 2 == 0 else 1 + return b, c a = tvm.placeholder((10, ), dtype='int32', name='a') - b = tvm.placeholder((10, ), dtype='int32', name='b') - run_and_check(if_then_else, [a, b], [a, b]) + run_and_check(if_then_else, [a]) def test_bind(): @@ -227,55 +243,66 @@ def test_bind(): print('[Warning] No GPU found! Skip bind test!') return @script - def vec_add(a, b, c): + def vec_add(a, b): + c = output_tensor((1000, ), dtype='float32') for tx in bind('threadIdx.x', 1000): c[tx] = b[tx] + c[tx] + return c a = tvm.placeholder((1000, ), dtype='float32', name='a') b = tvm.placeholder((1000, ), dtype='float32', name='b') - c = tvm.placeholder((1000, ), dtype='float32', name='c') - run_and_check(vec_add, [a, b, c], [c], target='cuda') + run_and_check(vec_add, [a, b], target='cuda') def test_math_intrin(): @script def intrin_real(a): - a[0] = sqrt(a[0]) - a[1] = log(a[1]) - a[2] = exp(a[2]) - a[3] = sigmoid(a[3]) - a[4] = power(a[4], a[5]) - a[5] = tanh(a[5]) - a[6] = min(a[4], a[5]) - a[7] = max(a[5], a[6]) + b = output_tensor((8, ), 'float32') + b[0] = sqrt(a[0]) + b[1] = log(a[1]) + b[2] = exp(a[2]) + b[3] = sigmoid(a[3]) + b[4] = power(a[4], a[5]) + b[5] = tanh(a[5]) + b[6] = min(a[4], a[5]) + b[7] = max(a[5], a[6]) + return b a8 = tvm.placeholder((8, ), dtype='float32', name='a') - ir = intrin_real(a8) - func = tvm.build(tvm.lower(ir, [a8])) + b8 = intrin_real(a8) + sch = tvm.create_schedule(b8.op) + func = tvm.build(sch, [a8, b8]) assert func a = numpy.arange(2, 10).astype('float32') tvm_a = tvm.ndarray.array(a) - func(tvm_a) - intrin_real(a) - tvm.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5) + tvm_b = tvm.ndarray.array(numpy.zeros((8, ), dtype='float32')) + b = intrin_real(a) + func(tvm_a, tvm_b) + tvm.testing.assert_allclose(b, tvm_b.asnumpy(), rtol=1e-5) @script def intrin_int(a): - a[0] = popcount(a[0]) + b = output_tensor((1, ), 'int32') + b[0] = popcount(a[0]) + return b a1 = tvm.placeholder((1, ), dtype='int32') - ir = intrin_int(a1) - func = tvm.build(tvm.lower(ir, [a1])) + b1 = intrin_int(a1) + sch = tvm.create_schedule(b1.op) + func = tvm.build(sch, [a1, b1]) assert func - a = numpy.array([1234567890]).astype('int32') + a = numpy.array([114514]).astype('int32') tvm_a = tvm.ndarray.array(a) - intrin_int(a) - func(tvm_a) - assert tvm_a.asnumpy()[0] == a[0] + tvm_b = tvm.ndarray.array(numpy.array([0]).astype('int32')) + b = intrin_int(a) + func(tvm_a, tvm_b) + assert tvm_b.asnumpy()[0] == b[0] +# test non caconical loops def test_non_zero(): @tvm.hybrid.script - def blur(a, b): + def blur(a): + b = output_tensor((30, 30), 'float32') for i in range(2, 32): for j in range(2, 32): s = 0.0 @@ -283,29 +310,28 @@ def test_non_zero(): for dj in range(3): s = s + a[i-di, j-dj] b[i-2, j-2] = s / 9.0 - try: - a = tvm.placeholder((32, 32), 'float32', 'a') - b = tvm.placeholder((30, 30), 'float32', 'b') - run_and_check(blur, [a, b], [b]) - except IOError as err: - assert sys.version_info[0] == 2 - print('[Warning] Case test_non_zero is skipped by Python2 because "%s"' % str(err)) + return b + + a = tvm.placeholder((32, 32), 'float32', 'a') + run_and_check(blur, [a]) @tvm.hybrid.script - def triangle(a, b, c): + def triangle(a, b): + c = output_tensor((10, 10), dtype='float32') for i in range(10): for j in range(i, 10): c[i, j] = a[i] * b[j] + return c a = tvm.placeholder((10, ), dtype='float32', name='a') b = tvm.placeholder((10, ), dtype='float32', name='b') - c = tvm.placeholder((10, 10), dtype='float32', name='c') - run_and_check(triangle, [a, b, c], [c]) + run_and_check(triangle, [a, b]) def test_allocate(): @tvm.hybrid.script - def blur2d(a, b): + def blur2d(a): + b = output_tensor((30, 30), 'float32') for i in range(30): ha = allocate((3, 30), 'float32') for j in range(3): @@ -313,15 +339,15 @@ def test_allocate(): ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2] for j in range(30): b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0 + return b a = tvm.placeholder((32, 32), 'float32', 'a') - b = tvm.placeholder((30, 30), 'float32', 'b') - - run_and_check(blur2d, [a, b], [b]) + run_and_check(blur2d, [a]) if tvm.gpu().exist: @tvm.hybrid.script - def share_vec_add(a, b, c): + def share_vec_add(a, b): + c = output_tensor((256, ), 'float32') shared = allocate((256, ), 'float32', 'shared') for i in bind("threadIdx.x", 256): shared[i] = a[i] @@ -330,23 +356,81 @@ def test_allocate(): local[i] = b[i] for i in bind("threadIdx.x", 256): c[i] = shared[i] + local[i] + return c a = tvm.placeholder((256, ), dtype='float32', name='a') b = tvm.placeholder((256, ), dtype='float32', name='b') - c = tvm.placeholder((256, ), dtype='float32', name='c') - run_and_check(share_vec_add, [a, b, c], [c], target='cuda') + run_and_check(share_vec_add, [a, b], target='cuda') else: print('[Warning] No GPU found! Skip shared mem test!') +def test_upstream(): + @tvm.hybrid.script + def upstream(a): + b = output_tensor((20, ), 'float32') + for i in range(20): + b[i] = a[i] * i + return b + + a = tvm.placeholder((20, ), 'float32') + b = tvm.placeholder((20, ), 'float32') + c = tvm.compute((20, ), lambda x: a[x] + b[x]) + d = upstream(c) + sch = tvm.create_schedule([c.op, d.op]) + ir = tvm.lower(sch, [a, b, d], simple_mode=True) + func = tvm.build(sch, [a, b, d]) + assert(func) + + a = numpy.random.randn(20).astype('float32') + b = numpy.random.randn(20).astype('float32') + ref = numpy.zeros((20, ), 'float32') + for i in range(20): + ref[i] = (a[i] + b[i]) * i + + tvm_a = tvm.nd.array(a) + tvm_b = tvm.nd.array(b) + tvm_d = tvm.nd.array(numpy.zeros((20, )).astype('float32')) + + func(tvm_a, tvm_b, tvm_d) + tvm.testing.assert_allclose(tvm_d.asnumpy(), ref, 1e-5, 1e-5) + +def test_downstream(): + @tvm.hybrid.script + def downstream(a): + b = output_tensor((20, ), 'float32') + for i in range(20): + b[i] = a[i] * i + return b + + a = tvm.placeholder((20, ), 'float32') + b = downstream(a) + c = tvm.compute((20, ), lambda x: b[x] + 1.0) + sch = tvm.create_schedule(c.op) + module = tvm.build(sch, [a, c]) + assert module + + a = numpy.random.randn(20).astype('float32') + ref = numpy.zeros((20, )).astype('float32') + for i in range(20): + ref[i] = (a[i] * i) + 1.0 + + tvm_a = tvm.nd.array(a) + tvm_c = tvm.nd.array(numpy.zeros((20, )).astype('float32')) + module(tvm_a, tvm_c) + tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5) + if __name__ == "__main__": test_outer_product() test_fanout() - test_failure() test_looptype() test_if() test_bind() test_math_intrin() test_non_zero() test_allocate() + #test_inplace() + test_upstream() + test_downstream() +