diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1e532367a321fc918f833bdd4a13ee0acdd527f1..cf21ea95054933f028f8047589b9fb28c2e5018d 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,7 +3,7 @@ import ast import operator import sys -from .util import make_nop, halide_imm_types +from .util import make_nop, halide_imm_types, is_docstring from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" - lst = list(map(visit, lst)) + lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] if not lst: return make_nop() diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 43d26e859560a0b62eb6e5881e527ac36c605a36..2a43957e97068933cce367401b0cad159e5d6a3e 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -1,5 +1,6 @@ """Internal utilities for parsing Python subset to HalideIR""" +import ast import inspect import numpy from .intrin import HYBRID_GLOBALS @@ -22,6 +23,11 @@ def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) +def is_docstring(node): + """Checks if a Python AST node is a docstring""" + return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) + + def _pruned_source(func): """Prune source code's extra leading spaces""" lines = inspect.getsource(func).split('\n') diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 0f500d7c704f97e72c5d5291e468715b94f2d1d0..ef0bcf8f72e58b6ac21a095e0eeabd6b9a415bce 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -43,6 +43,7 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'): @script def outer_product(n, m, a, b, c): + """This is a simple outer product""" for i in range(n): for j in range(m): c[i, j] = a[i] * b[j]