From d0041efdbc2aeb527f2a8c63804132b592352bd9 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 15 Jun 2017 11:17:47 -0700
Subject: [PATCH] [MODULE] support load back of .ll file into llvm module
 (#183)

---
 src/codegen/llvm/llvm_common.h            |  2 ++
 src/codegen/llvm/llvm_module.cc           | 27 +++++++++++++++++++++--
 tests/python/unittest/test_module_load.py |  3 +--
 3 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/src/codegen/llvm/llvm_common.h b/src/codegen/llvm/llvm_common.h
index 353f8af39..845f5041a 100644
--- a/src/codegen/llvm/llvm_common.h
+++ b/src/codegen/llvm/llvm_common.h
@@ -10,6 +10,7 @@
 #include <llvm/ExecutionEngine/MCJIT.h>
 
 #include <llvm/Bitcode/BitcodeWriter.h>
+#include <llvm/Support/SourceMgr.h>
 
 #include <llvm/IR/Value.h>
 #include <llvm/IR/Intrinsics.h>
@@ -36,6 +37,7 @@
 #include <llvm/Support/TargetSelect.h>
 #include <llvm/Target/TargetMachine.h>
 #include <llvm/Target/TargetOptions.h>
+#include <llvm/IRReader/IRReader.h>
 
 #include <utility>
 #include <string>
diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc
index 8ccbc892f..cfa4510d1 100644
--- a/src/codegen/llvm/llvm_module.cc
+++ b/src/codegen/llvm/llvm_module.cc
@@ -92,7 +92,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   void Init(const Array<LoweredFunc>& funcs, std::string target) {
     InitializeLLVM();
     tm_ = GetLLVMTargetMachine(target);
-    target_ = target;
     CHECK_NE(funcs.size(), 0U);
     ctx_ = std::make_shared<llvm::LLVMContext>();
     CodeGenLLVM cg;
@@ -106,6 +105,20 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     mptr_ = module_.get();
   }
 
+  void LoadIR(const std::string& file_name) {
+    InitializeLLVM();
+    ctx_ = std::make_shared<llvm::LLVMContext>();
+    llvm::SMDiagnostic err;
+    module_ = llvm::parseIRFile(file_name, err, *ctx_);
+    CHECK(module_.get() != nullptr)
+        << "Fail to load ir file " << file_name;
+    std::string target = module_->getTargetTriple();
+    mptr_ = module_.get();
+    std::ostringstream os;
+    os << "llvm -target " << target;
+    tm_ = GetLLVMTargetMachine(os.str());
+  }
+
  private:
   void LazyInitJIT() {
     CHECK(ee_ == nullptr);
@@ -127,7 +140,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
         << " and ExecutionEngine ("
         << layout.getStringRepresentation() << ")";
     ee_ = builder.create(tm);
-
     CHECK(ee_ != nullptr)
         << "Failed to initialize git engine for " << mptr_->getTargetTriple();
     ee_->runStaticConstructorsDestructors(false);
@@ -135,6 +147,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     void** ctx_addr =
         reinterpret_cast<void**>(
             ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
+    // setup context address.
+    entry_func_ =
+        reinterpret_cast<const char*>(
+            ee_->getGlobalValueAddress(runtime::symbol::tvm_module_main));
     if (ctx_addr != nullptr) {
       *ctx_addr = this;
     }
@@ -163,6 +179,13 @@ TVM_REGISTER_API("codegen.build_llvm")
     n->Init(args[0], args[1]);
     *rv = runtime::Module(n);
   });
+
+TVM_REGISTER_API("module.loadfile_ll")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
+    n->LoadIR(args[0]);
+    *rv = runtime::Module(n);
+  });
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_LLVM_VERSION
diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py
index a614de73f..711c46958 100644
--- a/tests/python/unittest/test_module_load.py
+++ b/tests/python/unittest/test_module_load.py
@@ -51,8 +51,7 @@ def test_dso_module_load():
     cc.create_shared(path_dso, [path_obj])
 
     f1 = tvm.module.load(path_dso)
-    f2 = tvm.module.load(path_dso)
-
+    f2 = tvm.module.load(path_ll)
     a = tvm.nd.array(np.zeros(10, dtype=dtype))
     f1(a)
     np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
-- 
GitLab