From 1d7c52f2b23a53c14dfc3c69f4ef3e57d1e557a7 Mon Sep 17 00:00:00 2001 From: Jian Weng <werefluke@gmail.com> Date: Mon, 27 Aug 2018 13:33:27 -0700 Subject: [PATCH] add docstring skip in hybrid script (#1668) * add docstring skip in hybrid script * fix lint --- python/tvm/hybrid/parser.py | 4 ++-- python/tvm/hybrid/util.py | 6 ++++++ tests/python/unittest/test_hybrid_script.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1e532367a..cf21ea950 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,7 +3,7 @@ import ast import operator import sys -from .util import make_nop, halide_imm_types +from .util import make_nop, halide_imm_types, is_docstring from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" - lst = list(map(visit, lst)) + lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] if not lst: return make_nop() diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 43d26e859..2a43957e9 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -1,5 +1,6 @@ """Internal utilities for parsing Python subset to HalideIR""" +import ast import inspect import numpy from .intrin import HYBRID_GLOBALS @@ -22,6 +23,11 @@ def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) +def is_docstring(node): + """Checks if a Python AST node is a docstring""" + return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) + + def _pruned_source(func): """Prune source code's extra leading spaces""" lines = inspect.getsource(func).split('\n') diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 0f500d7c7..ef0bcf8f7 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -43,6 +43,7 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'): @script def outer_product(n, m, a, b, c): + """This is a simple outer product""" for i in range(n): for j in range(m): c[i, j] = a[i] * b[j] -- GitLab