diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index adfe198ebf549ee74b06ae5c27e5f0729cbfb4c6..a0048a2ed771769fbff722369fbdbad997c0a42b 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -47,6 +47,15 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify")
     }
   });
 
+TVM_REGISTER_API("ir_pass.Substitute")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    if (args[0].IsNodeType<Stmt>()) {
+      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
+    } else {
+      *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
+    }
+  });
+
 TVM_REGISTER_API("ir_pass.Equal")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     if (args[0].IsNodeType<Stmt>()) {
diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py
index 06f24f2adaf50609059c36e70ff1db0b8ec521ee..0f500d7c704f97e72c5d5291e468715b94f2d1d0 100644
--- a/tests/python/unittest/test_hybrid_script.py
+++ b/tests/python/unittest/test_hybrid_script.py
@@ -1,7 +1,45 @@
-import tvm, inspect, sys, traceback, numpy
+import tvm, inspect, sys, traceback, numpy, nose
 from tvm.hybrid import script
 from tvm.hybrid.intrin import HYBRID_GLOBALS
 
+@nose.tools.nottest
+def run_and_check(func, args, outs, var_dict={}, target='llvm'):
+    def tvm_val_2_py_val(val):
+        val = tvm.ir_pass.Substitute(val, var_dict)
+        val = tvm.ir_pass.Simplify(val)
+        assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm))
+        return val.value
+
+    ctx = tvm.context(target, 0)
+
+    emu_args = []
+    nd_args = []
+    to_check = []
+    for i in args:
+        if isinstance(i, tvm.tensor.Tensor):
+            shape = [tvm_val_2_py_val(j) for j in i.shape]
+            if i in outs:
+                emu_args.append(numpy.zeros(shape).astype(i.dtype))
+                nd_args.append(tvm.nd.array(emu_args[-1], ctx))
+                to_check.append((nd_args[-1], emu_args[-1]))
+            else:
+                emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
+                nd_args.append(tvm.nd.array(emu_args[-1], ctx))
+        else:
+            assert isinstance(i, tvm.expr.Var)
+            emu_args.append(tvm_val_2_py_val(i))
+            nd_args.append(emu_args[-1])
+
+    func(*emu_args)
+
+    lowerd_func = tvm.lower(func(*args), args)
+    module = tvm.build(lowerd_func, target=target)
+    assert module
+    module(*nd_args)
+
+    for nd, np in to_check:
+        numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
+
 
 @script
 def outer_product(n, m, a, b, c):
@@ -45,18 +83,7 @@ def test_outer_product():
     func = tvm.lower(ir, [n, m, a, b, c])
     func = tvm.build(func)
 
-    _n = 999
-    _m = 1001
-    _a = numpy.random.rand(_n).astype('float32')
-    _b = numpy.random.rand(_m).astype('float32')
-    c_python = numpy.zeros((_n, _m), dtype='float32')
-    outer_product(_n, _m, _a, _b, c_python)
-
-    tvm_a = tvm.ndarray.array(_a)
-    tvm_b = tvm.ndarray.array(_b)
-    tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32'))
-    func(_n, _m, tvm_a, tvm_b, tvm_c)
-    numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5)
+    run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
 
     for key, _ in HYBRID_GLOBALS.items():
         assert key not in globals().keys()
@@ -135,19 +162,7 @@ def test_fanout():
     assert len(write.value.args) == 1
     assert write.value.args[0].value == 0
 
-    func = tvm.build(tvm.lower(ir, [n, a, b]))
-    assert func
-
-    np_a = numpy.random.randn(10).astype('float32')
-    np_b = numpy.zeros(7).astype('float32')
-
-    nd_a = tvm.ndarray.array(np_a)
-    nd_b = tvm.ndarray.array(np_b)
-
-    fanout(10, np_a, np_b)
-    func(10, nd_a, nd_b)
-
-    numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, rtol=1e-5, atol=1e-5)
+    run_and_check(fanout, [n, a, b], [b], {n: 10})
 
 
 @script
@@ -160,7 +175,7 @@ def test_failure():
         tvm.hybrid.parse(failure, [])
     except IOError as err:
         assert sys.version_info[0] == 2
-        print('[Warning] Python2 cannot do the failure case because "%s"' % str(err))
+        print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
     except Exception as err:
         assert str(err) == 'You CAN NEVER overwrite a loop variable!'
 
@@ -186,22 +201,7 @@ def test_looptype():
     assert jloop.for_type == tvm.stmt.For.Vectorized
     assert kloop.for_type == tvm.stmt.For.Unrolled
 
-    func = tvm.build(tvm.lower(ir, [a, b, c]))
-
-    np_a = numpy.zeros((8, )).astype('int32')
-    np_b = numpy.zeros((8, )).astype('int32')
-    np_c = numpy.zeros((8, )).astype('int32')
-
-    nd_a = tvm.ndarray.array(np_a)
-    nd_b = tvm.ndarray.array(np_b)
-    nd_c = tvm.ndarray.array(np_c)
-
-    looptype(np_a, np_b, np_c)
-    func(nd_a, nd_b, nd_c)
-
-    numpy.testing.assert_allclose(np_a, nd_a.asnumpy())
-    numpy.testing.assert_allclose(np_b, nd_b.asnumpy())
-    numpy.testing.assert_allclose(np_c, nd_c.asnumpy())
+    run_and_check(looptype, [a, b, c], [a, b, c])
 
 
 def test_if():
@@ -217,26 +217,13 @@ def test_if():
 
     a = tvm.placeholder((10, ), dtype='int32', name='a')
     b = tvm.placeholder((10, ), dtype='int32', name='b')
-    ir = if_then_else(a, b)
-    func = tvm.lower(ir, [a, b])
-    func = tvm.build(func)
-    assert func
-
-    _a = numpy.zeros((10, ), dtype = 'int32')
-    _b = numpy.zeros((10, ), dtype = 'int32')
-    if_then_else(_a, _b)
 
-    tvm_a = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
-    tvm_b = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
-    func(tvm_a, tvm_b)
+    run_and_check(if_then_else, [a, b], [a, b])
 
-    numpy.testing.assert_allclose(tvm_a.asnumpy(), _a, rtol=1e-5)
-    numpy.testing.assert_allclose(tvm_b.asnumpy(), _b, rtol=1e-5)
-    numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5)
 
 def test_bind():
     if not tvm.gpu(0).exist:
-        print('No GPU found! Skip this test!')
+        print('[Warning] No GPU found! Skip bind test!')
         return
     @script
     def vec_add(a, b, c):
@@ -246,24 +233,8 @@ def test_bind():
     a = tvm.placeholder((1000, ), dtype='float32', name='a')
     b = tvm.placeholder((1000, ), dtype='float32', name='b')
     c = tvm.placeholder((1000, ), dtype='float32', name='c')
-    ir = vec_add(a, b, c)
 
-    func = tvm.lower(ir, [a, b, c])
-    func = tvm.build(func, target = 'cuda')
-
-    _a = numpy.random.rand(1000).astype('float32')
-    _b = numpy.random.rand(1000).astype('float32')
-    _c = numpy.zeros((1000, ), dtype = 'float32')
-
-
-    tvm_a = tvm.ndarray.array(_a, tvm.gpu(0))
-    tvm_b = tvm.ndarray.array(_b, tvm.gpu(0))
-    tvm_c = tvm.ndarray.array(_c, tvm.gpu(0))
-
-    func(tvm_a, tvm_b, tvm_c)
-    vec_add(_a, _b, _c)
-
-    numpy.testing.assert_allclose(_c, tvm_c.asnumpy(), rtol=1e-5)
+    run_and_check(vec_add, [a, b, c], [c], target='cuda')
 
 def test_math_intrin():
     @script
@@ -277,9 +248,9 @@ def test_math_intrin():
         a[6] = min(a[4], a[5])
         a[7] = max(a[5], a[6])
 
-    a6 = tvm.placeholder((8, ), dtype='float32', name='a')
-    ir = intrin_real(a6)
-    func = tvm.build(tvm.lower(ir, [a6]))
+    a8 = tvm.placeholder((8, ), dtype='float32', name='a')
+    ir = intrin_real(a8)
+    func = tvm.build(tvm.lower(ir, [a8]))
     assert func
     a = numpy.arange(2, 10).astype('float32')
     tvm_a = tvm.ndarray.array(a)
@@ -312,23 +283,12 @@ def test_non_zero():
                         s = s + a[i-di, j-dj]
                 b[i-2, j-2] = s / 9.0
     try:
-        np_a = numpy.random.randn(32, 32).astype('float32')
-        np_b = numpy.zeros((30, 30), dtype='float32')
-        blur(np_a, np_b)
-
-        ph_a = tvm.placeholder((32, 32), 'float32', 'a')
-        ph_b = tvm.placeholder((30, 30), 'float32', 'b')
-        ir = tvm.hybrid.parse(blur, [ph_a, ph_b])
-        func = tvm.lower(ir, [ph_a, ph_b])
-        func = tvm.build(func)
-
-        nd_a = tvm.ndarray.array(np_a)
-        nd_b = tvm.ndarray.array(np_b)
-        func(nd_a, nd_b)
-
-        numpy.testing.assert_allclose(np_b, nd_b.asnumpy(), atol=1e-5, rtol=1e-5)
-    except IOError:
-        print('[Warning] Non-zero first test skipped by Python2')
+        a = tvm.placeholder((32, 32), 'float32', 'a')
+        b = tvm.placeholder((30, 30), 'float32', 'b')
+        run_and_check(blur, [a, b], [b])
+    except IOError as err:
+        assert sys.version_info[0] == 2
+        print('[Warning] Case test_non_zero is skipped by Python2 because "%s"' % str(err))
 
     @tvm.hybrid.script
     def triangle(a, b, c):
@@ -340,20 +300,7 @@ def test_non_zero():
     b = tvm.placeholder((10, ), dtype='float32', name='b')
     c = tvm.placeholder((10, 10), dtype='float32', name='c')
 
-    np_a = numpy.random.randn(10).astype('float32')
-    np_b = numpy.random.randn(10).astype('float32')
-    np_c = numpy.zeros((10, 10)).astype('float32')
-
-    nd_a = tvm.ndarray.array(np_a)
-    nd_b = tvm.ndarray.array(np_b)
-    nd_c = tvm.ndarray.array(np_c)
-
-    triangle(np_a, np_b, np_c)
-
-    func = tvm.build(tvm.lower(triangle(a, b, c), [a, b, c]))
-    assert func
-    func(nd_a, nd_b, nd_c)
-    numpy.testing.assert_allclose(nd_c.asnumpy(), np_c)
+    run_and_check(triangle, [a, b, c], [c])
 
 def test_allocate():
     @tvm.hybrid.script
@@ -369,19 +316,27 @@ def test_allocate():
     a = tvm.placeholder((32, 32), 'float32', 'a')
     b = tvm.placeholder((30, 30), 'float32', 'b')
 
-    func = tvm.build(tvm.lower(blur2d(a, b), [a, b]))
-    assert func
-
-    np_a = numpy.random.randn(32, 32).astype('float32')
-    np_b = numpy.zeros((30, 30)).astype('float32')
-
-    nd_a = tvm.ndarray.array(np_a)
-    nd_b = tvm.ndarray.array(np_b)
-
-    func(nd_a, nd_b)
-    blur2d(np_a, np_b)
+    run_and_check(blur2d, [a, b], [b])
+
+    if tvm.gpu().exist:
+        @tvm.hybrid.script
+        def share_vec_add(a, b, c):
+            shared = allocate((256, ), 'float32', 'shared')
+            for i in bind("threadIdx.x", 256):
+                shared[i] = a[i]
+            local = allocate((256, ), 'float32', 'local')
+            for i in bind("threadIdx.x", 256):
+                local[i] = b[i]
+            for i in bind("threadIdx.x", 256):
+                c[i] = shared[i] + local[i]
+
+        a = tvm.placeholder((256, ), dtype='float32', name='a')
+        b = tvm.placeholder((256, ), dtype='float32', name='b')
+        c = tvm.placeholder((256, ), dtype='float32', name='c')
+        run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
+    else:
+        print('[Warning] No GPU found! Skip shared mem test!')
 
-    numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, atol=1e-5, rtol=1e-5)
 
 if __name__ == "__main__":
     test_outer_product()