From edf09673f8a04a2ee77a30a9202453dade6c5dbb Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Fri, 14 Sep 2018 11:19:38 -0500
Subject: [PATCH] [TOPI] Add dp4a intrinsic to CUDA (#1707)

---
 topi/python/topi/cuda/tensor_intrin.py | 62 ++++++++++++++++++++++++++
 topi/recipe/gemm/gemm_int8.py          | 38 ++--------------
 2 files changed, 65 insertions(+), 35 deletions(-)
 create mode 100644 topi/python/topi/cuda/tensor_intrin.py

diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py
new file mode 100644
index 000000000..26ae7587c
--- /dev/null
+++ b/topi/python/topi/cuda/tensor_intrin.py
@@ -0,0 +1,62 @@
+"""Tensor intrinsics on CUDA."""
+#pylint: disable=invalid-name
+import tvm
+
+
+def dp4a(x_scope='local', y_scope='local', z_scope='local'):
+    """
+    Int8 dot product reduced by every 4 elements using __dp4a
+
+    Parameters
+    ----------
+    x_scope : str, optional
+        The storage scope of buffer for lhs
+    y_scope : str, optional
+        The storage scope of buffer for rhs
+    z_scope : str, optional
+        The storage scope of buffer for result
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The dp4a TensorIntrin that can be used in tensorizing schedule.
+    """
+
+    n = 4  # dp4a requires operands packed by 4
+    x = tvm.placeholder((n,), name='x', dtype='int8')
+    y = tvm.placeholder((n,), name='y', dtype='int8')
+
+    k = tvm.reduce_axis((0, n), name='rc')
+
+    z = tvm.compute((1,), lambda i: tvm.sum(
+        x[k].astype('int32') * y[k].astype('int32'), axis=[k]))
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            xx, yy = ins
+            zz = outs[0]
+
+            if index == 1:
+                return zz.vstore(0, 0)
+
+            ib = tvm.ir_builder.create()
+
+            vec_x = xx.vload(0, dtype='int8x4')
+            vec_y = yy.vload(0, dtype='int8x4')
+            prev_z = 0 if index == 0 else zz.vload(0)
+
+            new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z)
+            ib.emit(zz.vstore(0, new_z))
+
+            return ib.get()
+
+        return _instr(0), _instr(1), _instr(2) # body, reset, update
+
+    with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
+        scopes = {x: x_scope, y: y_scope, z: z_scope}
+        binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
+                                    data_alignment=cfg.data_alignment,
+                                    offset_factor=cfg.offset_factor,
+                                    scope=scopes[t]) for t in [x, y, z]}
+
+        return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
diff --git a/topi/recipe/gemm/gemm_int8.py b/topi/recipe/gemm/gemm_int8.py
index 4cce2735c..ed735dad9 100644
--- a/topi/recipe/gemm/gemm_int8.py
+++ b/topi/recipe/gemm/gemm_int8.py
@@ -4,44 +4,12 @@ import sys
 import numpy as np
 import tvm
 from tvm import autotvm
+from topi.cuda.tensor_intrin import dp4a
 
 DO_TUNING = True
 PRETUNED_INDEX = 75333
 
-def intrin_dot():
-    n = 4  # dp4a requires operands packed by 4
-    x = tvm.placeholder((n,), name='x', dtype='int8')
-    y = tvm.placeholder((n,), name='y', dtype='int8')
-    k = tvm.reduce_axis((0, n), name='k')
-
-    z = tvm.compute(
-        (1,), lambda _: tvm.sum(
-            x[k].astype('int32') * y[k].astype('int32'), axis=k))
-
-    def intrin_func(ins, outs):
-        xx, yy = ins
-        zz = outs[0]
-        ib = tvm.ir_builder.create()
-
-        dp4a = zz.vstore(0, tvm.call_pure_extern('int32', '__dp4a',
-                                                 xx.vload(0, dtype='int8x4'),
-                                                 yy.vload(0, dtype='int8x4'),
-                                                 zz.vload(0)))
-        ib.emit(dp4a)
-
-        body = ib.get()
-        return body, zz.vstore(0, 0), body
-
-    with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
-        binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
-                                    data_alignment=cfg.data_alignment,
-                                    offset_factor=cfg.offset_factor,
-                                    scope='local') for t in [x, y, z]}
-        return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
-
-
-dot = intrin_dot()
-
+intrin_dp4a = dp4a('local', 'local', 'local')
 
 @autotvm.template
 def gemm_int8(n, m, l):
@@ -70,7 +38,7 @@ def gemm_int8(n, m, l):
 
     ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
 
-    s[CC].tensorize(ki, dot)
+    s[CC].tensorize(ki, intrin_dp4a)
 
     block_x = tvm.thread_axis('blockIdx.x')
     block_y = tvm.thread_axis('blockIdx.y')
-- 
GitLab