diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 70935cde18162eeea54253395bf76d85a9bfcb9e..8e0d16286d6ad57ec6f5be8408a5980b65ea5097 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -384,8 +384,14 @@ def build(sch, target=None, target_host=None, name="default_function", - binds=None): - """Build a function with arguments as signiture. + binds=None, + postpone_host_codegen=False): + """Build a function with arguments as signature. Code will be generated + for a device specified by the target. For homogeneous execution, a module + that contains both host and device code is returned. For heterogeneous + execution, a list of lowered functions for the host and a module containing + device code are returned, but actual code generation for the host module is + postponed after code generation is finished for all devices. Parameters ---------- @@ -414,10 +420,18 @@ def build(sch, Dictionary that maps the binding of symbolic buffer to Tensor. By default, a new buffer is created for each tensor in the argument. + postpone_host_codegen : bool, optional + A bool value that indicates if code generation for the host module + should be postponed. This variable is set to be true for heterogeneous + execution. Otherwise, it is defaulted to false. + Returns ------- - f : Function, or pair of functions - The result function. + ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple + A module that combines both host and device code is returned when + postpone_host_codegen is not set. Otherwise, a list of lowered + functions for the host and a module contains only device code are + returned. Note ---- @@ -498,9 +512,15 @@ def build(sch, fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] - mhost = codegen.build_module(fhost, str(target_host)) + # Append fhost to the device module and return the updated module. All + # device modules will be imported to the host module after all of them are + # collected. + mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None + if postpone_host_codegen: + return fhost, mdev + + mhost = codegen.build_module(fhost, str(target_host)) if fdevice: - mdev = codegen.build_module(fdevice, str(target_device)) mhost.import_module(mdev) return mhost diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index e49b966e6a1e7416a3bbfcee077cd192a50a1f3e..f0e83eec0bb8eee9c929f6565e20b8cb3ac8c1fa 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -3,26 +3,24 @@ import numpy as np from .._ffi.base import string_types from .._ffi.function import get_global_func +from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base -from .. import ndarray as nd - def create(graph_json_str, libmod, ctx): """Create a runtime executor module given a graph and module. - Parameters ---------- graph_json_str : str or graph class The graph to be deployed in json format output by nnvm graph. The graph can only contain one operator(tvm_op) that points to the name of PackedFunc in the libmod. - libmod : tvm.Module The module of the corresponding function - - ctx : TVMContext - The context to deploy the module, can be local or remote. - + ctx : TVMContext or list of TVMContext + The context to deploy the module. It can be local or remote when there + is only one TVMContext. Otherwise, the first context in the list will + be used as this purpose. All context should be given for heterogeneous + execution. Returns ------- graph_module : GraphModule @@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx): graph_json_str = graph_json_str._tvm_graph_json() except AttributeError: raise ValueError("Type %s is not supported" % type(graph_json_str)) - device_type = ctx.device_type - device_id = ctx.device_id - if device_type >= rpc_base.RPC_SESS_MASK: - assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index + if isinstance(ctx, TVMContext): + ctx = [ctx] + elif not isinstance(ctx, (list, tuple)): + raise ValueError("ctx has to be the type of TVMContext or a list of " + "TVMCTVMContext") + for cur_ctx in ctx: + if not isinstance(cur_ctx, TVMContext): + raise ValueError("ctx has to be the type of TVMContext or a list " + "of TVMContext") + + # device_type_id[0], device_type_id[1] are used as the primary/fallback + # context type and id. All other ones are used as device context for + # heterogeneous execution. + num_rpc_ctx = 0 + device_type_id = [] + for cur_ctx in ctx: + device_type = cur_ctx.device_type + if device_type >= rpc_base.RPC_SESS_MASK: + assert libmod.type_key == "rpc" + assert rpc_base._SessTableIndex( + libmod) == cur_ctx._rpc_sess._tbl_index + num_rpc_ctx += 1 + device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK + device_type_id.append(device_type) + device_type_id.append(cur_ctx.device_id) + + if 0 < num_rpc_ctx < len(ctx): + raise ValueError("Either all or none of the contexts should be rpc.") + + if num_rpc_ctx == len(ctx): hmod = rpc_base._ModuleHandle(libmod) - fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create") - device_type = device_type % rpc_base.RPC_SESS_MASK - return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx) + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create") + return GraphModule(fcreate(graph_json_str, hmod, *device_type_id)) + fcreate = get_global_func("tvm.graph_runtime.create") - return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx) + return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) class GraphModule(object): @@ -58,18 +81,13 @@ class GraphModule(object): module : Module The interal tvm module that holds the actual graph functions. - ctx : TVMContext - The context this module is under - Attributes ---------- module : Module The interal tvm module that holds the actual graph functions. - - ctx : TVMContext - The context this module is under """ - def __init__(self, module, ctx): + + def __init__(self, module): self.module = module self._set_input = module["set_input"] self._run = module["run"] @@ -81,7 +99,6 @@ class GraphModule(object): except AttributeError: pass self._load_params = module["load_params"] - self.ctx = ctx def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -98,14 +115,14 @@ class GraphModule(object): Additonal arguments """ if key: - self._set_input(key, nd.array(value, ctx=self.ctx)) + self._get_input(key).copyfrom(value) if params: # upload big arrays first to avoid memory issue in rpc mode keys = list(params.keys()) keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: - self._set_input(k, nd.array(params[k], ctx=self.ctx)) + self._get_input(k).copyfrom(params[k]) def run(self, **input_dict): """Run forward execution of the graph @@ -177,7 +194,8 @@ class GraphModule(object): if hasattr(self, '_debug_get_output'): self._debug_get_output(node, out) else: - raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") + raise RuntimeError( + "Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0") return out def load_params(self, params_bytes): diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 162d616dea8a120fc2d8551c70a8552dcde791b9..a48047fe369cd0f3d168f25c665d0a9246cee34c 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -2,22 +2,26 @@ * Copyright (c) 2017 by Contributors * \file graph_runtime.cc */ +#include "graph_runtime.h" + +#include <dlpack/dlpack.h> +#include <dmlc/json.h> +#include <dmlc/memory_io.h> +#include <tvm/runtime/device_api.h> +#include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> -#include <tvm/runtime/ndarray.h> -#include <tvm/runtime/device_api.h> -#include <dmlc/memory_io.h> -#include <dmlc/json.h> -#include <numeric> +#include <tvm/runtime/serializer.h> + #include <algorithm> -#include <vector> #include <functional> -#include "graph_runtime.h" +#include <numeric> +#include <vector> namespace tvm { namespace runtime { -/*! \brief macro to do C API call */ +/*! \brief Macro to do C API call. */ #define TVM_CCALL(func) \ { \ int ret = (func); \ @@ -34,7 +38,7 @@ namespace runtime { class GraphRuntime : public ModuleNode { public: /*! - * \brief Get member function to front-end + * \brief Get member function to front-end. * \param name The name of the function. * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. @@ -58,12 +62,13 @@ class GraphRuntime : public ModuleNode { /*! * \brief Initialize the graph executor with graph and context. * \param graph_json The execution graph. - * \param module The module containing the compiled functions. - * \param ctx The context where the graph should sit on + * \param module The module containing the compiled functions for the host + * processor. + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. */ - void Init(const std::string& graph_json, - tvm::runtime::Module module, - TVMContext ctx) { + void Init(const std::string& graph_json, const tvm::runtime::Module& module, + const std::vector<TVMContext>& ctxs) { #ifndef _LIBCPP_SGX_NO_IOSTREAMS std::istringstream is(graph_json); #else @@ -72,10 +77,11 @@ class GraphRuntime : public ModuleNode { dmlc::JSONReader reader(&is); this->Load(&reader); module_ = module; - ctx_ = ctx; + ctxs_ = ctxs; this->SetupStorage(); this->SetupOpExecs(); } + /*! * \brief Get the input index given the name of input. * \param name The name of the input. @@ -92,7 +98,7 @@ class GraphRuntime : public ModuleNode { return -1; } /*! - * \brief set index-th input to the graph. + * \brief Set index-th input to the graph. * \param index The input index. * \param data_in The input data. */ @@ -134,7 +140,7 @@ class GraphRuntime : public ModuleNode { /*! * \brief Copy index-th output to data_out. * \param index The output index. - * \param data_out the output data. + * \param data_out The output data. */ void CopyOutputTo(int index, DLTensor* data_out) { CHECK_LT(static_cast<size_t>(index), outputs_.size()); @@ -172,8 +178,8 @@ class GraphRuntime : public ModuleNode { * from begining upto the index-th node and return output of index-th node. * This is costly operation and suggest to use only for debug porpose. * - * \param index: The index of the node. - * \param data_out the node data. + * \param index The index of the node. + * \param data_out The node data. */ void DebugGetNodeOutput(int index, DLTensor* data_out) { CHECK_LT(static_cast<size_t>(index), nodes_.size()); @@ -188,7 +194,7 @@ class GraphRuntime : public ModuleNode { } #endif /*! - * \brief Load parameters from binary stream + * \brief Load parameters from binary stream. * \param strm The input stream. */ void LoadParams(dmlc::Stream* strm); @@ -202,6 +208,12 @@ class GraphRuntime : public ModuleNode { } private: + // Memory pool entry. + struct PoolEntry { + size_t size; + int device_type; + PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {} + }; // Node entry struct NodeEntry { uint32_t node_id; @@ -260,7 +272,6 @@ class GraphRuntime : public ModuleNode { // JSON Loader void Load(dmlc::JSONReader *reader) { reader->BeginObject(); - std::unordered_map<std::string, std::string> dict; int bitmask = 0; std::string key; while (reader->NextObjectItem(&key)) { @@ -287,6 +298,7 @@ class GraphRuntime : public ModuleNode { struct GraphAttr { size_t storage_num_not_alloctaed{0}; std::vector<int> storage_id; + std::vector<int> device_index; std::vector<std::string> dltype; std::vector<std::vector<int64_t> > shape; // The graph attribute fields. @@ -322,6 +334,14 @@ class GraphRuntime : public ModuleNode { reader->Read(&shape); CHECK(!reader->NextArrayItem()); bitmask |= 4; + } else if (key == "device_index") { + reader->BeginArray(); + CHECK(reader->NextArrayItem()); + reader->Read(&type); + CHECK_EQ(type, "list_int"); + CHECK(reader->NextArrayItem()); + reader->Read(&device_index); + CHECK(!reader->NextArrayItem()); } else { reader->BeginArray(); CHECK(reader->NextArrayItem()); @@ -372,13 +392,14 @@ class GraphRuntime : public ModuleNode { } /*! \brief Setup the temporal storage */ void SetupStorage(); - /*! \brief Setup the executors */ + /*! \brief Setup the executors. */ void SetupOpExecs(); /*! * \brief Create a executtion function given input. - * \param attrs The node attributes + * \param attrs The node attributes. * \param args The arguments to the functor, including inputs and outputs. - * \param num_inputs Number of inputs + * \param num_inputs Number of inputs. + * \param dev_type The device type of the tvm_op. * \return The created executor. */ std::function<void()> CreateTVMOp(const TVMOpParam& attrs, @@ -392,7 +413,7 @@ class GraphRuntime : public ModuleNode { uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } - // Number of node entries + // Number of node entries. uint32_t num_node_entries() const { return node_row_ptr_.back(); } @@ -400,25 +421,25 @@ class GraphRuntime : public ModuleNode { uint32_t num_nodes() const { return static_cast<uint32_t>(nodes_.size()); } - // The graph nodes. + /*! \brief The graph nodes. */ std::vector<Node> nodes_; - // The argument nodes. + /*! \brief The argument nodes. */ std::vector<uint32_t> input_nodes_; - // used or quick entry indexing + /*! \brief Used for quick entry indexing. */ std::vector<uint32_t> node_row_ptr_; - // output entries + /*! \brief Output entries. */ std::vector<NodeEntry> outputs_; - // Additional graph attributes + /*! \brief Additional graph attributes. */ GraphAttr attrs_; - /*! \brief The code module */ + /*! \brief The code module that contains both host and device code. */ tvm::runtime::Module module_; - /*! \brief execution context */ - TVMContext ctx_; - /*! \brief common storage pool */ + /*! \brief Execution context of all devices including the host. */ + std::vector<TVMContext> ctxs_; + /*! \brief Common storage pool for all devices. */ std::vector<NDArray> storage_pool_; - /*! \brief data entry of each node */ + /*! \brief Data entry of each node. */ std::vector<NDArray> data_entry_; - /*! \brief operator on each node */ + /*! \brief Operator on each node. */ std::vector<std::function<void()> > op_execs_; }; @@ -458,12 +479,17 @@ void GraphRuntime::SetupStorage() { for (const std::string& s_type : attrs_.dltype) { vtype.push_back(tvm::runtime::String2TVMType(s_type)); } - data_entry_.resize(num_node_entries()); - // size of each storage pool entry - std::vector<size_t> pool_entry_bytes; + + // Size and device type of each storage pool entry. + std::vector<PoolEntry> pool_entry; // Find the maximum space size. for (size_t i = 0; i < attrs_.shape.size(); ++i) { int storage_id = attrs_.storage_id[i]; + // Use the fallback device if no device index is available. + int device_type = static_cast<int>(ctxs_[0].device_type); + if (!attrs_.device_index.empty()) { + device_type = attrs_.device_index[i]; + } size_t size = 1; for (int64_t sz : attrs_.shape[i]) { size *= static_cast<size_t>(sz); @@ -474,23 +500,42 @@ void GraphRuntime::SetupStorage() { CHECK_EQ(bits % 8U, 0U); size_t bytes = (bits / 8U) * size; - size_t sid = static_cast<size_t>(storage_id); - if (sid >= pool_entry_bytes.size()) { - pool_entry_bytes.resize(sid + 1, 0); + uint32_t sid = static_cast<uint32_t>(storage_id); + if (sid >= pool_entry.size()) { + pool_entry.resize(sid + 1, {0, -1}); + } else { + CHECK(pool_entry[sid].device_type == -1 || + pool_entry[sid].device_type == device_type) + << "The same pool entry cannot be assigned to multiple devices"; } - pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes); + pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); + pool_entry[sid].device_type = device_type; } + // Allocate the space. - for (size_t i = 0; i < pool_entry_bytes.size(); ++i) { + for (const auto& pit : pool_entry) { std::vector<int64_t> shape; - shape.push_back(static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4); - storage_pool_.push_back(NDArray::Empty(shape, DLDataType {kDLFloat, 32, 1}, ctx_)); + // This for loop is very fast since there are usually only a couple of + // devices available on the same hardware. + const auto& cit = + std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { + return pit.device_type == static_cast<int>(c.device_type); + }); + TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit; + shape.push_back(static_cast<int64_t>(pit.size + 3) / 4); + storage_pool_.push_back( + NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); } - // Assign the pooled entries. + + // Assign the pooled entries. A unified memory pool is used to simplifiy + // memory assignment for each node entry. The allocated memory on each device + // is mapped to this pool. + data_entry_.resize(num_node_entries()); for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size()); - data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + data_entry_[i] = + storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); } } @@ -508,8 +553,8 @@ void GraphRuntime::SetupOpExecs() { uint32_t eid = this->entry_id(nid, index); args.push_back(*(data_entry_[eid].operator->())); } - CHECK_EQ(inode.op_type, "tvm_op") - << "Can only take tvm_op as op"; + CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; + op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size()); } } @@ -543,13 +588,26 @@ std::function<void()> GraphRuntime::CreateTVMOp( t->shape = &(arg_ptr->shape_data[i]); } } + if (param.func_name == "__nop") { return [](){}; + } else if (param.func_name == "__copy") { + // Perform cross device data copy. + // Directly copy data from the input to the output. + auto fexec = [arg_ptr]() { + DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle); + DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle); + TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr)); + }; + return fexec; } - // get compiled function from module. + + // Get compiled function from the module that contains both host and device + // code. tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false); CHECK(pf != nullptr) << "no such function in module: " << param.func_name; - auto fexec = [arg_ptr, pf] () { + + auto fexec = [arg_ptr, pf]() { TVMRetValue rv; TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), @@ -562,7 +620,7 @@ std::function<void()> GraphRuntime::CreateTVMOp( PackedFunc GraphRuntime::GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) { - // return member functions during query. + // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (args[0].type_code() == kStr) { @@ -618,29 +676,53 @@ PackedFunc GraphRuntime::GetFunction( } } -Module GraphRuntimeCreate(std::string sym_json, - tvm::runtime::Module m, - int device_type, - int device_id) { - TVMContext ctx; - ctx.device_type = static_cast<DLDeviceType>(device_type); - ctx.device_id = device_id; +Module GraphRuntimeCreate(const std::string& sym_json, + const tvm::runtime::Module& m, + const std::vector<TVMContext>& ctxs) { std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>(); - exec->Init(sym_json, m, ctx); + exec->Init(sym_json, m, ctxs); return Module(exec); } +// Get all context for the host and other runtime devices. +std::vector<TVMContext> GetAllContext(const TVMArgs& args) { + // Reserve the first item as the fallback device. + std::vector<TVMContext> ret; + TVMContext ctx; + for (int i = 2; i < args.num_args; i += 2) { + int dev_type = args[i]; + ctx.device_type = static_cast<DLDeviceType>(dev_type); + ctx.device_id = args[i + 1]; + ret.push_back(ctx); + } + return ret; +} + +// 4-argument version is currently reserved to keep support of calling +// from tvm4j and javascript, since they don't have heterogeneous +// execution support yet. For heterogenenous execution, at least 5 arguments will +// be passed in. The third one is the number of devices. +// Eventually, we will only probably pass TVMContext for all the languages. TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = GraphRuntimeCreate(args[0], args[1], args[2], args[3]); + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) + << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeCreate(args[0], args[1], contexts); }); TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for " + "graph_runtime.remote_create is " + "at least 4, but it has " + << args.num_args; void* mhandle = args[1]; - *rv = GraphRuntimeCreate(args[0], - *static_cast<tvm::runtime::Module*>(mhandle), - args[2], args[3]); + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeCreate( + args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts); }); } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py new file mode 100644 index 0000000000000000000000000000000000000000..b916ee28571778690f034119d819b67f0cdbeac7 --- /dev/null +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -0,0 +1,405 @@ +# pylint: disable=too-many-locals +"""Unit tests for heterogeneous runtime""" +import json +import numpy as np + +import tvm +from tvm.contrib import graph_runtime, util +import topi + +def get_simplex_graph(host_dev_type, device_dev_type): + r""" Return the hand-crafted json object where only one copy node is + inserted. This node copies data from the target device to cpu. + The network is constructed as following: + A B + \ / + elemwise_add (gpu) + \ + copy C + \ / + elemwise_sub (cpu) + + Parameters + ---------- + host_dev_type : int + The device type of the host processor, e.g. cpu. + device_dev_type : int + The device type of the device processor, e.g. gpu, opencl, etc. + + Returns + ------- + json : json + A json encoded object. + """ + # Construct each node in the graph. + var_a = {"op": "null", "name": "A", "inputs": []} + var_b = {"op": "null", "name": "B", "inputs": []} + elemwise_add = { + "op": "tvm_op", "name": "elemwise_add", + "attrs": { + "flatten_data": "1", + "func_name": "elemwise_add", + "num_inputs": "2", + "num_outputs": "1" + }, + "inputs": [[0, 0, 0], [1, 0, 0]] + } + copy = { + "op": "tvm_op", + "name": "__copy_add_to_sub", + "attrs": { + "flatten_data": "0", + "func_name": "__copy", + "num_inputs": "1", + "num_outputs": "1" + }, + "inputs": [[2, 0, 0]] + } + var_c = {"op": "null", "name": "C", "inputs": []} + elemwise_sub = { + "op": "tvm_op", "name": "elemwise_sub", + "attrs": { + "flatten_data": "0", + "func_name": "elemwise_sub", + "num_inputs": "2", + "num_outputs": "1" + }, + "inputs": [[3, 0, 0], [4, 0, 0]] + } + + # Group the nodes. + nodes = [var_a, var_b, elemwise_add, copy, var_c, elemwise_sub] + arg_nodes = [0, 1, 4] + node_row_ptr = [0, 1, 2, 3, 4, 5, 6] + heads = [[5, 0, 0]] + shape = (4,) + attrs = { + "storage_id": ["list_int", [3, 4, 0, 1, 5, 2]], + "shape": ["list_shape", [shape, shape, shape, shape, shape, shape]], + "device_index": ["list_int", [device_dev_type, device_dev_type, + device_dev_type, host_dev_type, + host_dev_type, host_dev_type]], + "dtype": ["list_int", [0, 0, 0, 0, 0, 0]], + "dltype": ["list_str", ["float32", "float32", "float32", + "float32", "float32", "float32"]] + } + + # Construct the graph. + graph = {"nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": heads, + "attrs": attrs} + return json.dumps(graph) + + +def test_simplex_data_transferring(): + r""" + Test the heterogeneous execution of a simple network where data + transferring is from the target device to the host processor at runtime. + The host processor is always assumed to be cpu, and the device varies. + """ + host = "cpu" + target_host = "llvm" + host_ctx = tvm.context(host) + if not tvm.module.enabled(target_host): + print("Skip test because llvm is not enabled.") + return + + def check_device(device, target_device): + if not tvm.module.enabled(target_device): + print("Skip test because {} is not enabled.".format(target_device)) + return + + device_ctx = tvm.context(device) + graph = get_simplex_graph(host_ctx.device_type, device_ctx.device_type) + shape = (4,) + + # Create module for add whose target is the device. + tensor_a = tvm.placeholder(shape, name="A") + tensor_b = tvm.placeholder(shape, name="B") + elemwise_add = tvm.compute(shape, lambda *i: tensor_a(*i) + + tensor_b(*i), name="elemwise_add") + target = topi.cpp.TEST_create_target(device) + schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add]) + lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add], + name="elemwise_add") + host_funcs_add, lib_add = tvm.build(lower_add, target=target_device, + name="elemwise_add", + postpone_host_codegen=True) + + # Insert copy. Neither compute nor schedule is required for the copy + # node. The compute will be performed at runtime which is just data + # copy from the input to the output. + tensor_copy = tvm.placeholder(shape, name="__copy") + + # Create module for sub whose target is the host. + tensor_c = tvm.placeholder(shape, name="C") + elemwise_sub = tvm.compute(shape, lambda *i: tensor_copy(*i) + - tensor_c(*i), name="elemwise_sub") + schedule_sub = tvm.create_schedule(elemwise_sub.op) + lower_sub = tvm.lower(schedule_sub, [tensor_copy, tensor_c, + elemwise_sub], + name="elemwise_sub") + + host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host, + name="elemwise_sub", + postpone_host_codegen=True) + host_funcs = host_funcs_add + host_funcs_sub + mhost = tvm.codegen.build_module(host_funcs, target_host) + if lib_add: + mhost.import_module(lib_add) + if lib_sub: + mhost.import_module(lib_sub) + + ctx = [host_ctx, device_ctx] + mod = graph_runtime.create(graph, mhost, ctx) + params = {} + params["A"] = tensor_a = np.random.uniform( + size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform( + size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform( + size=shape).astype(tensor_c.dtype) + mod.set_input(**params) + mod.run() + out = mod.get_output(0, tvm.nd.empty(shape)) + np.testing.assert_equal( + out.asnumpy(), (tensor_a + tensor_b) - tensor_c) + + dev_tar = {"cuda": "cuda", "opencl": "opencl"} + for device, target in dev_tar.items(): + check_device(device, target) + + +def get_duplex_graph(host_dev_type, device_dev_type): + r""" Return the hand-crafted json object where two copy nodes are inserted. + Data transferring happens back-and-forth between the target device and CPU. + The network is constructed as following: + A B + \ / + elemwise_add (gpu) + \ + copy C + \ / + elemwise_sub (cpu) + \ + copy D + \ / + elemwise_add (gpu) + + Parameters + ---------- + host_dev_type : int + The device type of the host processor, e.g. cpu. + device_dev_type : int + The device type of the device processor, e.g. gpu, opencl, etc. + + Returns + ------- + json : json + A json encoded object. + """ + # Construct each node in the graph. + var_a = {"op": "null", "name": "A", "inputs": []} + var_b = {"op": "null", "name": "B", "inputs": []} + elemwise_add0 = { + "op": "tvm_op", "name": "elemwise_add0", + "attrs": { + "flatten_data": "1", + "func_name": "elemwise_add0", + "num_inputs": "2", + "num_outputs": "1" + }, + "inputs": [[0, 0, 0], [1, 0, 0]] + } + copy_add_sub = { + "op": "tvm_op", + "name": "__copy_add_to_sub", + "attrs": { + "flatten_data": "0", + "func_name": "__copy", + "num_inputs": "1", + "num_outputs": "1" + }, + "inputs": [[2, 0, 0]] + } + var_c = {"op": "null", "name": "C", "inputs": []} + elemwise_sub = { + "op": "tvm_op", "name": "elemwise_sub", + "attrs": { + "flatten_data": "0", + "func_name": "elemwise_sub", + "num_inputs": "2", + "num_outputs": "1" + }, + "inputs": [[3, 0, 0], [4, 0, 0]] + } + copy_sub_add = { + "op": "tvm_op", + "name": "__copy_sub_to_add", + "attrs": { + "flatten_data": "0", + "func_name": "__copy", + "num_inputs": "1", + "num_outputs": "1" + }, + "inputs": [[5, 0, 0]] + } + var_d = {"op": "null", "name": "D", "inputs": []} + elemwise_add1 = { + "op": "tvm_op", "name": "elemwise_add1", + "attrs": { + "flatten_data": "0", + "func_name": "elemwise_add1", + "num_inputs": "2", + "num_outputs": "1" + }, + "inputs": [[6, 0, 0], [7, 0, 0]] + } + + # Group the nodes. + nodes = [var_a, var_b, elemwise_add0, copy_add_sub, var_c, elemwise_sub, + copy_sub_add, var_d, elemwise_add1] + arg_nodes = [0, 1, 4, 7] + node_row_ptr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + heads = [[8, 0, 0]] + shape = (4,) + attrs = { + "storage_id": ["list_int", [4, 5, 0, 1, 6, 2, 0, 7, 3]], + "shape": ["list_shape", [shape, shape, shape, shape, shape, shape, + shape, shape, shape]], + "device_index": ["list_int", [device_dev_type, device_dev_type, + device_dev_type, + host_dev_type, host_dev_type, host_dev_type, + device_dev_type, device_dev_type, + device_dev_type]], + "dtype": ["list_int", [0, 0, 0, 0, 0, 0, 0, 0, 0]], + "dltype": ["list_str", ["float32", "float32", "float32", + "float32", "float32", "float32", + "float32", "float32", "float32"]] + } + + # Construct the graph. + graph = {"nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": heads, + "attrs": attrs} + return json.dumps(graph) + + +def test_duplex_data_transferring(): + r""" + Test the heterogeneous execution of a simple network where data + transferring occurs back-and-forth between the target device and host + processor. + The host processor is always assumed to be cpu, and the target device + varies. + """ + host = "cpu" + target_host = "llvm" + host_ctx = tvm.context(host) + if not tvm.module.enabled(target_host): + print("Skip test because llvm is not enabled.") + return + + def check_device(device, target_device): + if not tvm.module.enabled(target_device): + print("Skip test because {} is not enabled.".format(target_device)) + return + + device_ctx = tvm.context(device) + graph = get_duplex_graph(host_ctx.device_type, device_ctx.device_type) + shape = (4,) + + # Insert copy nodes for data transferring between add and sub nodes. + # Transfers data from gpu to cpu. + copy_add_sub = tvm.placeholder(shape, name="__copy0") + # Transfers data from cpu to gpu. + copy_sub_add = tvm.placeholder(shape, name="__copy1") + + # Create a module containing adds on the device. + tensor_a = tvm.placeholder(shape, name="A") + tensor_b = tvm.placeholder(shape, name="B") + tensor_d = tvm.placeholder(shape, name="D") + elemwise_add0 = tvm.compute(shape, lambda *i: tensor_a(*i) + + tensor_b(*i), name="elemwise_add0") + elemwise_add1 = tvm.compute(shape, lambda *i: copy_sub_add(*i) + + tensor_d(*i), name="elemwise_add1") + target = topi.cpp.TEST_create_target(device) + add_schedule0 = topi.cpp.cuda.schedule_injective( + target, [elemwise_add0]) + lower_add0 = tvm.lower( + add_schedule0, [tensor_a, tensor_b, elemwise_add0], + name="elemwise_add0") + add_schedule1 = topi.cpp.cuda.schedule_injective( + target, [elemwise_add1]) + lower_add1 = tvm.lower( + add_schedule1, [tensor_d, copy_sub_add, elemwise_add1], + name="elemwise_add1") + host_funcs_add, lib_add = tvm.build([lower_add0, lower_add1], + target=target_device, + postpone_host_codegen=True) + + # Create module for sub whose target is the host. + tensor_c = tvm.placeholder(shape, name="C") + elemwise_sub = tvm.compute(shape, lambda *i: copy_add_sub(*i) + - tensor_c(*i), name="elemwise_sub") + sub_schedule = tvm.create_schedule(elemwise_sub.op) + lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c, + elemwise_sub], + name="elemwise_sub") + host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host, + postpone_host_codegen=True) + host_funcs = host_funcs_add + host_funcs_sub + mhost = tvm.codegen.build_module(host_funcs, target_host) + if lib_add: + mhost.import_module(lib_add) + if lib_sub: + mhost.import_module(lib_sub) + + ctx = [host_ctx, device_ctx] + params = {} + params["A"] = tensor_a = np.random.uniform( + size=shape).astype(tensor_a.dtype) + params["B"] = tensor_b = np.random.uniform( + size=shape).astype(tensor_b.dtype) + params["C"] = tensor_c = np.random.uniform( + size=shape).astype(tensor_c.dtype) + params["D"] = tensor_d = np.random.uniform( + size=shape).astype(tensor_d.dtype) + + def check_verify(): + mod = graph_runtime.create(graph, mhost, ctx) + mod.set_input(**params) + mod.run() + out = mod.get_output(0, tvm.nd.empty(shape)) + np.testing.assert_equal( + out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d) + + def check_load_module(): + temp = util.tempdir() + path_lib = temp.relpath("deploy.so") + mhost.export_library(path_lib) + with open(temp.relpath("deploy.json"), "w") as out_file: + out_file.write(graph) + loaded_lib = tvm.module.load(path_lib) + loaded_graph = open(temp.relpath("deploy.json")).read() + mod = graph_runtime.create(loaded_graph, loaded_lib, ctx) + mod.set_input(**params) + mod.run() + out = mod.get_output(0, tvm.nd.empty(shape)) + np.testing.assert_equal( + out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d) + + check_verify() + check_load_module() + + dev_tar = {"cuda": "cuda", "opencl": "opencl"} + for device, target in dev_tar.items(): + check_device(device, target) + +if __name__ == "__main__": + test_simplex_data_transferring() + test_duplex_data_transferring()