From 21e1301086f74996af2d54e81ff342aa65b6cd35 Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Tue, 21 Aug 2018 12:40:23 -0500
Subject: [PATCH] Add int8 gemm recipe (#1614)

---
 topi/recipe/gemm/gemm_int8.py | 185 ++++++++++++++++++++++++++++++++++
 1 file changed, 185 insertions(+)
 create mode 100644 topi/recipe/gemm/gemm_int8.py

diff --git a/topi/recipe/gemm/gemm_int8.py b/topi/recipe/gemm/gemm_int8.py
new file mode 100644
index 000000000..61ef97d0a
--- /dev/null
+++ b/topi/recipe/gemm/gemm_int8.py
@@ -0,0 +1,185 @@
+"Example code to perform int8 GEMM"
+import logging
+import sys
+import numpy as np
+import tvm
+from tvm import autotvm
+
+DO_TUNING = True
+PRETUNED_INDEX = 75333
+
+def intrin_dot():
+    n = 4  # dp4a requires operands packed by 4
+    x = tvm.placeholder((n,), name='x', dtype='int8')
+    y = tvm.placeholder((n,), name='y', dtype='int8')
+    k = tvm.reduce_axis((0, n), name='k')
+
+    z = tvm.compute(
+        (1,), lambda _: tvm.sum(
+            x[k].astype('int32') * y[k].astype('int32'), axis=k))
+
+    def intrin_func(ins, outs):
+        xx, yy = ins
+        zz = outs[0]
+        ib = tvm.ir_builder.create()
+
+        dp4a = zz.vstore(0, tvm.call_pure_extern('int32', '__dp4a',
+                                                 xx.vload(0, dtype='int8x4'),
+                                                 yy.vload(0, dtype='int8x4'),
+                                                 zz.vload(0)))
+        ib.emit(dp4a)
+
+        body = ib.get()
+        return body, zz.vstore(0, 0), body
+
+    with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
+        binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
+                                    data_alignment=cfg.data_alignment,
+                                    offset_factor=cfg.offset_factor,
+                                    scope='local') for t in [x, y, z]}
+        return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
+
+
+dot = intrin_dot()
+
+
+@autotvm.template
+def gemm_int8(n, m, l):
+    A = tvm.placeholder((n, l), name='A', dtype='int8')
+    B = tvm.placeholder((m, l), name='B', dtype='int8')
+
+    k = tvm.reduce_axis((0, l), name='k')
+    C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k].astype('int32') * B[j, k].astype(
+        'int32'), axis=k), name='C')
+
+    cfg = autotvm.get_config()
+    s = tvm.create_schedule(C.op)
+    y, x = C.op.axis
+
+    AA = s.cache_read(A, 'shared', [C])
+    BB = s.cache_read(B, 'shared', [C])
+    AL = s.cache_read(AA, 'local', [C])
+    BL = s.cache_read(BB, 'local', [C])
+    CC = s.cache_write(C, 'local')
+
+    k = CC.op.reduce_axis[0]
+
+    cfg.define_split('tile_k', cfg.axis(k), num_outputs=3,
+                     filter=lambda entity: entity.size[2] == 4 and \
+                     entity.size[0] * 2 >= entity.size[1])
+
+    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
+
+    s[CC].tensorize(ki, dot)
+
+    block_x = tvm.thread_axis('blockIdx.x')
+    block_y = tvm.thread_axis('blockIdx.y')
+    thread_x = tvm.thread_axis('threadIdx.x')
+    thread_y = tvm.thread_axis('threadIdx.y')
+
+    def block_size_filter(entity):
+        return entity.size[0] * 2 >= entity.size[1] * 2 and \
+                entity.size[1] <= 16 and entity.size[3] <= 4
+    cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter)
+    cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter)
+    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y)
+    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x)
+
+    s[C].bind(by, block_y)
+    s[C].bind(bx, block_x)
+    s[C].bind(tyz, tvm.thread_axis('vthread'))
+    s[C].bind(txz, tvm.thread_axis('vthread'))
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
+
+    s[CC].compute_at(s[C], tx)
+
+    yo, xo = CC.op.axis
+    s[CC].reorder(ko, kt, yo, xo, ki)
+    s[CC].unroll(kt)
+
+    for stage in [AL, BL]:
+        s[stage].compute_at(s[CC], kt)
+        _, xi = s[stage].split(stage.op.axis[1], factor=4)
+        s[stage].vectorize(xi)
+        s[stage].double_buffer()
+
+    cfg.define_knob('storage_align', [16, 48])
+    for stage in [AA, BB]:
+        s[stage].storage_align(s[stage].op.axis[0],
+                               cfg['storage_align'].val, 0)
+        s[stage].compute_at(s[CC], ko)
+
+        fused = s[stage].fuse(*s[stage].op.axis)
+        ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2])
+        tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2])
+        _, xi = s[stage].split(xi, factor=16)
+
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(xi)
+
+    cfg.define_knob('auto_unroll_max_step', [512, 1500])
+    s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
+    s[C].pragma(by, 'unroll_explicit', False)
+
+    cfg.add_flop(n*m*l*2)
+    return s, [A, B, C]
+
+
+if __name__ == '__main__':
+    N = 2048
+    n = m = l = N
+
+    logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
+    task = autotvm.task.create(gemm_int8, args=(n, m, l), target='cuda')
+    print(task.config_space)
+
+    measure_option = autotvm.measure_option(
+        measure_func='local', number=10, n_parallel=8, timeout=20)
+    log_name = 'gemm_int8.log'
+    if DO_TUNING:
+        tuner = autotvm.tuner.XGBTuner(task)
+        tuner.tune(n_trial=1000, measure_option=measure_option,
+               callbacks=[autotvm.callback.log_to_file(log_name)])
+
+        dispatch_context = autotvm.apply_history_best(log_name)
+        best_config = dispatch_context.query(task.target, task.workload)
+        print('\nBest config:')
+        print(best_config)
+    else:
+        config = task.config_space.get(PRETUNED_INDEX)
+        dispatch_context = autotvm.task.ApplyConfig(config)
+        print("Using pretuned config:")
+        print(config)
+
+    with dispatch_context:
+        with tvm.target.create('cuda'):
+            s, arg_bufs = gemm_int8(n, m, l)
+            f = tvm.build(s, arg_bufs, 'cuda', name='gemm_int8')
+
+    ctx = tvm.context('cuda', 0)
+
+    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype='int8')
+    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype='int8')
+
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    c = tvm.nd.array(np.zeros((n, m), dtype='int32'), ctx)
+    f(a, b, c)
+
+    np.testing.assert_allclose(
+        c.asnumpy(),
+        np.dot(
+            a_np.astype('int32'),
+            b_np.T.astype('int32')),
+        rtol=1e-5)
+
+    num_ops = 2 * l * m * n
+    num_runs = 1000
+    timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
+    t = timer_f(a, b, c).mean
+    GOPS = num_ops / (t * 1e3) / 1e6
+    print("average time cost of %d runs = %g ms, %g GOPS." %
+          (num_runs, t * 1e3, GOPS))
-- 
GitLab