From 589a26515f10e222bc523692dd152023c810a200 Mon Sep 17 00:00:00 2001
From: Zhixun Tan <phisiart@gmail.com>
Date: Tue, 27 Feb 2018 20:20:16 -0800
Subject: [PATCH] Add test case: Create a static WebGL library and run it in
 the browser. (#932)

* Add test case: Create a static WebGL library and run it in the browser.

* Add documentation for loadModuleFromFile

* Modify emscripten.createjs
---
 python/tvm/contrib/emscripten.py           |  2 +
 python/tvm/module.py                       |  8 +++-
 src/runtime/system_lib_module.cc           | 22 +++++++--
 tests/webgl/test_static_webgl_library.html | 55 ++++++++++++++++++++++
 tests/webgl/test_static_webgl_library.py   | 49 +++++++++++++++++++
 web/tvm_runtime.js                         | 37 +++++++++++++++
 6 files changed, 169 insertions(+), 4 deletions(-)
 create mode 100644 tests/webgl/test_static_webgl_library.html
 create mode 100644 tests/webgl/test_static_webgl_library.py

diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emscripten.py
index d770ce116..d263e472c 100644
--- a/python/tvm/contrib/emscripten.py
+++ b/python/tvm/contrib/emscripten.py
@@ -60,3 +60,5 @@ def create_js(output,
         msg = "Compilation error:\n"
         msg += out
         raise RuntimeError(msg)
+
+create_js.object_format = "bc"
diff --git a/python/tvm/module.py b/python/tvm/module.py
index d8b018b82..6459733fa 100644
--- a/python/tvm/module.py
+++ b/python/tvm/module.py
@@ -84,6 +84,8 @@ class Module(ModuleBase):
 
         fcompile : function(target, file_list, kwargs), optional
             Compilation function to use create dynamic library.
+            If fcompile has attribute object_format, will compile host library
+            to that format. Otherwise, will use default format "o".
 
         kwargs : dict, optiona;
             Additional arguments passed to fcompile
@@ -95,7 +97,11 @@ class Module(ModuleBase):
         if self.type_key != "llvm":
             raise ValueError("Module[%s]: Only llvm support export shared" % self.type_key)
         temp = _util.tempdir()
-        path_obj = temp.relpath("lib.o")
+        if fcompile is not None and hasattr(fcompile, "object_format"):
+            object_format = fcompile.object_format
+        else:
+            object_format = "o"
+        path_obj = temp.relpath("lib." + object_format)
         self.save(path_obj)
         files = [path_obj]
         is_system_lib = self.get_function("__tvm_is_system_module")()
diff --git a/src/runtime/system_lib_module.cc b/src/runtime/system_lib_module.cc
index db06f57e8..1f9c8ac8e 100644
--- a/src/runtime/system_lib_module.cc
+++ b/src/runtime/system_lib_module.cc
@@ -13,8 +13,8 @@ namespace runtime {
 
 class SystemLibModuleNode : public ModuleNode {
  public:
-  SystemLibModuleNode() {
-  }
+  SystemLibModuleNode() = default;
+
   const char* type_key() const final {
     return "system_lib";
   }
@@ -23,6 +23,13 @@ class SystemLibModuleNode : public ModuleNode {
       const std::string& name,
       const std::shared_ptr<ModuleNode>& sptr_to_self) final {
     std::lock_guard<std::mutex> lock(mutex_);
+
+    if (module_blob_ != nullptr) {
+      // If we previously recorded submodules, load them now.
+      ImportModuleBlob(reinterpret_cast<const char*>(module_blob_), &imports_);
+      module_blob_ = nullptr;
+    }
+
     auto it = tbl_.find(name);
     if (it != tbl_.end()) {
       return WrapPackedFunc(
@@ -38,7 +45,14 @@ class SystemLibModuleNode : public ModuleNode {
       void** ctx_addr = reinterpret_cast<void**>(ptr);
       *ctx_addr = this;
     } else if (name == symbol::tvm_dev_mblob) {
-      ImportModuleBlob(reinterpret_cast<const char*>(ptr), &imports_);
+      // Record pointer to content of submodules to be loaded.
+      // We defer loading submodules to the first call to GetFunction().
+      // The reason is that RegisterSymbol() gets called when initializing the
+      // syslib (i.e. library loading time), and the registeries aren't ready
+      // yet. Therefore, we might not have the functionality to load submodules
+      // now.
+      CHECK(module_blob_ == nullptr) << "Resetting mobule blob?";
+      module_blob_ = ptr;
     } else {
       auto it = tbl_.find(name);
       if (it != tbl_.end()) {
@@ -65,6 +79,8 @@ class SystemLibModuleNode : public ModuleNode {
   std::mutex mutex_;
   // Internal symbol table
   std::unordered_map<std::string, void*> tbl_;
+  // Module blob to be imported
+  void* module_blob_{nullptr};
 };
 
 TVM_REGISTER_GLOBAL("module._GetSystemLib")
diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html
new file mode 100644
index 000000000..39bcb5fff
--- /dev/null
+++ b/tests/webgl/test_static_webgl_library.html
@@ -0,0 +1,55 @@
+<html>
+
+<head>
+  <meta charset="UTF-8">
+  <title>TVM RPC Test Page</title>
+</head>
+
+<body>
+  <h1>TVM Test Page</h1>
+  <div id="log"></div>
+  <canvas id="canvas"></canvas>
+  <script>
+    var Module = {};
+    Module["canvas"] = document.getElementById("canvas");
+  </script>
+  <script src="identity_static.js"></script>
+  <script src="tvm_runtime.js"></script>
+  <script>
+    var tvm = tvm_runtime.create(Module);
+    tvm.logger = function (message) {
+      console.log(message);
+      var d = document.createElement("div");
+      d.innerHTML = message;
+      document.getElementById("log").appendChild(d);
+    };
+
+    function randomArray(length, max) {
+      return Array.apply(null, Array(length)).map(function () {
+        return Math.random() * max;
+      });
+    }
+
+    setTimeout(function () {
+      this.syslib = tvm.systemLib();
+      this.identity = this.syslib.getFunction("identity");
+
+      this.n = 16;
+      this.a = randomArray(this.n, 1);
+      this.ctx = tvm.context("opengl", 0);
+      this.A = tvm.empty(this.n, "float32", ctx).copyFrom(this.a);
+      this.B = tvm.empty(this.n, "float32", ctx);
+      identity(this.A, this.B);
+
+      this.a = this.A.asArray();
+      this.b = this.B.asArray();
+      for (var i = 0; i < n; ++i) {
+        tvm.assert(this.a[i] == this.b[i]);
+      }
+      this.identity.release();
+    }, 1000);
+
+  </script>
+</body>
+
+</html>
\ No newline at end of file
diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py
new file mode 100644
index 000000000..262416c42
--- /dev/null
+++ b/tests/webgl/test_static_webgl_library.py
@@ -0,0 +1,49 @@
+"""Create a static WebGL library and run it in the browser."""
+
+from __future__ import absolute_import, print_function
+
+import os, shutil, SimpleHTTPServer, SocketServer
+import tvm
+from tvm.contrib import emscripten, util
+import numpy as np
+
+def try_static_webgl_library():
+    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+
+    # Change to lib/ which contains "libtvm_runtime.bc".
+    os.chdir(os.path.join(curr_path, "../../lib"))
+
+    # Create OpenGL module.
+    n = tvm.var("n")
+    A = tvm.placeholder((n,), name='A', dtype="float")
+    B = tvm.compute((n,), lambda *i: A[i], name="B")
+
+    s = tvm.create_schedule(B.op)
+    s[B].opengl()
+
+    target_host = "llvm -target=asmjs-unknown-emscripten -system-lib"
+    f = tvm.build(s, [A, B], name="identity", target="opengl",
+                  target_host=target_host)
+
+    # Create a JS library that contains both the module and the tvm runtime.
+    path_dso = "identity_static.js"
+    f.export_library(path_dso, emscripten.create_js, options=[
+        "-s", "USE_GLFW=3",
+        "-s", "USE_WEBGL2=1",
+        "-lglfw",
+    ])
+
+    # Create "tvm_runtime.js" and "identity_static.html" in lib/
+    shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"),
+                    "tvm_runtime.js")
+    shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"),
+                    "identity_static.html")
+
+    port = 8080
+    handler = SimpleHTTPServer.SimpleHTTPRequestHandler
+    httpd = SocketServer.TCPServer(("", port), handler)
+    print("Please open http://localhost:" + str(port) + "/identity_static.html")
+    httpd.serve_forever()
+
+if __name__ == "__main__":
+    try_static_webgl_library()
diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js
index df9cba94a..347532f5b 100644
--- a/web/tvm_runtime.js
+++ b/web/tvm_runtime.js
@@ -229,6 +229,14 @@ var tvm_runtime = tvm_runtime || {};
       "number"  // size_t nbytes
      ]);
 
+    var TVMModLoadFromFile = Module.cwrap
+    ("TVMModLoadFromFile",
+     "number",
+     ["string", // const char* file_name
+      "string", // const char* format
+      "number"  // TVMModuleHandle* out
+     ])
+
     //-----------------------------------------
     // Static utility functions
     // ----------------------------------------
@@ -940,6 +948,35 @@ var tvm_runtime = tvm_runtime || {};
       }
       return new RPCServer(counter);
     };
+
+    /**
+     * Load a TVM module from a library file.
+     * The file must be present in the Emscripten virtual file system.
+     * For example, you can pass "--preload-file file" or "--preload-file dir/"
+     * to "emcc" when compiling the TVM library, in order to populate files into
+     * the file system.
+     * For more detail, see:
+     * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files
+     * @param {string} file_name Path of the file to be loaded. The path refers
+     * to the Emscripten virtual file system.
+     * @param {string} format The format of the file.
+     * @return {tvm.TVMModule} The loaded module.
+     */
+    this.loadModuleFromFile = function (file_name, format) {
+      // alloc
+      var out = new RefTVMValue();
+      TVM_CALL(TVMModLoadFromFile(file_name, format, out.data));
+      var out_handle = out.asHandle();
+      // release
+      out.release();
+      if (out_handle != 0) {
+        return new TVMModule(out_handle);
+      } else {
+        return null;
+      }
+    };
+    var loadModuleFromFile = this.loadModuleFromFile;
+
     //-----------------------------------------
     // Class defintions
     // ----------------------------------------
-- 
GitLab