diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js index 347532f5bdbf51f972cff96ee6e13cc93aaabe1d..bc12f4f39c143361a983e42ee7fcd379e123aa5f 100644 --- a/web/tvm_runtime.js +++ b/web/tvm_runtime.js @@ -656,6 +656,8 @@ var tvm_runtime = tvm_runtime || {}; v = convertFunc(v); this.temp.push(v); this.setHandle(i, v._tvm_function.handle, kFuncHandle); + } else if (v instanceof TVMModule) { + this.setHandle(i, v.handle, kModuleHandle); } else { throwError("Unsupported argument type " + tp); } @@ -977,6 +979,107 @@ var tvm_runtime = tvm_runtime || {}; }; var loadModuleFromFile = this.loadModuleFromFile; + /** + * Wrapper runtime module. + * Wraps around set_input, load_params, run, and get_output. + * + * @class + * @memberof tvm + */ + function GraphModule(tvm_graph_module, ctx) { + CHECK(tvm_graph_module instanceof TVMModule, + "tvm_graph_module must be TVMModule"); + CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); + + this.tvm_graph_module = tvm_graph_module; + this.ctx = ctx; + this._set_input = tvm_graph_module.getFunction("set_input"); + this._load_params = tvm_graph_module.getFunction("load_params"); + this._run = tvm_graph_module.getFunction("run"); + this._get_output = tvm_graph_module.getFunction("get_output"); + }; + + GraphModule.prototype = { + /** + * Set input to graph module. + * + * @param {string} key The name of the input. + * @param {NDArray} value The input value. + */ + "set_input" : function(key, value) { + CHECK(typeof key == "string", "key must be string"); + CHECK(value instanceof NDArray, "value must be NDArray"); + this._set_input(key, value); + }, + + /** + * Load parameters from serialized byte array of parameter dict. + * + * @param {Uint8Array} params The serialized parameter dict. + */ + "load_params" : function(params) { + CHECK(params instanceof Uint8Array, "params must be Uint8Array"); + this._load_params(params); + }, + + /** + * Load parameters from serialized base64 string of parameter dict. + * + * @param {string} base64_params The serialized parameter dict. + */ + "load_base64_params" : function(base64_params) { + CHECK(typeof base64_params == "string", "base64_params must be string"); + var decoded_string = atob(base64_params); + var decoded_u8 = new Uint8Array(decoded_string.length); + for (var i = 0; i < decoded_string.length; i++) { + decoded_u8[i] = decoded_string[i].charCodeAt(0); + } + this.load_params(decoded_u8); + }, + + /** + * Run forward execution of the graph. + */ + "run" : function() { + this._run(); + }, + + /** + * Get index-th output to out. + * + * @param {NDArray} out The output array container. + * @return {NDArray} The output array container. + */ + "get_output" : function(index, out) { + CHECK(typeof index == "number", "index must be number"); + CHECK(out instanceof NDArray, "out must be NDArray"); + this._get_output(new TVMConstant(index, "int32"), out); + return out; + } + }; + + /** + * Create a runtime executor module given a graph and a module. + * @param {string} graph_json_str The Json string of the graph. + * @param {TVMModule} libmod The TVM module. + * @param {TVMContext} ctx The context to deploy the module. + * @return {GraphModule} Runtime graph module for executing the graph. + */ + this.createGraphRuntime = function(graph_json_str, libmod, ctx) { + CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); + CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); + CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); + + var fcreate = getGlobalFunc("tvm.graph_runtime.create"); + CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); + + var tvm_graph_module = fcreate(graph_json_str, libmod, + new TVMConstant(ctx.device_type, "int32"), + new TVMConstant(ctx.device_id, "int32")); + + return new GraphModule(tvm_graph_module, ctx); + }; + //----------------------------------------- // Class defintions // ----------------------------------------