diff --git a/python/tvm/contrib/rpc/proxy.py b/python/tvm/contrib/rpc/proxy.py index ed9d89dc9b680aedc249a91e350b1423ac2eb173..857e69433a1c2b5b650052726e10659b3f2ef430 100644 --- a/python/tvm/contrib/rpc/proxy.py +++ b/python/tvm/contrib/rpc/proxy.py @@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""): def _connect(key): conn = yield websocket.websocket_connect(url) on_message = create_on_message(conn) - temp = _server_env() + temp = _server_env(None) # Start connecton conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True) key = "server:" + key diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py index 6759f13b6753538e48f40ecccb9f6e1664919652..b49a6900b295527870e3def6d4e971b2535fd312 100644 --- a/python/tvm/contrib/rpc/server.py +++ b/python/tvm/contrib/rpc/server.py @@ -11,6 +11,7 @@ Server is TCP based with the following protocol: from __future__ import absolute_import import os +import ctypes import socket import select import struct @@ -21,12 +22,13 @@ import time from ..._ffi.function import register_func from ..._ffi.base import py_str +from ..._ffi.libinfo import find_lib_path from ...module import load as _load_module from .. import util from . import base from . base import TrackerCode -def _server_env(): +def _server_env(load_library): """Server environment function return temp dir""" temp = util.tempdir() # pylint: disable=unused-variable @@ -41,13 +43,21 @@ def _server_env(): m = _load_module(path) logging.info("load_module %s", path) return m + + libs = [] + load_library = load_library.split(":") if load_library else [] + for file_name in load_library: + file_name = find_lib_path(file_name)[0] + libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) + logging.info("Load additional library %s", file_name) + temp.libs = libs return temp -def _serve_loop(sock, addr): +def _serve_loop(sock, addr, load_library): """Server loop""" sockfd = sock.fileno() - temp = _server_env() + temp = _server_env(load_library) base._ServerLoop(sockfd) temp.remove() logging.info("Finish serving %s", addr) @@ -62,7 +72,7 @@ def _parse_server_opt(opts): return ret -def _listen_loop(sock, port, rpc_key, tracker_addr): +def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): """Lisenting loop of the server master.""" def _accept_conn(listen_sock, tracker_conn, ping_period=2): """Accept connection from the other places. @@ -162,7 +172,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): # step 3: serving logging.info("RPCServer: connection from %s", addr) - server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) + server_proc = multiprocessing.Process(target=_serve_loop, args=(conn, addr, load_library)) server_proc.deamon = True server_proc.start() # close from our side. @@ -174,7 +184,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): server_proc.terminate() -def _connect_proxy_loop(addr, key): +def _connect_proxy_loop(addr, key, load_library): key = "server:" + key retry_count = 0 max_retry = 5 @@ -198,7 +208,7 @@ def _connect_proxy_loop(addr, key): opts = _parse_server_opt(remote_key.split()[1:]) logging.info("RPCProxy connected to %s", str(addr)) process = multiprocessing.Process( - target=_serve_loop, args=(sock, addr)) + target=_serve_loop, args=(sock, addr, load_library)) process.deamon = True process.start() sock.close() @@ -256,6 +266,9 @@ class Server(object): key : str, optional The key used to identify the server in Proxy connection. + + load_library : str, optional + List of additional libraries to be loaded during execution. """ def __init__(self, host, @@ -264,7 +277,8 @@ class Server(object): is_proxy=False, use_popen=False, tracker_addr=None, - key=""): + key="", + load_library=None): try: if base._ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") @@ -283,6 +297,8 @@ class Server(object): assert key cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key] + if load_library: + cmd += ["--load-libary", load_library] self.proc = multiprocessing.Process( target=subprocess.check_call, args=(cmd,)) self.proc.deamon = True @@ -308,12 +324,12 @@ class Server(object): self.sock = sock self.proc = multiprocessing.Process( target=_listen_loop, args=( - self.sock, self.port, key, tracker_addr)) + self.sock, self.port, key, tracker_addr, load_library)) self.proc.deamon = True self.proc.start() else: self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key)) + target=_connect_proxy_loop, args=((host, port), key, load_library)) self.proc.deamon = True self.proc.start() diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 7cbe17722bf57c34968bf0c51aefdfe6c25a1abb..84bad0f5422e2efabab675ac63ba860574d98f4c 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -1,12 +1,8 @@ """Start an RPC server""" from __future__ import absolute_import -import logging import argparse -import os -import ctypes from ..contrib import rpc -from .._ffi.libinfo import find_lib_path def main(): """Main funciton""" @@ -19,26 +15,12 @@ def main(): help='The end search port of the PRC') parser.add_argument('--key', type=str, default="", help="RPC key used to identify the connection type.") - parser.add_argument('--with-executor', type=bool, default=False, - help="Whether to load executor runtime") parser.add_argument('--load-library', type=str, default="", help="Additional library to load") parser.add_argument('--tracker', type=str, default="", help="Report to RPC tracker") args = parser.parse_args() - logging.basicConfig(level=logging.INFO) - load_library = [lib for lib in args.load_library.split(":") if len(lib) != 0] - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/") - libs = [] - if args.with_executor: - load_library += ["libtvm_graph_exec.so"] - for file_name in load_library: - file_name = find_lib_path(file_name, apps_path)[0] - libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) - logging.info("Load additional library %s", file_name) - if args.tracker: url, port = args.tracker.split(":") port = int(port) @@ -53,8 +35,8 @@ def main(): args.port, args.port_end, key=args.key, - tracker_addr=tracker_addr) - server.libs += libs + tracker_addr=tracker_addr, + load_library=args.load_library) server.proc.join() if __name__ == "__main__": diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index bc99eeeb1d331288f6f7361f55c9bb045eb3c636..89c5bbc05ce4cfbf14ee47f66db0370ca7e1f037 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -10,6 +10,7 @@ #include <tvm/codegen.h> #include <string> #include <vector> +#include <functional> #include <unordered_map> #include "../runtime/meta_data.h" @@ -111,17 +112,19 @@ class CodeGenSourceBase { runtime::Module SourceModuleCreate(std::string code, std::string fmt); /*! - * \brief Create a source module for viewing and limited saving - * \param code The code to be viewed. + * \brief Create a source module for viewing and limited saving for device. + * \param data The code data to be viewed. * \param fmt The code. format. * \param fmap The map function information map of each function. * \param type_key The type_key of the runtime module of this source code + * \param fget_source a closure to replace default get source behavior. */ runtime::Module DeviceSourceModuleCreate( - std::string code, + std::string data, std::string fmt, std::unordered_map<std::string, runtime::FunctionInfo> fmap, - std::string type_key); + std::string type_key, + std::function<std::string(const std::string&)> fget_source = nullptr); } // namespace codegen } // namespace tvm #endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index fa42cefa07a7fd6f049dfeed4db59aa8c7c0500c..97858079faaa0f21a2159dbe361bc53c81402ce4 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -4,15 +4,18 @@ * \brief AMDGPU code generator. */ #ifdef TVM_LLVM_VERSION -#if TVM_ROCM_RUNTIME #include <tvm/runtime/device_api.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/registry.h> #include "./codegen_llvm.h" #include "../build_common.h" +#include "../codegen_source_base.h" #include "../../pass/ir_util.h" + +#if TVM_ROCM_RUNTIME #include "../../runtime/rocm/rocm_module.h" +#endif // TVM_ROCM_RUNTIME namespace tvm { namespace codegen { @@ -131,19 +134,27 @@ class CodeGenAMDGPU : public CodeGenLLVM { } }; -inline int DetectROCMComputeVersion() { +inline int DetectROCMComputeVersion(const std::string& target) { + size_t pos = target.find("=gfx"); + if (pos != std::string::npos) { + int value; + std::stringstream is(target.substr(pos + 4)); + if (is >> value) return value; + } TVMContext tvm_ctx; tvm_ctx.device_type = kDLROCM; tvm_ctx.device_id = 0; - TVMRetValue val; - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kExist, &val); - if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); - return val.operator int(); - } else { - return 803; + tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); + if (api != nullptr) { + TVMRetValue val; + api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); + if (val.operator int() == 1) { + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); + return val.operator int(); + } } + LOG(WARNING) << "Cannot find -mcpu to specify rocm compute version assume gfx803"; + return 803; } runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { @@ -151,7 +162,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { target.substr(0, 4) == "rocm"); std::ostringstream config; config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" - << DetectROCMComputeVersion() + << DetectROCMComputeVersion(target) << target.substr(4, target.length() - 4); llvm::TargetMachine* tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); @@ -216,7 +227,19 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { std::string hsaco = (*f)(arr); std::string ll(data_ll.begin(), data_ll.end()); +#if TVM_ROCM_RUNTIME return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); +#else + LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; + auto fget_source = [ll, assembly](const std::string& format) { + if (format.length() == 0) return assembly; + if (format == "ll" || format == "llvm") return format; + if (format == "asm") return assembly; + return std::string(""); + }; + return DeviceSourceModuleCreate( + hsaco, "hsaco", ExtractFuncInfo(funcs), "hsaco", fget_source); +#endif // TVM_ROCM_RUNTIME } TVM_REGISTER_API("codegen.build_rocm") @@ -226,5 +249,4 @@ TVM_REGISTER_API("codegen.build_rocm") } // namespace codegen } // namespace tvm -#endif // TVM_ROCM_RUNTIME #endif // TVM_LLVM_VERSION diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index 23c0cbd8466e7288182861ab78659bf49fb71ef4..69dbda49976bd5c781f0f3bf15c0664fd0874644 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -54,46 +54,71 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { } // supports limited save without cross compile -class DeviceSourceModuleNode final : public SourceModuleNode { +class DeviceSourceModuleNode final : public runtime::ModuleNode { public: - DeviceSourceModuleNode(std::string code, + DeviceSourceModuleNode(std::string data, std::string fmt, std::unordered_map<std::string, FunctionInfo> fmap, - std::string type_key) - : SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {} + std::string type_key, + std::function<std::string(const std::string&)> fget_source) + : data_(data), + fmt_(fmt), + fmap_(fmap), + type_key_(type_key), + fget_source_(fget_source) {} + + PackedFunc GetFunction( + const std::string& name, + const std::shared_ptr<ModuleNode>& sptr_to_self) final { + LOG(FATAL) << "Source module cannot execute, to get executable module" + << " build TVM with \'" << fmt_ << "\' runtime support"; + return PackedFunc(); + } + + std::string GetSource(const std::string& format) final { + if (fget_source_ != nullptr) { + return fget_source_(format); + } else { + return data_; + } + } const char* type_key() const { return type_key_.c_str(); } void SaveToFile(const std::string& file_name, - const std::string& format) final { + const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); - SaveBinaryToFile(file_name, code_); + SaveBinaryToFile(file_name, data_); } void SaveToBinary(dmlc::Stream* stream) final { stream->Write(fmt_); stream->Write(fmap_); - stream->Write(code_); + stream->Write(data_); } private: + std::string data_; + std::string fmt_; std::unordered_map<std::string, FunctionInfo> fmap_; std::string type_key_; + std::function<std::string(const std::string&)> fget_source_; }; runtime::Module DeviceSourceModuleCreate( - std::string code, + std::string data, std::string fmt, std::unordered_map<std::string, FunctionInfo> fmap, - std::string type_key) { + std::string type_key, + std::function<std::string(const std::string&)> fget_source) { std::shared_ptr<DeviceSourceModuleNode> n = - std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key); + std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source); return runtime::Module(n); } diff --git a/src/runtime/module.cc b/src/runtime/module.cc index b3cdd9c95ba66ede0dbb527acf7761a2f557c4cd..d800d6a42ee73b9b8853ac0ef56db77ef11c2920 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -121,9 +121,9 @@ bool RuntimeEnabled(const std::string& target) { } else if (target == "vpi" || target == "verilog") { f_name = "device_api.vpi"; } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") { - f_name = "codegen.build_nvptx"; + f_name = "device_api.gpu"; } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") { - f_name = "codegen.build_rocm"; + f_name = "device_api.rocm"; } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") { const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled"); if (pf == nullptr) return false; diff --git a/topi/tests/python/test_topi_l2norm.py b/topi/tests/python/test_topi_l2norm.py index aa7970125b4406e85a062588c964628569c4cc23..182099ff93674df4a6fd40bf3956ef5b41b7c494 100644 --- a/topi/tests/python/test_topi_l2norm.py +++ b/topi/tests/python/test_topi_l2norm.py @@ -41,13 +41,13 @@ def verify_l2norm(n, c, h, w, eps, axis=None): b_np = l2norm_instance_python(a_np, eps, axis) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_l2norm(B) - ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) diff --git a/topi/tests/python/test_topi_lrn.py b/topi/tests/python/test_topi_lrn.py index 6c4714077e1ebb880c4c2b94a71e51d177a4c4da..596e5747a6c5cec205463c686a68f4ba3ee79c7c 100644 --- a/topi/tests/python/test_topi_lrn.py +++ b/topi/tests/python/test_topi_lrn.py @@ -70,13 +70,13 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): b_np = lrn_python(a_np, size, axis, bias, alpha, beta) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_lrn(B) - ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) diff --git a/topi/tests/python_cpp/test_topi_dense.py b/topi/tests/python_cpp/test_topi_dense.py index 6ebd6948ee717f25c232819da1535522a84e8996..f2369af4319aba90755ee83f2b156ea64740ce51 100644 --- a/topi/tests/python_cpp/test_topi_dense.py +++ b/topi/tests/python_cpp/test_topi_dense.py @@ -29,7 +29,8 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): a_np, b_np, c_np, d_np = get_ref_data() def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -40,7 +41,6 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True): s = topi.cpp.rocm.schedule_dense(target, [D]) else: s = topi.cpp.cuda.schedule_dense(target, [D]) - ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(c_np, ctx) diff --git a/topi/tests/python_cpp/test_topi_pooling.py b/topi/tests/python_cpp/test_topi_pooling.py index e45b53dc0dec45a10785e242c24887192cffd523..ce7e65343e2ced5ad9046ec754154cc4e4ba578d 100644 --- a/topi/tests/python_cpp/test_topi_pooling.py +++ b/topi/tests/python_cpp/test_topi_pooling.py @@ -48,7 +48,8 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): b_np = np.maximum(b_np, 0.0) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -57,7 +58,6 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): s = topi.cpp.generic.default_schedule(target, [B], False) else: s = topi.cpp.cuda.schedule_pool(target, [B]) - ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py index b4c630395f6054bd7d913032e601b697e6b8c2bc..b4904d5380ca156f3560f68f0a02e9a02621c6fb 100644 --- a/topi/tests/python_cpp/test_topi_reduce.py +++ b/topi/tests/python_cpp/test_topi_reduce.py @@ -46,7 +46,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): raise NotImplementedError def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -56,7 +57,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): else: s = topi.cpp.cuda.schedule_reduce(target, [B]) - ctx = tvm.context(device, 0) foo = tvm.build(s, [A, B], device, name="sum") # Test in_npy = np.random.uniform(size=in_shape).astype(np.float32) diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index 68fad8dae707285c546925d1370118a3a0be0829..0f46be6a7b2e53ba827ef2f4a465da4e32232698 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -7,7 +7,8 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): A = tvm.placeholder(shape=in_shape, name="A") B = topi.cpp.expand_dims(A, axis, num_newaxis) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -16,7 +17,6 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): s = topi.cpp.generic.schedule_injective(target, [B]) else: s = topi.cpp.cuda.schedule_injective(target, [B]) - ctx = tvm.context(device, 0) foo = tvm.build(s, [A, B], device, name="expand_dims") data_npy = np.random.uniform(size=in_shape).astype(A.dtype) out_npy = data_npy.reshape(out_shape) @@ -33,7 +33,8 @@ def verify_tranpose(in_shape, axes): A = tvm.placeholder(shape=in_shape, name="A") B = topi.cpp.transpose(A, axes) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -59,7 +60,8 @@ def verify_reshape(src_shape, dst_shape): A = tvm.placeholder(shape=src_shape, name="A") B = topi.cpp.reshape(A, dst_shape) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -68,7 +70,6 @@ def verify_reshape(src_shape, dst_shape): s = topi.cpp.generic.schedule_injective(target, [B]) else: s = topi.cpp.cuda.schedule_injective(target, [B]) - ctx = tvm.context(device, 0) foo = tvm.build(s, [A, B], device, name="reshape") data_npy = np.random.normal(size=src_shape).astype(A.dtype) out_npy = np.reshape(data_npy, newshape=dst_shape) @@ -85,7 +86,8 @@ def verify_squeeze(src_shape, axis): A = tvm.placeholder(shape=src_shape, name="A") B = topi.cpp.squeeze(A, axis) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -94,7 +96,6 @@ def verify_squeeze(src_shape, axis): s = topi.cpp.generic.schedule_injective(target, [B]) else: s = topi.cpp.cuda.schedule_injective(target, [B]) - ctx = tvm.context(device, 0) foo = tvm.build(s, [A, B], device, name="squeeze") data_npy = np.random.normal(size=src_shape).astype(A.dtype) out_npy = np.squeeze(data_npy, axis=axis) @@ -116,7 +117,8 @@ def verify_concatenate(shapes, axis): tensor_l.append(tvm.placeholder(shape, name="A" + str(i))) out_tensor = topi.cpp.concatenate(tensor_l, axis) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) @@ -125,7 +127,6 @@ def verify_concatenate(shapes, axis): s = topi.cpp.generic.schedule_injective(target, [out_tensor]) else: s = topi.cpp.cuda.schedule_injective(target, [out_tensor]) - ctx = tvm.context(device, 0) foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate") data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes] out_npy = np.concatenate(data_npys, axis=axis) @@ -143,7 +144,8 @@ def verify_split(src_shape, indices_or_sections, axis): tensor_l = topi.cpp.split(A, indices_or_sections, axis) tensor_l = list(tensor_l) def check_device(device): - if not tvm.module.enabled(device): + ctx = tvm.context(device, 0) + if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device)