diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index af05b8ab1aa44e73b06652d80bb189e68dbc2ee8..714ec5ff5b1c5fa123746203f0d2b92395525f40 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -37,6 +37,16 @@ class LoweredFunc : public FunctionRef { using ContainerType = LoweredFuncNode; }; +/*! \brief specific type of lowered function */ +enum LoweredFuncType : int { + /*! \brief Function that can mix device and host calls */ + kMixedFunc = 0, + /*! \brief Only contains host code */ + kHostFunc = 1, + /*! \brief Only contains device code */ + kDeviceFunc = 2 +}; + /*! \brief Node container of LoweredFunc */ class LoweredFuncNode : public FunctionBaseNode { public: @@ -72,6 +82,8 @@ class LoweredFuncNode : public FunctionBaseNode { * constant Expr of given type is used. */ Map<Var, Expr> handle_data_type; + /*! \brief The type of the function */ + LoweredFuncType func_type{kMixedFunc}; /*! \brief Whether this function is packed function */ bool is_packed_func{true}; /*! \brief The body statment of the function */ @@ -90,6 +102,7 @@ class LoweredFuncNode : public FunctionBaseNode { v->Visit("args", &args); v->Visit("thread_axis", &thread_axis); v->Visit("handle_data_type", &handle_data_type); + v->Visit("func_type", &func_type); v->Visit("is_packed_func", &is_packed_func); v->Visit("body", &body); } diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 99fa5b1716e38b8a3e8d3331db9a01750b3bc817..a3a1693f61a4c2ac42471e8ddb9e452fdc05b467 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -29,13 +29,15 @@ #define TVM_DLL #endif -#include <stdint.h> -#include <stddef.h> // TVM Runtime is DLPack compatible. #include <dlpack/dlpack.h> - +#ifdef __cplusplus TVM_EXTERN_C { +#endif +#include <stdint.h> +#include <stddef.h> + /*! \brief type of array index. */ typedef int64_t tvm_index_t; @@ -405,6 +407,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); +#ifdef __cplusplus } // TVM_EXTERN_C - +#endif #endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index f780ed4ebf11aaa3fc2abf45346ed27cf7e339c0..6dfe5f122fd0b9e5cdd782d9dae02567617a400c 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -4,7 +4,7 @@ from __future__ import absolute_import import sys import ctypes -from .base import _LIB, check_call, py_str, c_str, _FFI_MODE +from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/build.py b/python/tvm/build.py index b7f7eff133e441ef6422e3ce6808236824b4bc68..45eb71bef3920bc2b3d1531a6a7f9565acf04043 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -220,30 +220,57 @@ def build(sch, if isinstance(sch, schedule.Schedule): if args is None: raise ValueError("args must be given for build from schedule") - fapi = lower(sch, args, - name=name, - binds=binds) + flist = lower(sch, args, + name=name, + binds=binds) + if isinstance(flist, collections.LoweredFunc): + flist = [flist] elif isinstance(sch, collections.LoweredFunc): if args: raise ValueError("args must be done when build from LoweredFunc") - fapi = sch + flist = [sch] + elif isinstance(sch, (list, tuple, collections.Array)): + flist = sch else: - raise ValueError("sch have to be Schedule or LoweredFunc") - # device related lowering - if BuildConfig.current.detect_global_barrier: - fapi = ir_pass.StorageSync(fapi, "global") - fapi = ir_pass.StorageSync(fapi, "shared") - warp_size = 32 if target == "cuda" else 1 - fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size) - fsplits = [s for s in ir_pass.SplitHostDevice(fapi)] - fsplits[0] = ir_pass.LowerPackedCall(fsplits[0]) - if len(fsplits) > 1: + raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") + fname_set = set() + for x in flist: + if not isinstance(x, collections.LoweredFunc): + raise ValueError("sch have to be Schedule, LoweredFunc or list of LoweredFunc") + if x.name in fname_set: + raise ValueError("Duplicate function name %s" % x.name) + + fhost = [] + fdevice = [] + for func in flist: + if func.func_type == collections.LoweredFunc.MixedFunc: + if BuildConfig.current.detect_global_barrier: + func = ir_pass.StorageSync(func, "global") + func = ir_pass.StorageSync(func, "shared") + warp_size = 32 if target == "cuda" else 1 + func = ir_pass.LowerThreadAllreduce(func, warp_size) + fsplits = [s for s in ir_pass.SplitHostDevice(func)] + fhost.append(fsplits[0]) + for x in fsplits[1:]: + fdevice.append(x) + elif func.func_type == collections.LoweredFunc.HostFunc: + fhost.append(func) + elif func.func_type == collections.LoweredFunc.DeviceFunc: + fdevice.append(func) + else: + raise ValueError("unknown function type %d" % func.func_type) + fhost = [ir_pass.LowerPackedCall(x) for x in fhost] + + if not target.startswith("llvm") and target != "stackvm" and not fdevice: + raise ValueError( + "Specified target %s, but cannot find device code, did you do bind?" % target) + if fdevice: if not target_host: target_host = "llvm" if module.enabled("llvm") else "stackvm" - mhost = codegen.build_module(fsplits[0], target_host) + mhost = codegen.build_module(fhost, target_host) if target: - mdev = codegen.build_module(fsplits[1:], target) + mdev = codegen.build_module(fdevice, target) mhost.import_module(mdev) return mhost else: - return codegen.build_module(fsplits[0], target) + return codegen.build_module(fhost, target) diff --git a/python/tvm/collections.py b/python/tvm/collections.py index f9af60035a655d8d82b15bea782ab3b4147da4b4..4909f9144abcfd5f8b15aa3dfc38f6dcdcade7a6 100644 --- a/python/tvm/collections.py +++ b/python/tvm/collections.py @@ -68,4 +68,6 @@ class Range(NodeBase): @register_node class LoweredFunc(NodeBase): """Represent a LoweredFunc in TVM.""" - pass + MixedFunc = 0 + HostFunc = 1 + DeviceFunc = 2 diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 28783174d1617b9d49a27e030502dba87b903e1c..5c85d3a4aa84929f637b49cbc0004ec7c57ec8bf 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -96,7 +96,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { void CodeGenLLVM::InitGlobalContext() { gv_mod_ctx_ = new llvm::GlobalVariable( *module_, t_void_p_, false, - llvm::GlobalValue::LinkOnceODRLinkage, 0, "__tvm_module_ctx"); + llvm::GlobalValue::LinkOnceAnyLinkage, 0, "__tvm_module_ctx"); gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_)); gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_)); } @@ -142,21 +142,12 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { llvm::Function* f = module_->getFunction(entry_func_name); CHECK(f) << "Function " << entry_func_name << "does not in module"; - CHECK(!module_->getFunction(runtime::symbol::tvm_module_main)); - llvm::FunctionType* ftype = f->getFunctionType(); - function_ = llvm::cast<llvm::Function>( - module_->getOrInsertFunction(runtime::symbol::tvm_module_main, ftype)); - function_->setCallingConv(llvm::CallingConv::C); - std::vector<llvm::Value*> args; - for (auto it = function_->arg_begin(); - it != function_->arg_end(); ++it) { - args.push_back(&(*it)); - } - llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); - builder_->SetInsertPoint(block); - llvm::CallInst* call = builder_->CreateCall(f, args); - call->setTailCall(true); - builder_->CreateRet(call); + llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, + runtime::symbol::tvm_module_main); + global->setAlignment(1); + global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name)); } class FPassManager : public llvm::legacy::FunctionPassManager { @@ -424,7 +415,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { // create the function handle hptr = new llvm::GlobalVariable( *module_, t_tvm_func_handle_, false, - llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func"); + llvm::GlobalValue::LinkOnceAnyLinkage, 0, ".tvm_func." + fname); hptr->setAlignment(align); hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); func_handle_map_[fname] = hptr; diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 24ca48283ae5ce34a04f267ad27a93381b178760..74e62f2fd3df44ca8c054f2e8a313017d75d3ea9 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -36,8 +36,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { void PreCompile(const std::string& name, TVMContext ctx) final { if (ee_ == nullptr) LazyInitJIT(); std::lock_guard<std::mutex> lock(mutex_); + const std::string& fname = (name == runtime::symbol::tvm_module_main ? + entry_func_ : name); BackendPackedCFunc faddr = - reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name)); + reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname)); CHECK(faddr != nullptr) << "Failed to Precompile function " << name; } @@ -47,8 +49,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { const std::shared_ptr<ModuleNode>& sptr_to_self) final { if (ee_ == nullptr) LazyInitJIT(); std::lock_guard<std::mutex> lock(mutex_); + const std::string& fname = (name == runtime::symbol::tvm_module_main ? + entry_func_ : name); BackendPackedCFunc faddr = - reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name)); + reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(fname)); if (faddr == nullptr) return PackedFunc(); return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { int ret = (*faddr)( @@ -103,6 +107,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK_NE(funcs.size(), 0U); ctx_ = std::make_shared<llvm::LLVMContext>(); CodeGenLLVM cg; + entry_func_ = funcs[0]->name; cg.Init(funcs[0]->name, tm_, ctx_.get()); for (LoweredFunc f : funcs) { cg.AddFunction(f); @@ -147,6 +152,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { } // The target configuration string std::string target_; + // Name of entry function. + std::string entry_func_; // JIT lock std::mutex mutex_; // execution engine diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index feb0be21360b6f813fe11a9b94ad945867116be0..3a4be4929536f577cd73641ba8e644a1a2b0acf7 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -54,5 +54,10 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { std::make_shared<SourceModuleNode>(code, fmt); return runtime::Module(n); } + +TVM_REGISTER_GLOBAL("module.source_module_create") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = SourceModuleCreate(args[0], args[1]); + }); } // namespace codegen } // namespace tvm diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 9de57d5b84fa6088c6155fee082ce52f0ff0c914..f9f99e8fcd265dbf42041a25ddc1f708d4b3d374 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -280,6 +280,7 @@ class ThreadAllreduceBuilder : public IRMutator { LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size) { + CHECK_NE(f->func_type, kHostFunc); auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); return LoweredFunc(n); diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index a6e99f6a5c4c93e3b0bef9072738789fda8fce52..942e70339488255ca211e4f91ff8cc0b3dccebbf 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -155,6 +155,7 @@ class HostDeviceSplitter : public IRMutator { } Array<LoweredFunc> Split(LoweredFunc f) { + CHECK_EQ(f->func_type, kMixedFunc); for (auto kv : f->handle_data_type) { handle_data_type_[kv.first.get()] = kv.second; } @@ -162,6 +163,7 @@ class HostDeviceSplitter : public IRMutator { std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = this->Mutate(f->body); + n->func_type = kHostFunc; Array<LoweredFunc> ret{LoweredFunc(n)}; for (LoweredFunc x : device_funcs_) { ret.push_back(x); @@ -179,6 +181,7 @@ class HostDeviceSplitter : public IRMutator { m.visit_thread_extent_ = false; n->body = m.Mutate(body); n->name = os.str(); + n->func_type = kDeviceFunc; n->thread_axis = m.thread_axis_; // Strictly order the arguments: Var pointers, positional arguments. for (Var v : m.undefined_) { diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index f403564efb5497c0b05ca5f0d79be5a17bee8cc7..53f4e927c7a893f8c3ba36861deb5bd21fdc81f9 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -397,6 +397,7 @@ Stmt StorageSync(Stmt stmt, std::string storage_scope) { } LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) { + CHECK_NE(f->func_type, kHostFunc); auto n = std::make_shared<LoweredFuncNode>(*f.operator->()); n->body = StorageSync(f->body, storage_scope); return LoweredFunc(n); diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 6a003a2b7a498e4aba2b0483ed591241e150d610..9092a012773313451696182fb65a2d739f33d388 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -62,11 +62,17 @@ class CUDAModuleNode : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - SaveBinaryToFile(file_name, data_); + if (fmt == "cu") { + CHECK_NE(cuda_source_.length(), 0); + SaveMetaDataToFile(meta_file, fmap_); + SaveBinaryToFile(file_name, cuda_source_); + } else { + CHECK_EQ(fmt, fmt_) + << "Can only save to format=" << fmt_; + SaveMetaDataToFile(meta_file, fmap_); + SaveBinaryToFile(file_name, data_); + } } void SaveToBinary(dmlc::Stream* stream) final { diff --git a/src/runtime/dso_module.cc b/src/runtime/dso_module.cc index 80b150857b52b1743b825f01ef61d77339d2da17..645b1b42dacc3a8e910cf06df4761f9566ac82c6 100644 --- a/src/runtime/dso_module.cc +++ b/src/runtime/dso_module.cc @@ -101,6 +101,17 @@ class DSOModuleNode final : public ModuleNode { } private: + BackendPackedCFunc GetFuncPtr(const std::string& name) { + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast<const char*>( + GetGlobalVPtr(runtime::symbol::tvm_module_main)); + CHECK(entry_name!= nullptr) + << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; + return GetFuncPtr_(entry_name); + } else { + return GetFuncPtr_(name); + } + } // Platform dependent handling. #if defined(_WIN32) // library handle @@ -111,7 +122,7 @@ class DSOModuleNode final : public ModuleNode { std::wstring wname(name.begin(), name.end()); lib_handle_ = LoadLibraryW(wname.c_str()); } - BackendPackedCFunc GetFuncPtr(const std::string& name) { + BackendPackedCFunc GetFuncPtr_(const std::string& name) { return reinterpret_cast<BackendPackedCFunc>( GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*) } @@ -129,7 +140,7 @@ class DSOModuleNode final : public ModuleNode { void Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); } - BackendPackedCFunc GetFuncPtr(const std::string& name) { + BackendPackedCFunc GetFuncPtr_(const std::string& name) { return reinterpret_cast<BackendPackedCFunc>( dlsym(lib_handle_, name.c_str())); } diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 8f7b54fa5404f34487edd58586d35ad57e40b96e..0e5d8aeb5fd0b56ad44c4dbe19947c3d2485ef57 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -103,7 +103,42 @@ def test_llvm_temp_space(): c.asnumpy(), a.asnumpy() + 1 + 1) check_llvm() +def test_multiple_func(): + nn = 1024 + n = tvm.convert(nn) + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + s[C].parallel(xo) + s[C].vectorize(xi) + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + # build two functions + f2 = tvm.lower(s, [A, B, C], name="fadd1") + f1 = tvm.lower(s, [A, B, C], name="fadd2") + m = tvm.build([f1, f2], "llvm") + fadd1 = m['fadd1'] + fadd2 = m['fadd2'] + ctx = tvm.cpu(0) + # launch the kernel. + n = nn + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + fadd1(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + fadd2(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + check_llvm() + + if __name__ == "__main__": + test_multiple_func() test_llvm_add_pipeline() test_llvm_flip_pipeline() test_llvm_madd_pipeline() diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py index 638623235fd0551114f73865c159878f8fd1f461..a614de73fed76399e3e574bab5b9d917d9bae66b 100644 --- a/tests/python/unittest/test_module_load.py +++ b/tests/python/unittest/test_module_load.py @@ -101,6 +101,44 @@ def test_device_module_dump(): check_device("opencl") check_device("metal") + +def test_combine_module_llvm(): + """Test combine multiple module into one shared lib.""" + # graph + nn = 12 + n = tvm.convert(nn) + A = tvm.placeholder((n,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + + def check_llvm(): + ctx = tvm.cpu(0) + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled" ) + return + temp = util.tempdir() + fadd1 = tvm.build(s, [A, B], "llvm", name="myadd1") + fadd2 = tvm.build(s, [A, B], "llvm", name="myadd2") + path1 = temp.relpath("myadd1.o") + path2 = temp.relpath("myadd2.o") + path_dso = temp.relpath("mylib.so") + fadd1.save(path1) + fadd2.save(path2) + # create shared library with multiple functions + cc.create_shared(path_dso, [path1, path2]) + m = tvm.module.load(path_dso) + fadd1 = m['myadd1'] + fadd2 = m['myadd2'] + a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx) + fadd1(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + fadd2(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + check_llvm() + + if __name__ == "__main__": + test_combine_module_llvm() test_device_module_dump() test_dso_module_load()