From d4a46898e03fae0cbb74564bf2d3e274dcc76734 Mon Sep 17 00:00:00 2001 From: xqdan <danxiaoqiang@126.com> Date: Sat, 20 Jan 2018 11:57:39 +0800 Subject: [PATCH] Support dump ir for each pass (#693) (#791) * Support dump ir for each pass(#693) * expose DumpIR * fix comments * fix comments --- python/tvm/build_module.py | 89 ++++++++++++++++++++++- tests/python/unittest/test_pass_unroll.py | 19 ++++- 2 files changed, 105 insertions(+), 3 deletions(-) mode change 100644 => 100755 python/tvm/build_module.py diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py old mode 100644 new mode 100755 index fe6b01bb4..8b52b11d8 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -5,18 +5,96 @@ LoweredFunc and compiled Module. """ from __future__ import absolute_import as _abs import warnings +import types from . import api from . import tensor from . import schedule from . import expr from . import ir_pass +from . import stmt as _stmt from . import container from . import module from . import codegen from . import ndarray from . import target as _target +class DumpIR(object): + """Dump IR for each pass. + With it, you can dump ir just like gcc/llvm. + + How to use: + ----------- + .. code-block:: python + + with tvm.build_config(dump_pass_ir=True) + run() + + """ + scope_level = 0 + def __init__(self): + self._pass_id = 0 + self._recover_list = [] + + def decorate(self, func): + ''' decorate the pass function''' + def dump(*args, **kwargs): + '''dump function''' + retv = func(*args, **kwargs) + if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): + return retv + pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc" + with open(pname, "a") as f: + out = retv.body if isinstance(retv, container.LoweredFunc) else retv + f.write(str(out)) + if isinstance(retv, container.Array): + for x in retv: + out = x.body if isinstance(x, container.LoweredFunc) else x + f.write("---------%s\n%s\n-----------\n"%(x.name, str(out))) + self._pass_id += 1 + return retv + return dump + + def decorate_irpass(self): + '''decorate ir_pass and ScheduleOps''' + self._old_sgpass = schedule.ScheduleOps + schedule.ScheduleOps = self.decorate(schedule.ScheduleOps) + vset = vars(ir_pass) + k = v = 0 + def recover(): + vset[k] = v + for k, v in vset.items(): + self._recover_list.append(recover) + vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v + + def decorate_custompass(self): + ''' decorate add_lower_pass pass in BuildConfig''' + cfg = BuildConfig.current + self._old_custom_pass = cfg.add_lower_pass + custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] + pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass] + BuildConfig.current.add_lower_pass = pass_list + + def enter(self): + '''only decorate outermost nest''' + if DumpIR.scope_level > 0: + return + self.decorate_irpass() + self.decorate_custompass() + self._pass_id = 0 + DumpIR.scope_level += 1 + + def exit(self): + '''recover outermost nest''' + if DumpIR.scope_level > 1: + return + # recover decorated functions + for f in self._recover_list: + f() + schedule.ScheduleOps = self._old_sgpass + BuildConfig.current.add_lower_pass = self._old_custom_pass + DumpIR.scope_level -= 1 + class BuildConfig(object): """Configuration scope to set a build config option. @@ -37,10 +115,12 @@ class BuildConfig(object): "data_alignment": -1, "restricted_func": True, "double_buffer_split_loop": 1, - "add_lower_pass": None + "add_lower_pass": None, + "dump_pass_ir": False } def __init__(self, **kwargs): self._old_scope = None + self._dump_ir = DumpIR() for k, _ in kwargs.items(): if k not in BuildConfig.defaults: raise ValueError( @@ -59,10 +139,14 @@ class BuildConfig(object): attr.update(self._attr) self._attr = attr BuildConfig.current = self + if self.dump_pass_ir is True: + self._dump_ir.enter() return self def __exit__(self, ptype, value, trace): assert self._old_scope + if self.dump_pass_ir is True: + self._dump_ir.exit() BuildConfig.current = self._old_scope @@ -115,6 +199,8 @@ def build_config(**kwargs): phase contains an integer on which optimization pass we apply the pass. Additional lowering passes to be applied before make_api. + dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False + Returns ------- config: BuildConfig @@ -247,7 +333,6 @@ def lower(sch, return stmt return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) - def build(sch, args=None, target=None, diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index 9e52a455e..b158113e6 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -1,4 +1,5 @@ import tvm +import os def test_unroll_loop(): dtype = 'int64' @@ -24,4 +25,20 @@ def test_unroll_loop(): if __name__ == "__main__": - test_unroll_loop() + with tvm.build_config(dump_pass_ir=True): + test_unroll_loop() + + def end_with(*suffix): + ends = suffix + def run(s): + f = map(s.endswith, ends) + if True in f: return s + return run + + file_list = os.listdir('./') + cc_file = end_with('.cc') + cc_file = filter(cc_file, file_list) + assert len(cc_file) == 3 + for i in cc_file: + os.remove(i) + -- GitLab