From d4a46898e03fae0cbb74564bf2d3e274dcc76734 Mon Sep 17 00:00:00 2001
From: xqdan <danxiaoqiang@126.com>
Date: Sat, 20 Jan 2018 11:57:39 +0800
Subject: [PATCH] Support dump ir for each pass (#693) (#791)

* Support dump ir for each pass(#693)

* expose DumpIR

* fix comments

* fix comments
---
 python/tvm/build_module.py                | 89 ++++++++++++++++++++++-
 tests/python/unittest/test_pass_unroll.py | 19 ++++-
 2 files changed, 105 insertions(+), 3 deletions(-)
 mode change 100644 => 100755 python/tvm/build_module.py

diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py
old mode 100644
new mode 100755
index fe6b01bb4..8b52b11d8
--- a/python/tvm/build_module.py
+++ b/python/tvm/build_module.py
@@ -5,18 +5,96 @@ LoweredFunc and compiled Module.
 """
 from __future__ import absolute_import as _abs
 import warnings
+import types
 
 from . import api
 from . import tensor
 from . import schedule
 from . import expr
 from . import ir_pass
+from . import stmt as _stmt
 from . import container
 from . import module
 from . import codegen
 from . import ndarray
 from . import target as _target
 
+class DumpIR(object):
+    """Dump IR for each pass.
+       With it, you can dump ir just like gcc/llvm.
+
+       How to use:
+       -----------
+       .. code-block:: python
+
+          with tvm.build_config(dump_pass_ir=True)
+              run()
+
+    """
+    scope_level = 0
+    def __init__(self):
+        self._pass_id = 0
+        self._recover_list = []
+
+    def decorate(self, func):
+        ''' decorate the pass function'''
+        def dump(*args, **kwargs):
+            '''dump function'''
+            retv = func(*args, **kwargs)
+            if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
+                return retv
+            pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc"
+            with open(pname, "a") as f:
+                out = retv.body if isinstance(retv, container.LoweredFunc) else retv
+                f.write(str(out))
+                if isinstance(retv, container.Array):
+                    for x in retv:
+                        out = x.body if isinstance(x, container.LoweredFunc) else x
+                        f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
+                self._pass_id += 1
+            return retv
+        return dump
+
+    def decorate_irpass(self):
+        '''decorate ir_pass and ScheduleOps'''
+        self._old_sgpass = schedule.ScheduleOps
+        schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
+        vset = vars(ir_pass)
+        k = v = 0
+        def recover():
+            vset[k] = v
+        for k, v in vset.items():
+            self._recover_list.append(recover)
+            vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
+
+    def decorate_custompass(self):
+        ''' decorate add_lower_pass pass in BuildConfig'''
+        cfg = BuildConfig.current
+        self._old_custom_pass = cfg.add_lower_pass
+        custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
+        pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass]
+        BuildConfig.current.add_lower_pass = pass_list
+
+    def enter(self):
+        '''only decorate outermost nest'''
+        if DumpIR.scope_level > 0:
+            return
+        self.decorate_irpass()
+        self.decorate_custompass()
+        self._pass_id = 0
+        DumpIR.scope_level += 1
+
+    def exit(self):
+        '''recover outermost nest'''
+        if DumpIR.scope_level > 1:
+            return
+        # recover decorated functions
+        for f in self._recover_list:
+            f()
+        schedule.ScheduleOps = self._old_sgpass
+        BuildConfig.current.add_lower_pass = self._old_custom_pass
+        DumpIR.scope_level -= 1
+
 class BuildConfig(object):
     """Configuration scope to set a build config option.
 
@@ -37,10 +115,12 @@ class BuildConfig(object):
         "data_alignment": -1,
         "restricted_func": True,
         "double_buffer_split_loop": 1,
-        "add_lower_pass": None
+        "add_lower_pass": None,
+        "dump_pass_ir": False
     }
     def __init__(self, **kwargs):
         self._old_scope = None
+        self._dump_ir = DumpIR()
         for k, _ in kwargs.items():
             if k not in BuildConfig.defaults:
                 raise ValueError(
@@ -59,10 +139,14 @@ class BuildConfig(object):
         attr.update(self._attr)
         self._attr = attr
         BuildConfig.current = self
+        if self.dump_pass_ir is True:
+            self._dump_ir.enter()
         return self
 
     def __exit__(self, ptype, value, trace):
         assert self._old_scope
+        if self.dump_pass_ir is True:
+            self._dump_ir.exit()
         BuildConfig.current = self._old_scope
 
 
@@ -115,6 +199,8 @@ def build_config(**kwargs):
         phase contains an integer on which optimization pass we apply the pass.
         Additional lowering passes to be applied before make_api.
 
+    dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
+
     Returns
     -------
     config: BuildConfig
@@ -247,7 +333,6 @@ def lower(sch,
         return stmt
     return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
 
-
 def build(sch,
           args=None,
           target=None,
diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py
index 9e52a455e..b158113e6 100644
--- a/tests/python/unittest/test_pass_unroll.py
+++ b/tests/python/unittest/test_pass_unroll.py
@@ -1,4 +1,5 @@
 import tvm
+import os
 
 def test_unroll_loop():
     dtype = 'int64'
@@ -24,4 +25,20 @@ def test_unroll_loop():
 
 
 if __name__ == "__main__":
-    test_unroll_loop()
+    with tvm.build_config(dump_pass_ir=True):
+        test_unroll_loop()
+
+    def end_with(*suffix):
+        ends = suffix
+        def run(s):
+            f = map(s.endswith, ends)
+            if True in f: return s
+        return run
+
+    file_list = os.listdir('./')
+    cc_file = end_with('.cc')
+    cc_file = filter(cc_file, file_list)
+    assert len(cc_file) == 3
+    for i in cc_file:
+        os.remove(i)
+    
-- 
GitLab