From 838e7181dd5e16aa5c7a87bbfbf11b5b42a80f1a Mon Sep 17 00:00:00 2001
From: Jian Weng <werefluke@gmail.com>
Date: Wed, 19 Dec 2018 15:20:34 -0800
Subject: [PATCH] [Hybrid Script] Inter-function call supported! (#2287)

---
 python/tvm/hybrid/api.py                    |   4 +-
 python/tvm/hybrid/calls.py                  |  92 ++++++
 python/tvm/hybrid/intrin.py                 |  15 +-
 python/tvm/hybrid/parser.py                 | 307 ++++++++++----------
 python/tvm/hybrid/util.py                   |  18 ++
 python/tvm/hybrid/var_decl.py               |  15 +-
 tests/python/unittest/test_hybrid_script.py |  36 ++-
 7 files changed, 303 insertions(+), 184 deletions(-)
 create mode 100644 python/tvm/hybrid/calls.py

diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py
index 5267731f4..d43217ca5 100644
--- a/python/tvm/hybrid/api.py
+++ b/python/tvm/hybrid/api.py
@@ -24,17 +24,15 @@ def script(pyfunc):
         from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
         if _is_tvm_arg_types(args):
             src = _pruned_source(func)
-            parser = parse_python(src, args)
+            parser = parse_python(src, func.__globals__, args)
 
             input_tensors = []
             for i in args:
                 if isinstance(i, Tensor):
                     input_tensors.append(i)
-
             op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
                                          parser.outputs, parser.parsed_body)
             res = [op.output(i) for i in range(len(parser.outputs))]
-
             return res[0] if len(res) == 1 else res
 
         intersect = _enter_hybrid_runtime(func)
diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py
new file mode 100644
index 000000000..730b56f58
--- /dev/null
+++ b/python/tvm/hybrid/calls.py
@@ -0,0 +1,92 @@
+"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
+semantic support."""
+
+from .. import api as _api
+from .. import expr as _expr
+from .. import make as _make
+from ..container import Array
+from .. import ir_pass
+from ..stmt import For
+from .util import _internal_assert
+
+#pylint: disable=redefined-builtin
+
+LOOP_INTRIN = {
+    'range'    : For.Serial,
+    'unroll'   : For.Unrolled,
+    'parallel' : For.Parallel,
+    'vectorize': For.Vectorized,
+}
+
+def _range(annotation, args):
+    """Handling TVM loop types"""
+    n = len(args)
+    if n == 1:
+        low, ext = _api.const(0, dtype='int32'), args[0]
+    else:
+        _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
+        low, ext = args[0], args[1]
+    if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
+        ext = ext - low
+    for_type = LOOP_INTRIN[annotation]
+    iter_var = None
+    return iter_var, low, ext, for_type
+
+
+range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name
+
+
+def bind(func_id, args):
+    """Handling TVM thread binding"""
+    _internal_assert(func_id == "bind", "This function cannot be directly invoked!")
+    _internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
+    _internal_assert(isinstance(args[0], str), \
+                     "A loop bind's first argument should be a string!")
+    iter_var = _api.thread_axis(args[0])
+    low, ext = _api.const(0), args[1]
+    for_type = None
+    return iter_var, low, ext, for_type
+
+
+def _math_intrin(func_id, args):
+    from .. import intrin
+    return getattr(intrin, func_id)(*args)
+
+sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
+
+
+def _min_max(func_id, args):
+    _internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
+    return getattr(_make, func_id.title())(args[0], args[1])
+
+
+min = max = _min_max #pylint: disable=invalid-name
+
+
+def _allocate_tensor(func_id, args):
+    """Handling TVM tensor allocation.
+    You may refer hybrid.intrin.allocate for more details."""
+    n = len(args)
+    _internal_assert(isinstance(_api.convert(args[0]), Array), \
+                     "allocate's first argument should be a tuple of shape!")
+    shape = args[0]
+    for i in shape:
+        _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
+    if n > 1:
+        _internal_assert(isinstance(args[1], str),
+                         "The data type should be an str")
+        _internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
+                         "The data type should be either int or float!")
+        dtype = args[1]
+    else:
+        dtype = 'float32'
+    if n > 2:
+        _internal_assert(isinstance(args[2], str), \
+                         "The data scope should be an string")
+        _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
+        scope = args[2]
+    else:
+        scope = 'global' if func_id != 'output_tensor' else 'output'
+    return (shape, dtype, scope)
+
+output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py
index 92e259585..48e92a8bf 100644
--- a/python/tvm/hybrid/intrin.py
+++ b/python/tvm/hybrid/intrin.py
@@ -1,7 +1,6 @@
-"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
+"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
 
 import numpy
-from ..stmt import For
 
 class _range(object):
     """Base class of the loop ranges in hybrid script"""
@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
     'sigmoid'      : sigmoid,
     'popcount'     : popcount
 }
-
-
-LOOP_INTRIN = {
-    'range'    : For.Serial,
-    'unroll'   : For.Unrolled,
-    'parallel' : For.Parallel,
-    'vectorize': For.Vectorized,
-    'bind'     : None
-}
-
-
-MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py
index ba10dd8dd..26b0e141d 100644
--- a/python/tvm/hybrid/parser.py
+++ b/python/tvm/hybrid/parser.py
@@ -4,24 +4,24 @@ import ast
 import operator
 import logging
 import sys
-from .util import make_nop, halide_imm_types, is_docstring, _internal_assert
-from .intrin import LOOP_INTRIN, MATH_INTRIN
+from .util import _internal_assert
+from . import calls
+from . import util
 from .var_decl import determine_variable_usage
-from ..api import thread_axis
 from ..api import all as _all
 from ..api import any as _any
+from ..tensor import Tensor, Operation
 from .. import expr as _expr
 from .. import make as _make
-from .. import intrin
 from .. import api  as _api
 from .. import ir_pass as _ir_pass
 
 def list_to_block(visit, lst):
     """Convert a list of Python IR nodes to HalideIR Block"""
-    lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)]
-    lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
+    lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
+    lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
     if not lst:
-        return make_nop()
+        return util.make_nop()
     if len(lst) == 1:
         return lst[0]
     body = lst[0]
@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor):
     }
 
 
-    def __init__(self, args, usage, func_name=None):
+    def __init__(self, args, usage, symbols, func_name=None):
         """
         Parameters
         ----------
@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor):
         self.args = list(args)
         self.usage = usage.copy()
         self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
-        self.alloc_buffers = {} # Buffers formed by allocate instructions
+        self.alloc_buffers = {} # Buffers formed by explicit allocate instructions
         self.loops_above = {} # State variable that indicates loop levels above the current node
-        self.var_consts = {} # Variables that are determined as readonly in previous stage
+        self.variables = {} # The status of defined variables
         self.func_name = func_name # The name of the function to be lowered
         self.outputs = [] # Output tensors' name
         self.side_effect = set() # Tensors with side effects
         self.parsed_body = None # The parsed HalideIR body
-        self.returned = False
+        self.returned = False # If this function has a valid return
+        self.symbols = symbols # The global context
 
 
     def wrap_up_realize(self, node, body):
         """Wrap up all the variables which will no longer be used"""
+        pop_buf = []
+        pop_var = []
         for key, val in self.usage.items():
-            if key in self.var_consts.keys():
-                continue
             _, level, _ = val
-            if level == node:
-                if key in self._args.keys():
+            if level != node:
+                continue
+            if key in self._args.keys():
+                continue
+            if key in self.alloc_buffers.keys():
+                _buf, _scope = self.alloc_buffers[key]
+                if _scope == 'output':
                     continue
-                else:
-                    _buf, _scope = self.alloc_buffers[key]
-                _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
-                _dtype = _buf.dtype
-                _true = _api.convert(True)
-                body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
-                body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
+                pop_buf.append(key)
+            else:
+                _internal_assert(key in self.variables.keys(),
+                                 "Key should be either in one of args, buffers, and vars")
+                if not isinstance(self.variables[key], tuple):
+                    continue
+                _buf, _scope = self.variables[key]
+                pop_var.append(key)
+            _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
+            _dtype = _buf.dtype
+            _true = _api.convert(True)
+            body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
+            body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
+
+        for elem in pop_buf:
+            self.alloc_buffers.pop(elem)
+        for elem in pop_var:
+            self.variables.pop(elem)
         return body
 
 
@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor):
         return self.alloc_buffers[s][0]
 
 
-
     #pylint: disable=invalid-name, missing-docstring
     def visit_Module(self, node):
         _internal_assert(len(node.body) == 1, \
@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor):
         _internal_assert(len(node.args.args) == len(self.args), \
                          "The number of arguments passed to the \
                          function should be the same as it is defined!")
+        if self.func_name is None:
+            self.func_name = node.name
         for idx, arg in enumerate(node.args.args):
             _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
             self._args[getattr(arg, _attr)] = self.args[idx]
         res = list_to_block(self.visit, node.body)
         res = self.wrap_up_realize(node, res)
-        if self.func_name is None:
-            self.func_name = node.name
         return res
 
 
@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor):
 
 
     def visit_Name(self, node):
-        _id = node.id
-        if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)):
-            return self._args[_id]
-        elif _id in self.loops_above.keys():
-            return self.loops_above[_id]
-        _internal_assert(_id not in self._args.keys(), \
-                         "This id %s should be handled in visit_Subscript!" % _id)
-        _internal_assert(_id in self.usage.keys(), \
-                         "This id %s is expected to be a defined variable!" % _id)
-        # Buffer
-        if _id in self.alloc_buffers.keys():
-            _buf, _ = self.alloc_buffers[_id]
-            return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
-        # Compilation time constant
-        _internal_assert(_id in self.var_consts.keys(),
-                         "This id %s is expected to a compilation time constant!" % _id)
-        return self.var_consts[_id]
+        name = node.id
+        if name in self.loops_above.keys():
+            return self.loops_above[name]
+        elif name in self.variables.keys():
+            res = self.variables[name]
+            if isinstance(res, tuple):
+                buf = res[0]
+                if isinstance(node.ctx, ast.Load):
+                    return _make.Call(buf.dtype, buf.name, [_api.const(0)], \
+                                      _expr.Call.Halide, buf.op, buf.value_index)
+                return buf, [_api.const(0)]
+            if isinstance(node.ctx, ast.Load):
+                return res
+            return None
+        buf = self._get_buffer_from_id(name)
+        return buf
 
 
     def visit_Num(self, node):
@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor):
 
 
     def visit_AugAssign(self, node):
-        lhs = self.visit(node.target)
+        buf = self.visit(node.target)
         rhs = self.visit(node.value)
-        rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
-        _internal_assert(isinstance(lhs, _expr.Call), \
-                         "The LHS of an AugAssign is supposed to be a call!")
-        return _make.Provide(lhs.func, 0, rhs, lhs.args)
+        if isinstance(buf, tuple):
+            _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
+            buf, args = buf
+        else:
+            args = [_api.const(0)]
+        _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
+
+        read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
+        value = HybridParser._binop_maker[type(node.op)](read, rhs)
+
+        return _make.Provide(buf.op, 0, value, args)
 
 
     def visit_Assign(self, node):
+        rhs = self.visit(node.value)
+        if isinstance(rhs, Operation):
+            rmap = {}
+            _internal_assert(len(node.targets) == rhs.num_outputs, \
+                             "Unable to detuple the outs to targets")
+            for i in range(rhs.num_outputs):
+                _internal_assert(isinstance(node.targets[i], ast.Name),
+                                 "You should bind a pure name to the tensors")
+                self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global')
+                rmap[rhs.outputs[i].op] = rhs.output(i)
+            return util.replace_io(rhs.body, rmap)
+
         _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
         lhs = node.targets[0]
-        rhs = self.visit(node.value)
         if isinstance(rhs, _expr.Expr):
             rhs = _ir_pass.Simplify(rhs)
         if isinstance(lhs, ast.Name):
@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor):
                              "Loop variable cannot be overwritten!")
             decl, _, rw = self.usage[lhs]
             if decl == lhs_:
-                _internal_assert(lhs not in self.var_consts.keys(), \
-                                 "A constant cannot be overwritten!")
-                _internal_assert(lhs not in self.alloc_buffers.keys(), \
+                _internal_assert(lhs not in self.variables.keys() and
+                                 lhs not in self.alloc_buffers.keys(), \
                                  "This value should not be defined before this point!")
                 if isinstance(rhs, tuple):
                     shape, dtype, scope = rhs
                     ph = _api.placeholder(shape, dtype=dtype, name=lhs)
-                    if scope != 'output':
-                        self.alloc_buffers[lhs] = (ph, scope)
-                    else:
-                        self._args[lhs] = ph
+                    self.alloc_buffers[lhs] = (ph, scope)
+                    if scope == 'output':
                         self.outputs.append(lhs)
-                    return make_nop()
-                if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
-                    self.var_consts[lhs] = rhs
+                    return util.make_nop()
+                if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
+                    self.variables[lhs] = rhs
                 else:
                     ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
-                    self.alloc_buffers[lhs] = (ph, 'global')
-            if lhs in self.var_consts.keys():
-                return make_nop()
-            _internal_assert(lhs in self.alloc_buffers.keys(), \
-                             "This variable should be defined before!")
-            tgt, _ = self.alloc_buffers[lhs]
-            return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
+                    self.variables[lhs] = (ph, 'global')
+            lhs = self.visit(lhs_)
+            if lhs is not None:
+                buf, args = lhs
+                return _make.Provide(buf.op, 0, rhs, args)
+            return util.make_nop()
         else:
-            lhs = self.visit(lhs)
-            _internal_assert(isinstance(lhs, _expr.Call), \
+            lhs, args = self.visit(lhs)
+            _internal_assert(isinstance(lhs, Tensor), \
                              "An array access's LHS is expected to be a expr.Call!")
-            #TODO: support slice later
-            buf = self._get_buffer_from_id(lhs.name, for_provide=True)
-            return _make.Provide(buf.op, 0, rhs, lhs.args)
+            res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
+            return res
 
 
     def visit_Index(self, node):
         if isinstance(node.value, ast.Tuple):
-            return [self.visit(i) for i in node.value.elts]
+            return self.visit(node.value)
         return [self.visit(node.value)]
 
 
+    def visit_Attribute(self, node):
+        _internal_assert(isinstance(node.value, ast.Name), \
+                         "For atrribute access, only both names are supported so far!")
+        buf = self._get_buffer_from_id(node.value.id)
+        return getattr(buf, node.attr)
+
+
     def visit_Subscript(self, node):
         args = self.visit(node.slice)
         if isinstance(node.value, ast.Name):
-            array = node.value.id
-            _buf = self._get_buffer_from_id(array)
-            return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index)
-
-        _internal_assert(isinstance(node.value, ast.Attribute), \
-                         "Only variable and attribute's subscript supported so far")
-        _internal_assert(isinstance(node.value.value, ast.Name), \
-                         "The root of array access is expect to be a id!")
-        _internal_assert(node.value.attr == "shape", \
-                         "Attribute access so far only 'shape' is supported!")
+            buf = self.visit(node.value)
+            if isinstance(node.ctx, ast.Load):
+                return _make.Call(buf.dtype, buf.name, args, \
+                                  _expr.Call.Halide, buf.op, buf.value_index)
+            return buf, args
+
+        shape = self.visit(node.value)
         _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
         args = args[0]
         #TODO: maybe support non-constant value later?
         _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
                          "So far only constant shape access supported!")
-        buf = self._get_buffer_from_id(node.value.value.id)
-        return buf.shape[args.value]
+        return shape[args.value]
 
 
     def visit_With(self, node):
@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor):
         if node.orelse:
             else_body = list_to_block(self.visit, node.orelse)
         else:
-            else_body = make_nop()
+            else_body = util.make_nop()
         return _make.IfThenElse(cond, if_body, else_body)
 
 
@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor):
             _internal_assert(isinstance(node.op, ast.Not), \
                              "Unary is supposed to be not!")
             return operator.not_(self.visit(node.values[0]))
-        elif n == 2:
-            _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
-                             "Binary is supposed to be and/or!")
-            values = [self.visit(i) for i in node.values]
-            return HybridParser._binop_maker[type(node.op)](*values)
-        else:
-            raise ValueError("This Bool Op is not supported yet!")
+        _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
+                         "Binary is supposed to be and/or!")
+        values = [self.visit(i) for i in node.values]
+        return HybridParser._binop_maker[type(node.op)](*values)
 
 
     def visit_UnaryOp(self, node):
@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor):
         # Yet, no function pointer supported
         _internal_assert(isinstance(node.func, ast.Name), \
                          "Only id-function function call is supported so far!")
+
         func_id = node.func.id
-        n = len(node.args)
-        if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
-            if n == 1:
-                low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0])
-            else:
-                _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
-                low, ext = self.visit(node.args[0]), self.visit(node.args[1])
-            if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
-                ext = ext - low
-            for_type = LOOP_INTRIN[func_id]
-            iter_var = None
-            return iter_var, low, ext, for_type
-        elif func_id == 'bind':
-            _internal_assert(n == 2, "A loop bind should only have 2 arguments!")
-            _internal_assert(isinstance(node.args[0], ast.Str), \
-                             "A loop bind's first argument should be a string!")
-            _vn = node.args[0].s
-            iter_var = thread_axis(node.args[0].s)
-            low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
-            for_type = None
-            return iter_var, low, ext, for_type
-        elif func_id in MATH_INTRIN:
-            return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
-        elif func_id in ['allocate', 'output_tensor']:
-            _internal_assert(isinstance(node.args[0], ast.Tuple), \
-                             "allocate's first argument should be a tuple of shape!")
-            shape = tuple(self.visit(i) for i in node.args[0].elts)
-            if func_id == 'output_tensor':
-                _internal_assert(not self.loops_above, \
-                                 "Are you sure to allocate a output buffer multiple times?")
-            for i in shape:
-                _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
-            if n > 1:
-                if isinstance(node.args[1], ast.Str):
-                    dtype = node.args[1].s
-                else:
-                    _internal_assert(isinstance(node.args[1], ast.Attribute), \
-                                     "Unable to evaluate to get data type")
-                    to_eval = node.args[1]
-                    _internal_assert(isinstance(to_eval.value, ast.Name), \
-                                     "Unable to evaluate the attribute to get data type")
-                    _internal_assert(to_eval.attr == 'dtype', \
-                                     "Only dtype attribute is supported so far")
-                    dtype = self._get_buffer_from_id(to_eval.value.id).dtype
-            else:
-                dtype = 'float32'
-            if n > 2:
-                _internal_assert(isinstance(node.args[2], ast.Str), \
-                                 "The data scope should be an string")
-                _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
-                scope = node.args[2].s
-            else:
-                scope = 'global' if func_id != 'output_tensor' else 'output'
-            return (shape, dtype, scope)
-        elif func_id == 'max' or func_id == 'min':
-            _internal_assert(n == 2, "Max/Min function should have 2 elements")
-            a, b = self.visit(node.args[0]), self.visit(node.args[1])
-            return getattr(_make, func_id.title())(a, b)
-        else:
-            raise ValueError("Function call not supported yet!")
+        args = [self.visit(i) for i in node.args]
+        try:
+            return getattr(calls, func_id)(func_id, args)
+        except AttributeError:
+            _internal_assert(func_id in self.symbols.keys(), \
+                             "The function called is not in the context either!")
+            outs = self.symbols[func_id](*args)
+            op = outs.op if isinstance(outs, Tensor) else outs[0].op
+            return op
 
 
     def visit_For(self, node):
@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor):
         if iter_var is None:
             _internal_assert(for_type is not None, "The loop bind function parse error!")
             offset = iter_var = _api.var(_name)
-            if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
+            if not _ir_pass.Equal(low, _api.const(0)):
                 offset = iter_var + low
             self.loops_above[_name] = offset
         else:
@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor):
         if for_type is None:
             res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
         else:
-            res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body)
+            res = _make.For(iter_var, _api.const(0), ext, for_type, 0, _body)
         self.loops_above.pop(_name)
         return res
 
@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor):
                 _internal_assert(isinstance(i, ast.Name), "What do you return?")
                 ids.append(i.id)
         _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
-        if len(ids) != len(self.outputs):
+        if len(ids) < len(self.outputs):
             logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
-        self.outputs = [self._args[i] for i in ids]
+        self.outputs = [self.alloc_buffers[i][0] for i in ids]
         self.returned = True
-        return make_nop()
+        return util.make_nop()
+
+
+    def visit_Tuple(self, node):
+        return tuple(self.visit(i) for i in node.elts)
 
 
-def parse_python(src, args):
+    def visit_Str(self, node):
+        return node.s
+
+
+def parse_python(src, symbols, args):
     """The helper function of calling the AST visitor
 
     Parameters
@@ -443,6 +429,9 @@ def parse_python(src, args):
     src : str
         The source code of the function to be parsed.
 
+    src : str
+        The symbol list of the global context of the function.
+
     args : list of Tensors or Vars
         The argument lists to the function.
         It is NOT encouraged to write a function without arguments.
@@ -454,8 +443,8 @@ def parse_python(src, args):
         The result Halide IR and the parser class instance.
     """
     root = ast.parse(src)
-    var_usage = determine_variable_usage(root, args)
-    parser = HybridParser(args, var_usage)
+    var_usage = determine_variable_usage(root, args, symbols)
+    parser = HybridParser(args, var_usage, symbols)
     parser.parsed_body = parser.visit(root)
     _internal_assert(parser.returned, 'No valid return found in the function body!')
     return parser
diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py
index 78106838f..aa86d55a6 100644
--- a/python/tvm/hybrid/util.py
+++ b/python/tvm/hybrid/util.py
@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
 from .. import api as _api
 from .. import make as _make
 from .. import expr as _expr
+from .. import stmt as _stmt
 from ..tensor import Tensor
 
 
@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
         _globals.pop(elem)
     for k, v in intersect:
         _globals[k] = v
+
+
+def replace_io(body, rmap):
+    """Replacing tensors usage according to the dict given"""
+    from .. import ir_pass
+
+    def replace(op):
+        if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
+            buf = rmap[op.func]
+            return _make.Provide(buf.op, op.value_index, op.value, op.args)
+        elif isinstance(op, _expr.Call) and  op.func in rmap.keys():
+            buf = rmap[op.func]
+            return _make.Call(buf.dtype, buf.name, op.args, \
+                              _expr.Call.Halide, buf.op, buf.value_index)
+        return None
+
+    return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py
index 27df87874..eb893a7f2 100644
--- a/python/tvm/hybrid/var_decl.py
+++ b/python/tvm/hybrid/var_decl.py
@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
     """The vistor class to determine the declaration, r/w status, and last use of each variable"""
     #pylint: disable=invalid-name
     #pylint: disable=missing-docstring
-    def __init__(self, args):
+    def __init__(self, args, symbols):
         self.status = {}
         self.scope_level = []
         self._args = {}
         self.args = args
         self.aug_assign_ = False
+        self.symbols = symbols
 
 
     def visit_FunctionDef(self, node):
@@ -43,8 +44,10 @@ class PyVariableUsage(ast.NodeVisitor):
         #No function pointer supported so far
         _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
         func_id = node.func.id
-        _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
-                "Function call id not in intrinsics' list")
+        _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
+                         ['range', 'max', 'min'] + \
+                         list(self.symbols.keys()), \
+                         "Function call id not in intrinsics' list")
         for elem in node.args:
             self.visit(elem)
 
@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
         else:
             decl, loop, usage = self.status[node.id]
             usage.add(type(node.ctx))
+            _internal_assert(loop in self.scope_level,
+                             "%s is used out of the scope it is defined!" % node.id)
             self.status[node.id] = (decl, loop, usage)
 
 
-def determine_variable_usage(root, args):
+def determine_variable_usage(root, args, symbols):
     """The helper function for calling the dedicated visitor."""
-    visitor = PyVariableUsage(args)
+    visitor = PyVariableUsage(args, symbols)
     visitor.visit(root)
     return visitor.status
diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py
index 7efbbe43e..f87c75f79 100644
--- a/tests/python/unittest/test_hybrid_script.py
+++ b/tests/python/unittest/test_hybrid_script.py
@@ -270,7 +270,7 @@ def test_bind():
         return
     @script
     def vec_add(a, b):
-        c = output_tensor((1000, ), dtype='float32')
+        c = output_tensor((1000, ), 'float32')
         for tx in bind('threadIdx.x', 1000):
             c[tx] = a[tx] + b[tx]
         return c
@@ -506,7 +506,37 @@ def test_value_index():
     module(tvm.ndarray.array(np_a), res)
     tvm.testing.assert_allclose(res.asnumpy(), ref)
 
+def test_func_call():
+    @tvm.hybrid.script
+    def foo(a, b):
+        for i in range(10):
+            a[i] = i + 1.0
+        for i in range(10):
+            b[i] = i + 1.0
+        c = outer_product(10, 10, a, b)
+        d = output_tensor(c.shape, c.dtype)
+        for i in range(10):
+            for j in range(10):
+                d[i, j] = c[i, j] + i * j
+        return d
 
+    a = tvm.placeholder((10, ), name='a')
+    b = tvm.placeholder((10, ), name='b')
+    run_and_check(foo, [a, b])
+
+def test_bool():
+    @tvm.hybrid.script
+    def foo(a):
+        b = output_tensor(a.shape, a.dtype)
+        b[0] = 1.2
+        for i in range(1, a.shape[0] - 1):
+            if a[i] * a[i - 1] < a[i] or a[i] * a[i - 1] < a[i - 1] or i * a[i] == a[i]:
+                b[i] = a[i]
+            else:
+                b[i] = 0.0
+        return b
+    a = tvm.placeholder((10, ), name='a')
+    run_and_check(foo, [a])
 
 if __name__ == "__main__":
     test_outer_product()
@@ -521,7 +551,7 @@ if __name__ == "__main__":
     test_downstream()
     test_const_param()
     test_value_index()
+    test_func_call()
+    test_bool()
     # TODO:
     # test_inplace()
-
-
-- 
GitLab