From 83b24b5b779f79eabae4c962a3216af7dd8ce1b6 Mon Sep 17 00:00:00 2001 From: Junru Shao <junrushao1994@gmail.com> Date: Mon, 19 Nov 2018 15:32:30 -0500 Subject: [PATCH] [TOPI] Minor fix in the LSTM recipe (#2131) --- topi/recipe/rnn/lstm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/topi/recipe/rnn/lstm.py b/topi/recipe/rnn/lstm.py index 53ccbe598..f627d6ce8 100644 --- a/topi/recipe/rnn/lstm.py +++ b/topi/recipe/rnn/lstm.py @@ -1,8 +1,6 @@ """LSTM Example, still work in progress..""" import tvm -import time import os -import argparse from tvm.contrib import nvcc import numpy as np @@ -14,16 +12,19 @@ DETECT_GLOBAL_BARRIER = PERSIST_KERNEL SKIP_CHECK = False UNROLL_WLOAD = True + @tvm.register_func def tvm_callback_cuda_compile(code): """Use nvcc compiler for better perf.""" ptx = nvcc.compile_cuda(code, target="ptx") return ptx + def write_code(code, fname): with open(fname, "w") as f: f.write(code) + @tvm.register_func def tvm_callback_cuda_postproc(code): if not os.path.exists("perf"): @@ -33,16 +34,16 @@ def tvm_callback_cuda_postproc(code): code = open("perf/%s_manual.cu" % TASK).read() return code + def lstm(): if not PERSIST_KERNEL: raise ValueError("Non persist LSTM not yet supported") - detect_global_barrier = DETECT_GLOBAL_BARRIER num_thread_y = 8 - num_thread_x = 16 * 3 / 2 + num_thread_x = 16 * 3 // 2 num_sm = 24 n_num_step = 128 num_step = tvm.var('num_step') - num_hidden = 1152 / 2 + num_hidden = 1152 // 2 batch_size = 1 # Global transition matrix # Input hidden channel can be pre-caculated by a gemm @@ -165,11 +166,9 @@ def lstm(): flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) ctx.sync() # measure time cost of second step. - tstart = time.time() - flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) - ctx.sync() - tgap = time.time() - tstart - print("Time cost=%g" % tgap) + evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000) + eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) + print("Time cost=%g" % eval_result.mean) # set unroll_explicit for more readable code. with tvm.build_config( -- GitLab