diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 9310d46dce88ed7b8e75e4471ce5c64b77afefde..ff18bf5bd0f1ed5ddffc44baf4ff13b07feca4f3 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -8,6 +8,21 @@ from topi import util as util from .environment import get_env +def _match_pragma(stmt, key): + """Internal helper to match stmt to pragma stmt. + + Parameters + ---------- + stmt : Stmt + The AttrStmt + + key : str + The pragma key + """ + return ((stmt.attr_key == "pragma_" + key) or + (stmt.attr_key == "pragma_scope" and stmt.value.value == key)) + + def fold_uop_loop(stmt_in): """Detect and fold uop loop. @@ -255,7 +270,7 @@ def inject_skip_copy(stmt_in): Transformed statement """ def _do_fold(stmt): - if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"): + if _match_pragma(stmt, "skip_dma_copy"): return tvm.make.Evaluate(0) return None return tvm.ir_pass.IRTransform( @@ -277,12 +292,12 @@ def inject_coproc_sync(stmt_in): """ success = [False] def _do_fold(stmt): - if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync": + if _match_pragma(stmt, "coproc_sync"): success[0] = True sync = tvm.make.Call( "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync)) - elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop": + elif _match_pragma(stmt, "trim_loop"): op = stmt.body assert isinstance(op, tvm.stmt.For) return tvm.make.For( @@ -561,7 +576,7 @@ def annotate_alu_coproc_scope(stmt_in): """ env = get_env() def _do_fold(stmt): - if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"): + if _match_pragma(stmt, "alu"): irb = tvm.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(env.dev.QID_COMPUTE)) @@ -569,7 +584,7 @@ def annotate_alu_coproc_scope(stmt_in): tvm.make.StringImm("VTAPushALUOp")) irb.emit(stmt) return irb.get() - elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"): + elif _match_pragma(stmt, "skip_alu"): return tvm.make.Evaluate(0) return stmt @@ -631,7 +646,7 @@ def inject_alu_intrin(stmt_in): return rev_src_coeff, rev_dst_coeff, rev_extents - if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"): + if _match_pragma(stmt, "alu"): # Get to the innermost loop body loop_body = stmt.body nest_size = 0