diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 12218955f110e107aae203d9a12e4f6a9ae59e0a..f3d9d811eec1ad545f3591369aeda208ff62f6d6 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -131,26 +131,29 @@ class CodeGenAMDGPU : public CodeGenLLVM { } }; -runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { - CHECK(target.length( -) >= 4 && - target.substr(0, 4) == "rocm"); - - TVMContext tvmCtx; - tvmCtx.device_type = kROCM; - tvmCtx.device_id = 0; +inline int DetectROCMComputeVersion() { + TVMContext tvm_ctx; + tvm_ctx.device_type = kROCM; + tvm_ctx.device_id = 0; TVMRetValue val; - tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kExist, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( + tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kComputeVersion, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); + return val.operator int(); } else { - val = 803; + return 803; } +} - llvm::TargetMachine* tm = \ - GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" + \ - std::to_string(val.operator int())+ target.substr(4, target.length() - 4)); - +runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { + CHECK(target.length() >= 4 && + target.substr(0, 4) == "rocm"); + std::ostringstream config; + config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" + << DetectROCMComputeVersion() + << target.substr(4, target.length() - 4); + llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); cg->Init(funcs[0]->name, tm, ctx.get(), false, false); @@ -159,7 +162,6 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { } std::unique_ptr<llvm::Module> module = cg->Finish(); - llvm::SmallString<8> dataObj, data_ll, dataAsm; llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm); destObj.SetUnbuffered(); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 97a5bb84e06e7f1a9ba3223809b3320b0c375486..299e2d8483b07a0a387e19528134aa61dc045784 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -582,14 +582,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); + BasicBlock* then_value_block = builder_->GetInsertBlock(); builder_->CreateBr(end_block); builder_->SetInsertPoint(else_block); llvm::Value* else_value = MakeValue(op->args[2]); + BasicBlock* else_value_block = builder_->GetInsertBlock(); builder_->CreateBr(end_block); builder_->SetInsertPoint(end_block); llvm::PHINode* value = builder_->CreatePHI(then_value->getType(), 2); - value->addIncoming(then_value, then_block); - value->addIncoming(else_value, else_block); + value->addIncoming(then_value, then_value_block); + value->addIncoming(else_value, else_value_block); return value; } else { LOG(FATAL) << "unknown intrinsic " << op->name; diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index 0db5b2398e2ed15dafca7b1489444505fd3b0a98..ede882895f956ecc3862fd22e82856fc7a2e7ea4 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -130,12 +130,34 @@ class CodeGenNVPTX : public CodeGenLLVM { } }; +inline int DetectCUDAComputeVersion() { + TVMContext tvm_ctx; + tvm_ctx.device_type = kGPU; + tvm_ctx.device_id = 0; + TVMRetValue val; + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( + tvm_ctx, tvm::runtime::kExist, &val); + if (val.operator int() == 1) { + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( + tvm_ctx, tvm::runtime::kComputeVersion, &val); + std::string version = val; + std::istringstream is(version); + double ver; + is >> ver; + return static_cast<int>(ver * 10); + } else { + return 20; + } +} + runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); - llvm::TargetMachine* tm = GetLLVMTargetMachine( - "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" + - target.substr(5, target.length() - 5)); + std::ostringstream config; + config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" + << DetectCUDAComputeVersion() + << target.substr(5, target.length() - 5); + llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); cg->Init(funcs[0]->name, tm, ctx.get(), false, false); diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py index 6adf218d18a14f313c3ab10681c1bf1475e15f59..765f4e4f518d4825450e4d6d2d90fb780f2a4d57 100644 --- a/topi/python/topi/generic/injective.py +++ b/topi/python/topi/generic/injective.py @@ -22,6 +22,7 @@ def schedule_injective(outs): target = tvm.target.current_target(allow_none=False) if target.target_name != "llvm": raise RuntimeError("schedule_injective not registered for '%s'" % target) + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs x = outs[0] s = tvm.create_schedule([x.op for x in outs]) tvm.schedule.AutoInlineInjective(s) diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 7f45ccd03e458e94bee6ce0804429bf7e1af9de0..41f8117077ca0e95ca9a2c560e3d2a3b9ecbe4e5 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -6,6 +6,7 @@ import tvm def _default_schedule(outs, auto_inline): """Default schedule for llvm.""" target = tvm.target.current_target(allow_none=False) + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs if target.target_name != "llvm": raise RuntimeError("schedule_pool not registered for '%s'" % target) s = tvm.create_schedule([x.op for x in outs]) diff --git a/topi/recipe/gemm/cuda_gemm_square.py b/topi/recipe/gemm/cuda_gemm_square.py index 8d9fedee054747fc7ac6655cc0acb1a9a8239c82..09fed58b842f9ad8fd8bf85dfdabaa1e2dea0a17 100644 --- a/topi/recipe/gemm/cuda_gemm_square.py +++ b/topi/recipe/gemm/cuda_gemm_square.py @@ -124,11 +124,11 @@ def test_gemm(): t = timer_f(a, b, c).mean GFLOPS = num_flops / (t * 1e3) / 1e6 print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) - - for device in ['cuda', 'opencl', 'rocm']: + + for device in ["cuda", "opencl", "rocm"]: with tvm.build_config(auto_unroll_max_step=32, auto_unroll_min_depth=0, - unroll_explicit=device == 'rocm'): + unroll_explicit=(device != "cuda")): check_device(device) if __name__ == "__main__": diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 07e55a54bcf66ec2bc85b12ef154549994c3e212..a34b7a173ff23a6acc4eeed0e1198c9b27947eef 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -74,11 +74,9 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): for _ in range(1): foo(data_tvm, out_tvm) np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) + for device in ["cuda", "opencl", "metal", "llvm", "rocm"]: + check_device(device) - check_device("opencl") - check_device("cuda") - check_device("metal") - check_device("rocm") def test_reduce_map(): verify_reduce_map_ele(in_shape=(128, 24, 128, 24), diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py index 912e2a8c3e5ddc8817d6766572d0d5f93ce6f8d1..5e54e78d6d02f8aafcd4bf49791ca65f673459e4 100644 --- a/topi/tests/python/test_topi_softmax.py +++ b/topi/tests/python/test_topi_softmax.py @@ -3,6 +3,7 @@ import os import numpy as np import tvm import topi +import logging from topi.util import get_const_tuple def verify_softmax(m, n): @@ -42,8 +43,6 @@ def verify_log_softmax(m, n): # confirm lower works s = tvm.create_schedule([B.op]) tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) b_np = topi.testing.log_softmax_python(a_np) @@ -60,13 +59,15 @@ def verify_log_softmax(m, n): foo(a, b) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - for device in ['cuda', 'opencl', 'metal', 'rocm']: + for device in ["cuda", "opencl", "metal", "rocm"]: check_device(device) + def test_log_softmax(): verify_log_softmax(32, 10) verify_log_softmax(3, 4) if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) test_softmax() test_log_softmax() diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index cd5bedf6c46eb794f1923de12d91bf94d90bff20..1e47fdc45ad37128839f43d945fb74204596fa8c 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -21,10 +21,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("opencl") - check_device("cuda") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) def verify_tranpose(in_shape, axes): @@ -45,10 +43,9 @@ def verify_tranpose(in_shape, axes): foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("cuda") - check_device("opencl") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) + def verify_reshape(src_shape, dst_shape): A = tvm.placeholder(shape=src_shape, name="A") @@ -68,10 +65,9 @@ def verify_reshape(src_shape, dst_shape): foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("cuda") - check_device("opencl") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) + def verify_squeeze(src_shape, axis): A = tvm.placeholder(shape=src_shape, name="A") @@ -95,10 +91,8 @@ def verify_squeeze(src_shape, axis): foo(data_nd, out_nd) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("cuda") - check_device("opencl") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) def verify_concatenate(shapes, axis): tensor_l = [] @@ -120,10 +114,9 @@ def verify_concatenate(shapes, axis): foo(*(data_nds + [out_nd])) np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("cuda") - check_device("opencl") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) + def verify_split(src_shape, indices_or_sections, axis): A = tvm.placeholder(shape=src_shape, name="A") @@ -144,10 +137,9 @@ def verify_split(src_shape, indices_or_sections, axis): for out_nd, out_npy in zip(out_nds, out_npys): np.testing.assert_allclose(out_nd.asnumpy(), out_npy) - check_device("cuda") - check_device("opencl") - check_device("metal") - check_device("rocm") + for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]: + check_device(device) + def test_expand_dims(): verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) @@ -175,6 +167,7 @@ def test_squeeze(): def test_concatenate(): + verify_concatenate([(2,), (2,), (2,)], 0) verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1) verify_concatenate([(5, 6, 7, 3), @@ -190,9 +183,9 @@ def test_split(): verify_split((10, 12, 24), [5, 7, 9], -1) if __name__ == "__main__": + test_concatenate() test_tranpose() test_expand_dims() test_reshape() test_squeeze() - test_concatenate() test_split()