From c89cd59a87c88e806f6479b1c6c46538ed577a57 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Mon, 8 May 2017 22:19:20 -0700 Subject: [PATCH] [MODULE/DSO] Support pack everything into one shared library. (#133) * [MODULE/DSO] Support pack everything into one shared library. * fix osx load --- dmlc-core | 2 +- include/tvm/codegen.h | 10 +++++ include/tvm/runtime/module.h | 24 +++++++++++- python/tvm/module.py | 33 +++++++++++++++- src/api/api_codegen.cc | 5 +++ src/codegen/codegen.cc | 48 +++++++++++++++++++++++ src/codegen/llvm/llvm_module.cc | 6 ++- src/codegen/source_module.cc | 10 ++++- src/codegen/stack_vm/stack_vm_module.cc | 6 ++- src/codegen/verilog/verilog_module.cc | 6 ++- src/runtime/cuda/cuda_module.cc | 30 ++++++++++++-- src/runtime/dso_module.cc | 36 ++++++++++++++++- src/runtime/file_util.cc | 13 ++++++ src/runtime/meta_data.h | 6 +++ src/runtime/metal/metal_module.mm | 28 +++++++++++-- src/runtime/opencl/opencl_module.cc | 31 +++++++++++++-- tests/python/integration/test_ewise.py | 1 - tests/python/unittest/test_module_load.py | 34 ++++++++++++++++ tutorials/python/get_started.py | 17 +++++++- 19 files changed, 321 insertions(+), 25 deletions(-) diff --git a/dmlc-core b/dmlc-core index 2b75a0ce6..a6c570121 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 2b75a0ce6f191ad0fcb5319039b41e990968542a +Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314 diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index d23d9e338..32fa446f9 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -31,6 +31,16 @@ using runtime::TVMRetValue; */ runtime::Module Build(const Array<LoweredFunc>& funcs, const std::string& target); +/*! + * \brief Pack imported device library to a C file. + * Compile the C file and link with the host library + * will allow the DSO loader to automatically discover and import + * the dependency from the shared library. + * + * \param m The host module with the imports. + * \return cstr The C string representation of the file. + */ +std::string PackImportsToC(const runtime::Module& m); } // namespace codegen } // namespace tvm diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index f6364174a..ca6244835 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -8,6 +8,7 @@ #ifndef TVM_RUNTIME_MODULE_H_ #define TVM_RUNTIME_MODULE_H_ +#include <dmlc/io.h> #include <memory> #include <vector> #include <string> @@ -58,6 +59,8 @@ class Module { const std::string& format); /*! \return internal container */ inline ModuleNode* operator->(); + /*! \return internal container */ + inline const ModuleNode* operator->() const; private: std::shared_ptr<ModuleNode> node_; @@ -111,6 +114,14 @@ class ModuleNode { */ virtual void SaveToFile(const std::string& file_name, const std::string& format) = 0; + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + * \note It is recommended to implement this for device modules, + * but not necessarily host modules. + * We can use this to do AOT loading of bundled device functions. + */ + virtual void SaveToBinary(dmlc::Stream* stream) = 0; /*! * \brief Get the source code of module, when available. * \param format Format of the source code, can be empty by default. @@ -118,7 +129,6 @@ class ModuleNode { */ virtual std::string GetSource( const std::string& format = "") = 0; - /*! * \brief Get a function from current environment * The environment includes all the imports as well as Global functions. @@ -132,10 +142,12 @@ class ModuleNode { return imports_; } - private: + protected: friend class Module; /*! \brief The modules this module depend on */ std::vector<Module> imports_; + + private: /*! \brief Cache used by GetImport */ std::unordered_map<std::string, std::unique_ptr<PackedFunc> > import_cache_; @@ -145,6 +157,10 @@ class ModuleNode { namespace symbol { /*! \brief Global variable to store module context. */ constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; +/*! \brief Global variable to store device module blob */ +constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob"; +/*! \brief Number of bytes of device module blob. */ +constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes"; /*! \brief global function to set device */ constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ @@ -160,6 +176,10 @@ inline ModuleNode* Module::operator->() { return node_.get(); } +inline const ModuleNode* Module::operator->() const { + return node_.get(); +} + } // namespace runtime } // namespace tvm diff --git a/python/tvm/module.py b/python/tvm/module.py index bf601b03f..ff0df4397 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import _init_api - +from .contrib import cc_compiler as _cc, util as _util class Module(ModuleBase): """Module container of all TVM generated functions""" @@ -44,15 +44,46 @@ class Module(ModuleBase): def save(self, file_name, fmt=""): """Save the module to file. + This do not save the dependent device modules. + See also export_shared + Parameters ---------- file_name : str The name of the file. fmt : str The format of the file. + + See Also + -------- + Module.export_library : export the module to shared library. """ _SaveToFile(self, file_name, fmt) + def export_library(self, file_name): + """Export the module and its imported device code one library. + + This function only works on host llvm modules. + It will pack all the imported modules + + Parameters + ---------- + file_name : str + The name of the shared library. + """ + if self.type_key != "llvm": + raise ValueError("Only llvm support export shared") + temp = _util.tempdir() + path_obj = temp.relpath("lib.o") + self.save(path_obj) + files = [path_obj] + if self.imported_modules: + path_cc = temp.relpath("devc.cc") + with open(path_cc, "w") as f: + f.write(_PackImportsToC(self)) + files.append(path_cc) + _cc.create_shared(file_name, files) + def load(path, fmt=""): """Load module from file diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 37e0717f1..f296acff7 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -20,5 +20,10 @@ TVM_REGISTER_API("codegen._Build") *ret = Build(args[0], args[1]); } }); + +TVM_REGISTER_API("module._PackImportsToC") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = PackImportsToC(args[0]); + }); } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index 03cf8ff43..c83d3bd0f 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -7,6 +7,9 @@ #include <tvm/ir_pass.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/module.h> +#include <dmlc/memory_io.h> +#include <sstream> +#include <iostream> namespace tvm { namespace codegen { @@ -32,5 +35,50 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, return m; } +std::string PackImportsToC(const runtime::Module& mod) { + std::string bin; + dmlc::MemoryStringStream ms(&bin); + dmlc::Stream* stream = &ms; + uint64_t sz = static_cast<uint64_t>(mod->imports().size()); + stream->Write(sz); + for (runtime::Module im : mod->imports()) { + CHECK_EQ(im->imports().size(), 0U) + << "Only support simply one-level hierachy"; + std::string tkey = im->type_key(); + std::string bin; + stream->Write(tkey); + im->SaveToBinary(stream); + } + // translate to C program + std::ostringstream os; + os << "#ifdef __cplusplus\n" + << "extern \"C\" {\n" + << "#endif\n"; + os << "extern const char " << runtime::symbol::tvm_dev_mblob << "[];\n"; + os << "extern const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes << ";\n"; + os << "const char " << runtime::symbol::tvm_dev_mblob + << "[" << bin.length() << "] = {\n "; + os << std::hex; + size_t nunit = 80 / 4; + for (size_t i = 0; i < bin.length(); ++i) { + // sperators + if (i != 0) { + if (i % nunit == 0) { + os << ",\n "; + } else { + os << ","; + } + } + int c = bin[i]; + os << "0x" << (c & 0xff); + } + os << "\n};\n" + << "const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes + << " = " << std::dec << bin.length() << "UL;\n" + << "#ifdef __cplusplus\n" + << "}\n" + << "#endif\n"; + return os.str(); +} } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 27117137a..9d18ff319 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -19,7 +19,7 @@ using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; -class LLVMModuleNode : public runtime::ModuleNode { +class LLVMModuleNode final : public runtime::ModuleNode { public: ~LLVMModuleNode() { module_.reset(); @@ -84,6 +84,10 @@ class LLVMModuleNode : public runtime::ModuleNode { dest.close(); } + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; + } + std::string GetSource(const std::string& format) final { std::string type_str; llvm::raw_string_ostream rso(type_str); diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index 0bd727a27..feb0be213 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -13,7 +13,7 @@ using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; // Simulator function -class SourceModuleNode : public runtime::ModuleNode { +class SourceModuleNode final : public runtime::ModuleNode { public: SourceModuleNode(std::string code, std::string fmt) @@ -30,10 +30,16 @@ class SourceModuleNode : public runtime::ModuleNode { << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } + void SaveToFile(const std::string& file_name, const std::string& format) final { - LOG(FATAL) << "not implemented"; + LOG(FATAL) << "SourceModule: SaveToFile not supported"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "SourceModule: SaveToBinary not supported"; } + std::string GetSource(const std::string& format) final { return code_; } diff --git a/src/codegen/stack_vm/stack_vm_module.cc b/src/codegen/stack_vm/stack_vm_module.cc index ab647be6a..42680fcaa 100644 --- a/src/codegen/stack_vm/stack_vm_module.cc +++ b/src/codegen/stack_vm/stack_vm_module.cc @@ -35,7 +35,11 @@ class StackVMModuleNode : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { - LOG(FATAL) << "StackVM do not support SaveToFile"; + LOG(FATAL) << "StackVMModule: SaveToFile not supported"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "StackVMModule: SaveToBinary not supported"; } std::string GetSource(const std::string& format) final { diff --git a/src/codegen/verilog/verilog_module.cc b/src/codegen/verilog/verilog_module.cc index 53215ad91..15e96731a 100644 --- a/src/codegen/verilog/verilog_module.cc +++ b/src/codegen/verilog/verilog_module.cc @@ -59,7 +59,11 @@ class VerilogModuleNode : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { - LOG(FATAL) << "not implemented"; + LOG(FATAL) << "VerilogModule: SaveToFile not supported"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "VerilogModule: SaveToBinary not supported"; } std::string GetSource(const std::string& format) final { diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index bf4cbf98a..6a003a2b7 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -69,6 +69,12 @@ class CUDAModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(data_); + } + std::string GetSource(const std::string& format) final { if (format == fmt_) return data_; if (cuda_source_.length() != 0) { @@ -242,8 +248,8 @@ Module CUDAModuleCreate( } // Load module from module. -Module CUDAModuleLoad(const std::string& file_name, - const std::string& format) { +Module CUDAModuleLoadFile(const std::string& file_name, + const std::string& format) { std::string data; std::unordered_map<std::string, FunctionInfo> fmap; std::string fmt = GetFileFormat(file_name, format); @@ -253,14 +259,30 @@ Module CUDAModuleLoad(const std::string& file_name, return CUDAModuleCreate(data, fmt, fmap, std::string()); } +Module CUDAModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); + std::string data; + std::unordered_map<std::string, FunctionInfo> fmap; + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&data); + return CUDAModuleCreate(data, fmt, fmap, std::string()); +} + TVM_REGISTER_GLOBAL("module.loadfile_cubin") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CUDAModuleLoad(args[0], args[1]); + *rv = CUDAModuleLoadFile(args[0], args[1]); }); TVM_REGISTER_GLOBAL("module.loadfile_ptx") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CUDAModuleLoad(args[0], args[1]); + *rv = CUDAModuleLoadFile(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("module.loadbinary_cuda") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = CUDAModuleLoadBinary(args[0]); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/dso_module.cc b/src/runtime/dso_module.cc index 6147f8612..ad1a88803 100644 --- a/src/runtime/dso_module.cc +++ b/src/runtime/dso_module.cc @@ -3,6 +3,7 @@ * \file dso_module.cc * \brief Module to load from dynamic shared library. */ +#include <dmlc/memory_io.h> #include <tvm/runtime/module.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/packed_func.h> @@ -19,7 +20,7 @@ namespace runtime { // Module to load from dynamic shared libary. // This is the default module TVM used for hostside AOT -class DSOModuleNode : public ModuleNode { +class DSOModuleNode final : public ModuleNode { public: ~DSOModuleNode() { if (lib_handle_) Unload(); @@ -49,7 +50,11 @@ class DSOModuleNode : public ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { - LOG(FATAL) << "Cannot save dso to another file"; + LOG(FATAL) << "DSOModule: SaveToFile not supported"; + } + + void SaveToBinary(dmlc::Stream* stream) final { + LOG(FATAL) << "DSOModule: SaveToBinary not supported"; } std::string GetSource(const std::string& format) final { @@ -66,6 +71,33 @@ class DSOModuleNode : public ModuleNode { if (ctx_addr != nullptr) { *ctx_addr = this; } + // Load the imported modules + const char* dev_mblob = + reinterpret_cast<const char*>( + GetGlobalVPtr(runtime::symbol::tvm_dev_mblob)); + const unsigned long* dev_mblob_nbytes = // NOLINT(*) + reinterpret_cast<const unsigned long*>( // NOLINT(*) + GetGlobalVPtr(runtime::symbol::tvm_dev_mblob_nbytes)); + + if (dev_mblob != nullptr) { + CHECK(dev_mblob_nbytes != nullptr); + dmlc::MemoryFixedSizeStream fs( + (void*)dev_mblob, dev_mblob_nbytes[0]); // NOLINT(*) + dmlc::Stream* stream = &fs; + uint64_t size; + CHECK(stream->Read(&size)); + for (uint64_t i = 0; i < size; ++i) { + std::string tkey; + CHECK(stream->Read(&tkey)); + std::string fkey = "module.loadbinary_" + tkey; + const PackedFunc* f = Registry::Get(fkey); + CHECK(f != nullptr) + << "Loader of " << tkey << "(" + << fkey << ") is not presented."; + Module m = (*f)(static_cast<void*>(stream)); + this->imports_.push_back(m); + } + } } private: diff --git a/src/runtime/file_util.cc b/src/runtime/file_util.cc index 257f12b1c..a6d0d7dc1 100644 --- a/src/runtime/file_util.cc +++ b/src/runtime/file_util.cc @@ -36,6 +36,19 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { } } +void FunctionInfo::Save(dmlc::Stream* writer) const { + writer->Write(name); + writer->Write(arg_types); + writer->Write(thread_axis_tags); +} + +bool FunctionInfo::Load(dmlc::Stream* reader) { + if (!reader->Read(&name)) return false; + if (!reader->Read(&arg_types)) return false; + if (!reader->Read(&thread_axis_tags)) return false; + return true; +} + std::string GetFileFormat(const std::string& file_name, const std::string& format) { std::string fmt = format; diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 6632bced3..418e36547 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -29,7 +29,13 @@ struct FunctionInfo { void Save(dmlc::JSONWriter *writer) const; void Load(dmlc::JSONReader *reader); + void Save(dmlc::Stream *writer) const; + bool Load(dmlc::Stream *reader); }; } // namespace runtime } // namespace tvm + +namespace dmlc { +DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true); +} // namespace dmlc #endif // TVM_RUNTIME_META_DATA_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index f946ad8c7..4d8b231f3 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -6,6 +6,7 @@ #if TVM_METAL_RUNTIME +#include <dmlc/memory_io.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/module.h> #include <array> @@ -54,6 +55,11 @@ class MetalModuleNode final :public runtime::ModuleNode { SaveBinaryToFile(file_name, data_); } + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(data_); + } std::string GetSource(const std::string& format) final { if (format == fmt_) return data_; if (source_.length() != 0) { @@ -261,8 +267,8 @@ Module MetalModuleCreate( } // Load module from module. -Module MetalModuleLoad(const std::string& file_name, - const std::string& format) { +Module MetalModuleLoadFile(const std::string& file_name, + const std::string& format) { std::string data; std::unordered_map<std::string, FunctionInfo> fmap; std::string fmt = GetFileFormat(file_name, format); @@ -272,9 +278,25 @@ Module MetalModuleLoad(const std::string& file_name, return MetalModuleCreate(data, fmt, fmap, ""); } +Module MetalModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); + std::string data; + std::unordered_map<std::string, FunctionInfo> fmap; + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&data); + return MetalModuleCreate(data, fmt, fmap, ""); +} + TVM_REGISTER_GLOBAL("module.loadfile_metal") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = MetalModuleLoad(args[0], args[1]); + *rv = MetalModuleLoadFile(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("module.loadbinary_metal") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = MetalModuleLoadBinary(args[0]); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 72c9550e8..d0965d9fe 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -7,6 +7,7 @@ #if TVM_OPENCL_RUNTIME +#include <dmlc/memory_io.h> #include <tvm/runtime/registry.h> #include <vector> #include <string> @@ -78,6 +79,12 @@ class OpenCLModuleNode : public ModuleNode { SaveBinaryToFile(file_name, data_); } + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(data_); + } + std::string GetSource(const std::string& format) final { if (format == fmt_) return data_; if (fmt_ == "cl") { @@ -272,8 +279,8 @@ Module OpenCLModuleCreate( } // Load module from module. -Module OpenCLModuleLoad(const std::string& file_name, - const std::string& format) { +Module OpenCLModuleLoadFile(const std::string& file_name, + const std::string& format) { std::string data; std::unordered_map<std::string, FunctionInfo> fmap; std::string fmt = GetFileFormat(file_name, format); @@ -283,14 +290,30 @@ Module OpenCLModuleLoad(const std::string& file_name, return OpenCLModuleCreate(data, fmt, fmap); } +Module OpenCLModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); + std::string data; + std::unordered_map<std::string, FunctionInfo> fmap; + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&data); + return OpenCLModuleCreate(data, fmt, fmap); +} + TVM_REGISTER_GLOBAL("module.loadfile_cl") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenCLModuleLoad(args[0], args[1]); + *rv = OpenCLModuleLoadFile(args[0], args[1]); }); TVM_REGISTER_GLOBAL("module.loadfile_clbin") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenCLModuleLoad(args[0], args[1]); + *rv = OpenCLModuleLoadFile(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("module.loadbinary_opencl") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = OpenCLModuleLoadBinary(args[0]); }); } // namespace runtime } // namespace tvm diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 5e4c7bf38..0c3ccaeaa 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -95,7 +95,6 @@ def test_add(): device, name="myadd") ctx = tvm.context(device, 0) - print(fadd.imported_modules[0].get_source()) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py index 602cb25a0..638623235 100644 --- a/tests/python/unittest/test_module_load.py +++ b/tests/python/unittest/test_module_load.py @@ -68,5 +68,39 @@ def test_dso_module_load(): "python %s %s %s" % (path_runtime_py, path_dso, dtype), shell=True) + +def test_device_module_dump(): + # graph + n = tvm.convert(1024) + A = tvm.placeholder((n,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + # create iter var and assign them tags. + num_thread = 8 + bx, tx = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + temp = util.tempdir() + f = tvm.build(s, [A, B], device, name="myadd") + path_dso = temp.relpath("dev_lib.so") + f.export_library(path_dso) + + f1 = tvm.module.load(path_dso) + a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + f1(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + check_device("cuda") + check_device("opencl") + check_device("metal") + if __name__ == "__main__": + test_device_module_dump() test_dso_module_load() diff --git a/tutorials/python/get_started.py b/tutorials/python/get_started.py index 96e82af19..9c00ea6d8 100644 --- a/tutorials/python/get_started.py +++ b/tutorials/python/get_started.py @@ -190,8 +190,7 @@ print(temp.listdir()) # The CPU(host) module is directly saved as a shared library(so). # There can be multiple customed format on the device code. # In our example, device code is stored in ptx, as well as a meta -# data json file. In the future we can consider pack every binary -# into one shared library. +# data json file. They can be loaded and linked seperatedly via import. # ###################################################################### @@ -207,6 +206,20 @@ fadd1.import_module(fadd1_dev) fadd1(a, b, c) np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) +###################################################################### +# Pack Everything into One Library +# -------------------------------- +# In the above example, we store the device and host code seperatedly. +# TVM also support export everything as one shared library. +# Under the hood, we pack the device modules into binary blobs and link +# them together with the host code. +# Currently we support packing of Metal, OpenCL and CUDA modules. +# +fadd_cuda.export_library(temp.relpath("myadd_pack.so")) +fadd2 = tvm.module.load(temp.relpath("myadd_pack.so")) +fadd2(a, b, c) +np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + ###################################################################### # .. note:: Runtime API and Thread-Safety # -- GitLab