From 35485307722ffb3ce58b8d21651e74b6a73b6968 Mon Sep 17 00:00:00 2001 From: Yizhi Liu <javelinjs@gmail.com> Date: Fri, 10 Nov 2017 01:56:29 +0800 Subject: [PATCH] android gemm for topi/recipe (#628) --- topi/recipe/gemm/android_gemm_square.py | 116 ++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 topi/recipe/gemm/android_gemm_square.py diff --git a/topi/recipe/gemm/android_gemm_square.py b/topi/recipe/gemm/android_gemm_square.py new file mode 100644 index 000000000..f6f3b5ab4 --- /dev/null +++ b/topi/recipe/gemm/android_gemm_square.py @@ -0,0 +1,116 @@ +"""Example code to do square matrix multiplication on Android Phone.""" +import tvm +import os +from tvm.contrib import rpc, util, ndk +import numpy as np + +# Set to be address of tvm proxy. +proxy_host = os.environ["TVM_ANDROID_RPC_PROXY_HOST"] +proxy_port = 9090 +key = "android" + +# Change target configuration. +# Run `adb shell cat /proc/cpuinfo` to find the arch. +arch = "arm64" +target = "llvm -target=%s-linux-android" % arch + +def ngflops(N): + return 2.0 * float(N * N * N) / (10**9) + +dtype = 'float32' +def evaluate(func, ctx, N, times): + a_np = np.random.uniform(size=(N, N)).astype(dtype) + b_np = np.random.uniform(size=(N, N)).astype(dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) + + time_f = func.time_evaluator(func.entry_name, ctx, number=times) + cost = time_f(a, b, c).mean + gf = ngflops(N) / cost + print('%g secs/op, %g GFLOPS' % (cost, gf)) + np.testing.assert_almost_equal(c.asnumpy(), a_np.dot(b_np), decimal=2) + +def test_gemm_gpu(N, times, bn, num_block, num_thread): + assert(bn <= N) + assert(num_thread * num_thread * 16 <= N) + assert(num_block * num_block * 2 <= N) + A = tvm.placeholder((N, N), name='A') + B = tvm.placeholder((N, N), name='Btmp') + k = tvm.reduce_axis((0, N), name='k') + + packedB = tvm.compute((N, N / bn, bn), + lambda x, y, z: B[x, y * bn + z], name = 'B') + + C = tvm.compute( + (N, N), + lambda ii, jj: tvm.sum(A[ii, k] * packedB[k, jj / bn, jj % bn], axis=k), + name='C') + + s = tvm.create_schedule(C.op) + CC = s.cache_write(C, "local") + + block_x = tvm.thread_axis("blockIdx.x") + block_y = tvm.thread_axis("blockIdx.y") + thread_x = tvm.thread_axis("threadIdx.x") + thread_y = tvm.thread_axis("threadIdx.y") + + thread_xz = tvm.thread_axis((0, 2), "vthread", name="vx") + thread_yz = tvm.thread_axis((0, 2), "vthread", name="vy") + + pby, pbi = s[packedB].split(packedB.op.axis[0], nparts=num_thread) + pbx, pbj = s[packedB].split(packedB.op.axis[1], nparts=num_thread) + s[packedB].bind(pby, thread_y) + s[packedB].bind(pbx, thread_x) + pbz, pbk = s[packedB].split(packedB.op.axis[2], factor=8) + s[packedB].vectorize(pbk) + + by, yi = s[C].split(C.op.axis[0], nparts=num_block) + bx, xi = s[C].split(C.op.axis[1], nparts=num_thread) + + s[C].bind(by, block_y) + s[C].bind(bx, thread_y) + s[C].reorder(by, bx, yi, xi) + + tyz, yi = s[C].split(yi, nparts=2) + ty, yi = s[C].split(yi, nparts=num_block) + txz, xi = s[C].split(xi, nparts=2) + tx, xi = s[C].split(xi, nparts=num_thread) + + s[C].reorder(tyz, txz, ty, tx, yi, xi) + s[C].bind(tyz, thread_yz) + s[C].bind(txz, thread_xz) + + s[C].bind(ty, block_x) + s[C].bind(tx, thread_x) + + xyi, xxi = s[C].split(xi, factor=8) + s[C].reorder(tyz, txz, ty, tx, yi, xyi, xxi) + s[C].vectorize(xxi) + + s[CC].compute_at(s[C], yi) + yo, xo = CC.op.axis + s[CC].reorder(k, yo, xo) + xo, xi = s[CC].split(xo, factor=8) + s[CC].vectorize(xi) + + ko, ki = s[CC].split(k, factor=2) + s[CC].unroll(ki) + + print(tvm.lower(s, [A, B, C], simple_mode=True)) + + f = tvm.build(s, [A, B, C], "opencl", target_host=target, name="gemm_gpu") + temp = util.tempdir() + path_dso = temp.relpath("gemm_gpu.so") + f.export_library(path_dso, ndk.create_shared) + + # connect to the proxy + remote = rpc.connect(proxy_host, proxy_port, key=key) + ctx = remote.cl(0) + remote.upload(path_dso) + f = remote.load_module("gemm_gpu.so") + + evaluate(f, ctx, N, times) + +if __name__ == "__main__": + test_gemm_gpu(1024, times=5, bn=8, num_block=2, num_thread=8) -- GitLab