From 155e955fe14f0d320dc76258ae2c3bbe92d4bb71 Mon Sep 17 00:00:00 2001
From: Wuwei Lin <vincentl13x@gmail.com>
Date: Thu, 25 Oct 2018 00:37:13 +0800
Subject: [PATCH] Fix int8x4 broadcast value codegen in cuda (#1959)

---
 src/codegen/codegen_cuda.cc                | 10 ++++++++++
 tests/python/unittest/test_codegen_cuda.py | 23 ++++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
index 2ed8d8e3f..0ab56a116 100644
--- a/src/codegen/codegen_cuda.cc
+++ b/src/codegen/codegen_cuda.cc
@@ -273,6 +273,16 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
 }
 
 void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+  if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) {
+    // make_int8x4
+    const int64_t *p = as_const_int(op->value);
+    CHECK(p);
+    int64_t v = *p & 0xFF;
+    v = (v << 24) | (v << 16) | (v << 8) | v;
+    os << "(int)" << v;
+    return;
+  }
+
   std::string v = PrintExpr(op->value);
   os << "make_";
   PrintType(op->type, os);
diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py
index a0b1cf445..d3b770790 100644
--- a/tests/python/unittest/test_codegen_cuda.py
+++ b/tests/python/unittest/test_codegen_cuda.py
@@ -87,7 +87,30 @@ def test_cuda_vectorize_load():
     check_cuda("int8", 64, 8)
     check_cuda("int8", 64, 16)
 
+def test_cuda_make_int8x4():
+    def check_cuda(n, value):
+        if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        lanes = 4
+        dtype = 'int8'
+        ctx = tvm.gpu(0)
+        A = tvm.compute((n, lanes), lambda i,j: tvm.const(value, dtype=dtype))
+        s = tvm.create_schedule(A.op)
+        y, x = s[A].op.axis
+        s[A].vectorize(x)
+        s[A].bind(y, tvm.thread_axis("blockIdx.x"))
+        fun = tvm.build(s, [A], "cuda", name="make_int8x4")
+        np_a = np.full((n, lanes), value, dtype=dtype)
+        a = tvm.nd.empty(np_a.shape, dtype, ctx)
+        fun(a)
+        np.testing.assert_equal(a.asnumpy(), np_a)
+    check_cuda(64, 0xAB)
+    check_cuda(64, 0)
+    check_cuda(64, -3)
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
     test_cuda_vectorize_load()
+    test_cuda_make_int8x4()
-- 
GitLab