From d0041efdbc2aeb527f2a8c63804132b592352bd9 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Thu, 15 Jun 2017 11:17:47 -0700 Subject: [PATCH] [MODULE] support load back of .ll file into llvm module (#183) --- src/codegen/llvm/llvm_common.h | 2 ++ src/codegen/llvm/llvm_module.cc | 27 +++++++++++++++++++++-- tests/python/unittest/test_module_load.py | 3 +-- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/codegen/llvm/llvm_common.h b/src/codegen/llvm/llvm_common.h index 353f8af39..845f5041a 100644 --- a/src/codegen/llvm/llvm_common.h +++ b/src/codegen/llvm/llvm_common.h @@ -10,6 +10,7 @@ #include <llvm/ExecutionEngine/MCJIT.h> #include <llvm/Bitcode/BitcodeWriter.h> +#include <llvm/Support/SourceMgr.h> #include <llvm/IR/Value.h> #include <llvm/IR/Intrinsics.h> @@ -36,6 +37,7 @@ #include <llvm/Support/TargetSelect.h> #include <llvm/Target/TargetMachine.h> #include <llvm/Target/TargetOptions.h> +#include <llvm/IRReader/IRReader.h> #include <utility> #include <string> diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 8ccbc892f..cfa4510d1 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -92,7 +92,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { void Init(const Array<LoweredFunc>& funcs, std::string target) { InitializeLLVM(); tm_ = GetLLVMTargetMachine(target); - target_ = target; CHECK_NE(funcs.size(), 0U); ctx_ = std::make_shared<llvm::LLVMContext>(); CodeGenLLVM cg; @@ -106,6 +105,20 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_ = module_.get(); } + void LoadIR(const std::string& file_name) { + InitializeLLVM(); + ctx_ = std::make_shared<llvm::LLVMContext>(); + llvm::SMDiagnostic err; + module_ = llvm::parseIRFile(file_name, err, *ctx_); + CHECK(module_.get() != nullptr) + << "Fail to load ir file " << file_name; + std::string target = module_->getTargetTriple(); + mptr_ = module_.get(); + std::ostringstream os; + os << "llvm -target " << target; + tm_ = GetLLVMTargetMachine(os.str()); + } + private: void LazyInitJIT() { CHECK(ee_ == nullptr); @@ -127,7 +140,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; ee_ = builder.create(tm); - CHECK(ee_ != nullptr) << "Failed to initialize git engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); @@ -135,6 +147,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { void** ctx_addr = reinterpret_cast<void**>( ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx)); + // setup context address. + entry_func_ = + reinterpret_cast<const char*>( + ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main)); if (ctx_addr != nullptr) { *ctx_addr = this; } @@ -163,6 +179,13 @@ TVM_REGISTER_API("codegen.build_llvm") n->Init(args[0], args[1]); *rv = runtime::Module(n); }); + +TVM_REGISTER_API("module.loadfile_ll") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>(); + n->LoadIR(args[0]); + *rv = runtime::Module(n); + }); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py index a614de73f..711c46958 100644 --- a/tests/python/unittest/test_module_load.py +++ b/tests/python/unittest/test_module_load.py @@ -51,8 +51,7 @@ def test_dso_module_load(): cc.create_shared(path_dso, [path_obj]) f1 = tvm.module.load(path_dso) - f2 = tvm.module.load(path_dso) - + f2 = tvm.module.load(path_ll) a = tvm.nd.array(np.zeros(10, dtype=dtype)) f1(a) np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) -- GitLab