From 83b24b5b779f79eabae4c962a3216af7dd8ce1b6 Mon Sep 17 00:00:00 2001
From: Junru Shao <junrushao1994@gmail.com>
Date: Mon, 19 Nov 2018 15:32:30 -0500
Subject: [PATCH] [TOPI] Minor fix in the LSTM recipe (#2131)

---
 topi/recipe/rnn/lstm.py | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/topi/recipe/rnn/lstm.py b/topi/recipe/rnn/lstm.py
index 53ccbe598..f627d6ce8 100644
--- a/topi/recipe/rnn/lstm.py
+++ b/topi/recipe/rnn/lstm.py
@@ -1,8 +1,6 @@
 """LSTM Example, still work in progress.."""
 import tvm
-import time
 import os
-import argparse
 from tvm.contrib import nvcc
 import numpy as np
 
@@ -14,16 +12,19 @@ DETECT_GLOBAL_BARRIER = PERSIST_KERNEL
 SKIP_CHECK = False
 UNROLL_WLOAD = True
 
+
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     """Use nvcc compiler for better perf."""
     ptx =  nvcc.compile_cuda(code, target="ptx")
     return ptx
 
+
 def write_code(code, fname):
     with open(fname, "w") as f:
         f.write(code)
 
+
 @tvm.register_func
 def tvm_callback_cuda_postproc(code):
     if not os.path.exists("perf"):
@@ -33,16 +34,16 @@ def tvm_callback_cuda_postproc(code):
         code = open("perf/%s_manual.cu" % TASK).read()
     return code
 
+
 def lstm():
     if not PERSIST_KERNEL:
         raise ValueError("Non persist LSTM not yet supported")
-    detect_global_barrier = DETECT_GLOBAL_BARRIER
     num_thread_y = 8
-    num_thread_x = 16 * 3 / 2
+    num_thread_x = 16 * 3 // 2
     num_sm = 24
     n_num_step = 128
     num_step = tvm.var('num_step')
-    num_hidden = 1152 / 2
+    num_hidden = 1152 // 2
     batch_size = 1
     # Global transition matrix
     # Input hidden channel can be pre-caculated by a gemm
@@ -165,11 +166,9 @@ def lstm():
         flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
         ctx.sync()
         # measure time cost of second step.
-        tstart = time.time()
-        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
-        ctx.sync()
-        tgap = time.time() - tstart
-        print("Time cost=%g" % tgap)
+        evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000)
+        eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
+        print("Time cost=%g" % eval_result.mean)
 
     # set unroll_explicit for more readable code.
     with tvm.build_config(
-- 
GitLab