diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc
index 12218955f110e107aae203d9a12e4f6a9ae59e0a..f3d9d811eec1ad545f3591369aeda208ff62f6d6 100644
--- a/src/codegen/llvm/codegen_amdgpu.cc
+++ b/src/codegen/llvm/codegen_amdgpu.cc
@@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM {
   }
 };
 
-runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
-  CHECK(target.length(
-) >= 4 &&
-        target.substr(0, 4) == "rocm");
-
-  TVMContext tvmCtx;
-  tvmCtx.device_type = kROCM;
-  tvmCtx.device_id = 0;
+inline int DetectROCMComputeVersion() {
+  TVMContext tvm_ctx;
+  tvm_ctx.device_type = kROCM;
+  tvm_ctx.device_id = 0;
   TVMRetValue val;
-  tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kExist, &val);
+  tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
+      tvm_ctx, tvm::runtime::kExist, &val);
   if (val.operator int() == 1) {
-    tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kComputeVersion, &val);
+    tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
+    return val.operator int();
   } else {
-    val = 803;
+    return 803;
   }
+}
 
-  llvm::TargetMachine* tm = \
-    GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" + \
-    std::to_string(val.operator int())+ target.substr(4, target.length() - 4));
-
+runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
+  CHECK(target.length() >= 4 &&
+        target.substr(0, 4) == "rocm");
+  std::ostringstream config;
+  config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx"
+         << DetectROCMComputeVersion()
+         << target.substr(4, target.length() - 4);
+  llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
   std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
@@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
   }
 
   std::unique_ptr<llvm::Module> module = cg->Finish();
-
   llvm::SmallString<8> dataObj, data_ll, dataAsm;
   llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
   destObj.SetUnbuffered();
diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
index 97a5bb84e06e7f1a9ba3223809b3320b0c375486..299e2d8483b07a0a387e19528134aa61dc045784 100644
--- a/src/codegen/llvm/codegen_llvm.cc
+++ b/src/codegen/llvm/codegen_llvm.cc
@@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
     builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block);
     builder_->SetInsertPoint(then_block);
     llvm::Value* then_value = MakeValue(op->args[1]);
+    BasicBlock* then_value_block = builder_->GetInsertBlock();
     builder_->CreateBr(end_block);
     builder_->SetInsertPoint(else_block);
     llvm::Value* else_value = MakeValue(op->args[2]);
+    BasicBlock* else_value_block = builder_->GetInsertBlock();
     builder_->CreateBr(end_block);
     builder_->SetInsertPoint(end_block);
     llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2);
-    value->addIncoming(then_value, then_block);
-    value->addIncoming(else_value, else_block);
+    value->addIncoming(then_value, then_value_block);
+    value->addIncoming(else_value, else_value_block);
     return value;
   } else {
     LOG(FATAL) << "unknown intrinsic " << op->name;
diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc
index 0db5b2398e2ed15dafca7b1489444505fd3b0a98..ede882895f956ecc3862fd22e82856fc7a2e7ea4 100644
--- a/src/codegen/llvm/codegen_nvptx.cc
+++ b/src/codegen/llvm/codegen_nvptx.cc
@@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM {
   }
 };
 
+inline int DetectCUDAComputeVersion() {
+  TVMContext tvm_ctx;
+  tvm_ctx.device_type = kGPU;
+  tvm_ctx.device_id = 0;
+  TVMRetValue val;
+  tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
+      tvm_ctx, tvm::runtime::kExist, &val);
+  if (val.operator int() == 1) {
+    tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
+        tvm_ctx, tvm::runtime::kComputeVersion, &val);
+    std::string version = val;
+    std::istringstream is(version);
+    double ver;
+    is >> ver;
+    return static_cast<int>(ver * 10);
+  } else {
+    return 20;
+  }
+}
+
 runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
   CHECK(target.length() >= 5 &&
         target.substr(0, 5) == "nvptx");
-  llvm::TargetMachine* tm = GetLLVMTargetMachine(
-      "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
-      target.substr(5, target.length() - 5));
+  std::ostringstream config;
+  config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_"
+         << DetectCUDAComputeVersion()
+         << target.substr(5, target.length() - 5);
+  llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str());
   std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py
index 6adf218d18a14f313c3ab10681c1bf1475e15f59..765f4e4f518d4825450e4d6d2d90fb780f2a4d57 100644
--- a/topi/python/topi/generic/injective.py
+++ b/topi/python/topi/generic/injective.py
@@ -22,6 +22,7 @@ def schedule_injective(outs):
     target = tvm.target.current_target(allow_none=False)
     if target.target_name != "llvm":
         raise RuntimeError("schedule_injective not registered for '%s'" % target)
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     x = outs[0]
     s = tvm.create_schedule([x.op for x in outs])
     tvm.schedule.AutoInlineInjective(s)
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 7f45ccd03e458e94bee6ce0804429bf7e1af9de0..41f8117077ca0e95ca9a2c560e3d2a3b9ecbe4e5 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -6,6 +6,7 @@ import tvm
 def _default_schedule(outs, auto_inline):
     """Default schedule for llvm."""
     target = tvm.target.current_target(allow_none=False)
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     if target.target_name != "llvm":
         raise RuntimeError("schedule_pool not registered for '%s'" % target)
     s = tvm.create_schedule([x.op for x in outs])
diff --git a/topi/recipe/gemm/cuda_gemm_square.py b/topi/recipe/gemm/cuda_gemm_square.py
index 8d9fedee054747fc7ac6655cc0acb1a9a8239c82..09fed58b842f9ad8fd8bf85dfdabaa1e2dea0a17 100644
--- a/topi/recipe/gemm/cuda_gemm_square.py
+++ b/topi/recipe/gemm/cuda_gemm_square.py
@@ -124,11 +124,11 @@ def test_gemm():
         t = timer_f(a, b, c).mean
         GFLOPS = num_flops / (t * 1e3) / 1e6
         print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
-        
-    for device in ['cuda', 'opencl', 'rocm']:
+
+    for device in ["cuda", "opencl", "rocm"]:
         with tvm.build_config(auto_unroll_max_step=32,
                               auto_unroll_min_depth=0,
-                              unroll_explicit=device == 'rocm'):
+                              unroll_explicit=(device != "cuda")):
             check_device(device)
 
 if __name__ == "__main__":
diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py
index 07e55a54bcf66ec2bc85b12ef154549994c3e212..a34b7a173ff23a6acc4eeed0e1198c9b27947eef 100644
--- a/topi/tests/python/test_topi_reduce.py
+++ b/topi/tests/python/test_topi_reduce.py
@@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
         for _ in range(1):
             foo(data_tvm, out_tvm)
         np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
+    for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
+        check_device(device)
 
-    check_device("opencl")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
 
 def test_reduce_map():
     verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py
index 912e2a8c3e5ddc8817d6766572d0d5f93ce6f8d1..5e54e78d6d02f8aafcd4bf49791ca65f673459e4 100644
--- a/topi/tests/python/test_topi_softmax.py
+++ b/topi/tests/python/test_topi_softmax.py
@@ -3,6 +3,7 @@ import os
 import numpy as np
 import tvm
 import topi
+import logging
 from topi.util import get_const_tuple
 
 def verify_softmax(m, n):
@@ -42,8 +43,6 @@ def verify_log_softmax(m, n):
     # confirm lower works
     s = tvm.create_schedule([B.op])
     tvm.lower(s, [A, B], simple_mode=True)
-
-
     a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
     b_np = topi.testing.log_softmax_python(a_np)
 
@@ -60,13 +59,15 @@ def verify_log_softmax(m, n):
         foo(a, b)
         np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-    for device in ['cuda', 'opencl', 'metal', 'rocm']:
+    for device in ["cuda", "opencl", "metal", "rocm"]:
         check_device(device)
 
+
 def test_log_softmax():
     verify_log_softmax(32, 10)
     verify_log_softmax(3, 4)
 
 if __name__ == "__main__":
+    logging.basicConfig(level=logging.DEBUG)
     test_softmax()
     test_log_softmax()
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index cd5bedf6c46eb794f1923de12d91bf94d90bff20..1e47fdc45ad37128839f43d945fb74204596fa8c 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
         foo(data_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("opencl")
-    check_device("cuda")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
 
 
 def verify_tranpose(in_shape, axes):
@@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes):
         foo(data_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("cuda")
-    check_device("opencl")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
+
 
 def verify_reshape(src_shape, dst_shape):
     A = tvm.placeholder(shape=src_shape, name="A")
@@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape):
         foo(data_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("cuda")
-    check_device("opencl")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
+
 
 def verify_squeeze(src_shape, axis):
     A = tvm.placeholder(shape=src_shape, name="A")
@@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis):
         foo(data_nd, out_nd)
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("cuda")
-    check_device("opencl")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
 
 def verify_concatenate(shapes, axis):
     tensor_l = []
@@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis):
         foo(*(data_nds + [out_nd]))
         np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("cuda")
-    check_device("opencl")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
+
 
 def verify_split(src_shape, indices_or_sections, axis):
     A = tvm.placeholder(shape=src_shape, name="A")
@@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis):
         for out_nd, out_npy in zip(out_nds, out_npys):
             np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
 
-    check_device("cuda")
-    check_device("opencl")
-    check_device("metal")
-    check_device("rocm")
+    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
+        check_device(device)
+
 
 def test_expand_dims():
     verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
@@ -175,6 +167,7 @@ def test_squeeze():
 
 
 def test_concatenate():
+    verify_concatenate([(2,), (2,), (2,)], 0)
     verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
     verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
     verify_concatenate([(5, 6, 7, 3),
@@ -190,9 +183,9 @@ def test_split():
     verify_split((10, 12, 24), [5, 7, 9], -1)
 
 if __name__ == "__main__":
+    test_concatenate()
     test_tranpose()
     test_expand_dims()
     test_reshape()
     test_squeeze()
-    test_concatenate()
     test_split()