From d173e6374b87a70ea2cf4378a304bcb8fa9fa545 Mon Sep 17 00:00:00 2001 From: Jian Weng <werefluke@gmail.com> Date: Tue, 4 Sep 2018 22:45:17 -0700 Subject: [PATCH] [Tutorial] tutorial to writing a costumized pass (#1671) --- docs/conf.py | 1 + python/tvm/build_module.py | 2 +- tutorials/dev/README.txt | 3 + tutorials/dev/low_level_custom_pass.py | 153 +++++++++++++++++++++++++ 4 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 tutorials/dev/README.txt create mode 100644 tutorials/dev/low_level_custom_pass.py diff --git a/docs/conf.py b/docs/conf.py index 989d26f87..e3f7f6a82 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -192,6 +192,7 @@ subsection_order = ExplicitOrder( ['../tutorials/language', '../tutorials/optimize', '../tutorials/autotvm', + '../tutorials/dev', '../tutorials/vta', '../tutorials/topi', '../tutorials/deployment', diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 777654af6..70935cde1 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -368,7 +368,7 @@ def lower(sch, cfg.unroll_explicit) for f in lower_phase2: stmt = f(stmt) - # Phase 2 + # Phase 3 stmt = ir_pass.Simplify(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) diff --git a/tutorials/dev/README.txt b/tutorials/dev/README.txt new file mode 100644 index 000000000..a35828064 --- /dev/null +++ b/tutorials/dev/README.txt @@ -0,0 +1,3 @@ +Developer Tutorials +------------------- + diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py new file mode 100644 index 000000000..617093d4a --- /dev/null +++ b/tutorials/dev/low_level_custom_pass.py @@ -0,0 +1,153 @@ +""" +Writing a Customized Pass +========================= +**Author**: `Jian Weng <https://were.github.io>`_ + +TVM is a framework that abstracts away the heterogenity of machine learning accelerators. +Sometimes users may want customize some analysis and IR transformations +to adapt TVM to their own specialized hardware. This tutorial helps users write +a customized pass in TVM. + Prerequisites +------------- +Before reading this tutorial, we assume readers have already known these topics well: +- Writing an algorithm in TVM and schedule it. Otherwise, see example tutorials like + `Optimize GeMM on CPU <https://docs.tvm.ai/tutorials/optimize/opt_gemm.html>_`. +- The basic structure of HalideIR. Otherwise, see ``HalideIR/src/ir/IR.h`` to learn what + attributes of IR nodes are defined. +- Visitor design pattern. Otherwise, check the + `Python AST module <https://docs.python.org/3/library/ast.html>_` to see how an AST + visitor is implemented. +- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise, + take a look at ``python/tvm/build_module.py`` to get some basics. +""" + +from __future__ import absolute_import, print_function +import tvm +import numpy as np + +###################################################################### +# We first write a very simple vector add and build it with the default schedule. Then, we use +# our customized lowering pass to manipulate the IR directly instead of using schedule premitives. +# + +n = tvm.const(128) +a = tvm.placeholder((n, ), name="a") +b = tvm.placeholder((n, ), name="b") +c = tvm.compute((n, ), lambda i: a[i] + b[i], name='c') + +sch = tvm.create_schedule(c.op) +ir = tvm.lower(sch, [a, b, c], simple_mode=True) +print(ir) + +###################################################################### +# Writing a Pass +# -------------- +# Essentially, an "IR transformation pass" is a function which maps a statement to a new statement. +# Thus, we define this vectorize function and implement it step by step. +# + +###################################################################### +# TVM already provides two class for users to both analyze and transform IR. +# +# IR Visitor +# ~~~~~~~~~~ +# We can use ``tvm.ir_pass.PostOrderVisit(stmt, func)`` to gather information from the Halide IR. +# ``func`` is a function callback. This function will be called before exiting the current IR node, +# i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the +# return value of ``func`` will be ignored. +# +# .. note:: +# +# You MUST use some array to store the result of IR visit. Even the value is a single variable. +# This is mainly due to the constraints in the Python-C runtime. The variable values will be +# refreshed every recursion but the array values will be preserved. +# + +loops = [] +def find_width8(op): + """ Find all the 'For' nodes whose extent can be divided by 8. """ + if isinstance(op, tvm.stmt.For): + if isinstance(op.extent, tvm.expr.IntImm): + if op.extent.value % 8 == 0: + loops.append(op) + +##################################################################### +# IR Transformation +# ~~~~~~~~~~~~~~~~~ +# The transformation interface is slightly different from the visitor interface. There is only a +# post-order callback in the visitor, but transformation visitor supports both a pre-order and a +# post-order callback. If you want to keep the origin IR node, just return None. If you want to +# change the current node to some node, use TVM IR maker interface to build it and return +# this value. +# +# .. note:: +# +# If the pre-order function is called and returns a value which is not None, the post-order +# function will be skipped. +# + +def vectorize8(op): + """ Split can vectorize the loops found in `find_width8`. """ + if op in loops: + extent = op.extent.value + name = op.loop_var.name + lo, li = tvm.var(name + '.outer'), tvm.var(name + '.inner') + body = tvm.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li}) + body = tvm.make.For(li, 0, 8, tvm.stmt.For.Vectorized, 0, body) + body = tvm.make.For(lo, 0, extent // 8, tvm.stmt.For.Serial, 0, body) + return body + return None + +def vectorize(stmt): + global loops + + tvm.ir_pass.PostOrderVisit(stmt, find_width8) + + if not loops: + return stmt + + # The last list arugment indicates what kinds of nodes will be transformed. + # Thus, in this case only `For` nodes will call `vectorize8` + stmt = tvm.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) + + return stmt + +##################################################################### +# Glue to Lowering +# ---------------- +# So far, we are done with writing this IR transformation pass. What we need to do next is to glue +# this pass to TVM's lower pass. We can first call this function directly as a sanity check. +# + +print(vectorize(ir)) + +##################################################################### +# In TVM, there is a property called ``BuildConfig``. You can use this property to customize your +# own lowering options. In this case, we inject the pass written above into the TVM standard lowering +# pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different +# phases of lowering. In TVM, there are four phases of lowering and user-customized ones will be +# called after each phase is done. +# +# .. note:: +# Here are the essential transformations done by each phase: +# - Phase 0 generates the raw IR and loop levels. +# - Phase 1 flattens the array storage. +# - Phase 2 transforms loops, like unroll, vectorization and thread-binding. +# - Phase 3 does some cleanup work. +# +# Thus, a good place to put this transformation pass is just after Phase 1. +# + +with tvm.build_config(add_lower_pass=[(1, vectorize)]) as cfg: + print(tvm.lower(sch, [a, b, c], simple_mode=True)) + +##################################################################### +# Quick View +# ---------- +# This tutorial gives a quick view of writing a customized IR transformation pass: +# - Use ``tvm.ir_pass.PostOrderVisit`` to gather information on each IR nodes. +# - Use ``tvm.ir_pass.IRTransform`` to transform IR nodes. +# - Wrap up two above to write an IR-transformation function. +# - Use ``tvm.build_config`` to put this function to TVM lowering pass +# + -- GitLab