diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py
index 70935cde18162eeea54253395bf76d85a9bfcb9e..8e0d16286d6ad57ec6f5be8408a5980b65ea5097 100755
--- a/python/tvm/build_module.py
+++ b/python/tvm/build_module.py
@@ -384,8 +384,14 @@ def build(sch,
           target=None,
           target_host=None,
           name="default_function",
-          binds=None):
-    """Build a function with arguments as signiture.
+          binds=None,
+          postpone_host_codegen=False):
+    """Build a function with arguments as signature. Code will be generated
+    for a device specified by the target. For homogeneous execution, a module
+    that contains both host and device code is returned. For heterogeneous
+    execution, a list of lowered functions for the host and a module containing
+    device code are returned, but actual code generation for the host module is
+    postponed after code generation is finished for all devices.
 
     Parameters
     ----------
@@ -414,10 +420,18 @@ def build(sch,
         Dictionary that maps the binding of symbolic buffer to Tensor.
         By default, a new buffer is created for each tensor in the argument.
 
+    postpone_host_codegen : bool, optional
+        A bool value that indicates if code generation for the host module
+        should be postponed. This variable is set to be true for heterogeneous
+        execution. Otherwise, it is defaulted to false.
+
     Returns
     -------
-    f : Function, or pair of functions
-       The result function.
+    ret : tvm.module, or (list of LoweredFunc, tvm.module) tuple
+        A module that combines both host and device code is returned when
+        postpone_host_codegen is not set. Otherwise, a list of lowered
+        functions for the host and a module contains only device code are
+        returned.
 
     Note
     ----
@@ -498,9 +512,15 @@ def build(sch,
     fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
     fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
     fhost = [ir_pass.CombineContextCall(x) for x in fhost]
-    mhost = codegen.build_module(fhost, str(target_host))
 
+    # Append fhost to the device module and return the updated module. All
+    # device modules will be imported to the host module after all of them are
+    # collected.
+    mdev = codegen.build_module(fdevice, str(target_device)) if fdevice else None
+    if postpone_host_codegen:
+        return fhost, mdev
+
+    mhost = codegen.build_module(fhost, str(target_host))
     if fdevice:
-        mdev = codegen.build_module(fdevice, str(target_device))
         mhost.import_module(mdev)
     return mhost
diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py
index e49b966e6a1e7416a3bbfcee077cd192a50a1f3e..f0e83eec0bb8eee9c929f6565e20b8cb3ac8c1fa 100644
--- a/python/tvm/contrib/graph_runtime.py
+++ b/python/tvm/contrib/graph_runtime.py
@@ -3,26 +3,24 @@ import numpy as np
 
 from .._ffi.base import string_types
 from .._ffi.function import get_global_func
+from .._ffi.runtime_ctypes import TVMContext
 from ..rpc import base as rpc_base
-from .. import ndarray as nd
-
 
 def create(graph_json_str, libmod, ctx):
     """Create a runtime executor module given a graph and module.
-
     Parameters
     ----------
     graph_json_str : str or graph class
         The graph to be deployed in json format output by nnvm graph.
         The graph can only contain one operator(tvm_op) that
         points to the name of PackedFunc in the libmod.
-
     libmod : tvm.Module
         The module of the corresponding function
-
-    ctx : TVMContext
-        The context to deploy the module, can be local or remote.
-
+    ctx : TVMContext or list of TVMContext
+        The context to deploy the module. It can be local or remote when there
+        is only one TVMContext. Otherwise, the first context in the list will
+        be used as this purpose. All context should be given for heterogeneous
+        execution.
     Returns
     -------
     graph_module : GraphModule
@@ -33,17 +31,42 @@ def create(graph_json_str, libmod, ctx):
             graph_json_str = graph_json_str._tvm_graph_json()
         except AttributeError:
             raise ValueError("Type %s is not supported" % type(graph_json_str))
-    device_type = ctx.device_type
-    device_id = ctx.device_id
-    if device_type >= rpc_base.RPC_SESS_MASK:
-        assert libmod.type_key == "rpc"
-        assert rpc_base._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index
+    if isinstance(ctx, TVMContext):
+        ctx = [ctx]
+    elif not isinstance(ctx, (list, tuple)):
+        raise ValueError("ctx has to be the type of TVMContext or a list of "
+                         "TVMCTVMContext")
+    for cur_ctx in ctx:
+        if not isinstance(cur_ctx, TVMContext):
+            raise ValueError("ctx has to be the type of TVMContext or a list "
+                             "of TVMContext")
+
+    # device_type_id[0], device_type_id[1] are used as the primary/fallback
+    # context type and id. All other ones are used as device context for
+    # heterogeneous execution.
+    num_rpc_ctx = 0
+    device_type_id = []
+    for cur_ctx in ctx:
+        device_type = cur_ctx.device_type
+        if device_type >= rpc_base.RPC_SESS_MASK:
+            assert libmod.type_key == "rpc"
+            assert rpc_base._SessTableIndex(
+                libmod) == cur_ctx._rpc_sess._tbl_index
+            num_rpc_ctx += 1
+            device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK
+        device_type_id.append(device_type)
+        device_type_id.append(cur_ctx.device_id)
+
+    if 0 < num_rpc_ctx < len(ctx):
+        raise ValueError("Either all or none of the contexts should be rpc.")
+
+    if num_rpc_ctx == len(ctx):
         hmod = rpc_base._ModuleHandle(libmod)
-        fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create")
-        device_type = device_type % rpc_base.RPC_SESS_MASK
-        return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx)
+        fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
+        return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
+
     fcreate = get_global_func("tvm.graph_runtime.create")
-    return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx)
+    return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))
 
 
 class GraphModule(object):
@@ -58,18 +81,13 @@ class GraphModule(object):
     module : Module
         The interal tvm module that holds the actual graph functions.
 
-    ctx : TVMContext
-        The context this module is under
-
     Attributes
     ----------
     module : Module
         The interal tvm module that holds the actual graph functions.
-
-    ctx : TVMContext
-        The context this module is under
     """
-    def __init__(self, module, ctx):
+
+    def __init__(self, module):
         self.module = module
         self._set_input = module["set_input"]
         self._run = module["run"]
@@ -81,7 +99,6 @@ class GraphModule(object):
         except AttributeError:
             pass
         self._load_params = module["load_params"]
-        self.ctx = ctx
 
     def set_input(self, key=None, value=None, **params):
         """Set inputs to the module via kwargs
@@ -98,14 +115,14 @@ class GraphModule(object):
            Additonal arguments
         """
         if key:
-            self._set_input(key, nd.array(value, ctx=self.ctx))
+            self._get_input(key).copyfrom(value)
 
         if params:
             # upload big arrays first to avoid memory issue in rpc mode
             keys = list(params.keys())
             keys.sort(key=lambda x: -np.prod(params[x].shape))
             for k in keys:
-                self._set_input(k, nd.array(params[k], ctx=self.ctx))
+                self._get_input(k).copyfrom(params[k])
 
     def run(self, **input_dict):
         """Run forward execution of the graph
@@ -177,7 +194,8 @@ class GraphModule(object):
         if hasattr(self, '_debug_get_output'):
             self._debug_get_output(node, out)
         else:
-            raise RuntimeError("Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
+            raise RuntimeError(
+                "Please compile runtime with USE_GRAPH_RUNTIME_DEBUG = 0")
         return out
 
     def load_params(self, params_bytes):
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index 162d616dea8a120fc2d8551c70a8552dcde791b9..a48047fe369cd0f3d168f25c665d0a9246cee34c 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -2,22 +2,26 @@
  *  Copyright (c) 2017 by Contributors
  * \file graph_runtime.cc
  */
+#include "graph_runtime.h"
+
+#include <dlpack/dlpack.h>
+#include <dmlc/json.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/device_api.h>
-#include <dmlc/memory_io.h>
-#include <dmlc/json.h>
-#include <numeric>
+#include <tvm/runtime/serializer.h>
+
 #include <algorithm>
-#include <vector>
 #include <functional>
-#include "graph_runtime.h"
+#include <numeric>
+#include <vector>
 
 namespace tvm {
 namespace runtime {
 
-/*! \brief macro to do C API call */
+/*! \brief Macro to do C API call. */
 #define TVM_CCALL(func)                                            \
   {                                                                \
     int ret = (func);                                              \
@@ -34,7 +38,7 @@ namespace runtime {
 class GraphRuntime : public ModuleNode {
  public:
   /*!
-   * \brief Get member function to front-end
+   * \brief Get member function to front-end.
    * \param name The name of the function.
    * \param sptr_to_self The pointer to the module node.
    * \return The corresponding member function.
@@ -58,12 +62,13 @@ class GraphRuntime : public ModuleNode {
   /*!
    * \brief Initialize the graph executor with graph and context.
    * \param graph_json The execution graph.
-   * \param module The module containing the compiled functions.
-   * \param ctx The context where the graph should sit on
+   * \param module The module containing the compiled functions for the host
+   * processor.
+   * \param ctxs The context of the host and devices where graph nodes will be
+   * executed on.
    */
-  void Init(const std::string& graph_json,
-            tvm::runtime::Module module,
-            TVMContext ctx) {
+  void Init(const std::string& graph_json, const tvm::runtime::Module& module,
+            const std::vector<TVMContext>& ctxs) {
 #ifndef _LIBCPP_SGX_NO_IOSTREAMS
     std::istringstream is(graph_json);
 #else
@@ -72,10 +77,11 @@ class GraphRuntime : public ModuleNode {
     dmlc::JSONReader reader(&is);
     this->Load(&reader);
     module_ = module;
-    ctx_ = ctx;
+    ctxs_ = ctxs;
     this->SetupStorage();
     this->SetupOpExecs();
   }
+
   /*!
    * \brief Get the input index given the name of input.
    * \param name The name of the input.
@@ -92,7 +98,7 @@ class GraphRuntime : public ModuleNode {
     return -1;
   }
   /*!
-   * \brief set index-th input to the graph.
+   * \brief Set index-th input to the graph.
    * \param index The input index.
    * \param data_in The input data.
    */
@@ -134,7 +140,7 @@ class GraphRuntime : public ModuleNode {
   /*!
    * \brief Copy index-th output to data_out.
    * \param index The output index.
-   * \param data_out the output data.
+   * \param data_out The output data.
    */
   void CopyOutputTo(int index, DLTensor* data_out) {
     CHECK_LT(static_cast<size_t>(index), outputs_.size());
@@ -172,8 +178,8 @@ class GraphRuntime : public ModuleNode {
    * from begining upto the index-th node and return output of index-th node.
    * This is costly operation and suggest to use only for debug porpose.
    *
-   * \param index: The  index of the node.
-   * \param data_out the node data.
+   * \param index The index of the node.
+   * \param data_out The node data.
    */
   void DebugGetNodeOutput(int index, DLTensor* data_out) {
     CHECK_LT(static_cast<size_t>(index), nodes_.size());
@@ -188,7 +194,7 @@ class GraphRuntime : public ModuleNode {
   }
 #endif
   /*!
-   * \brief Load parameters from binary stream
+   * \brief Load parameters from binary stream.
    * \param strm The input stream.
    */
   void LoadParams(dmlc::Stream* strm);
@@ -202,6 +208,12 @@ class GraphRuntime : public ModuleNode {
   }
 
  private:
+  // Memory pool entry.
+  struct PoolEntry {
+    size_t size;
+    int device_type;
+    PoolEntry(int s, int dev_type) : size(s), device_type(dev_type) {}
+  };
   // Node entry
   struct NodeEntry {
     uint32_t node_id;
@@ -260,7 +272,6 @@ class GraphRuntime : public ModuleNode {
     // JSON Loader
     void Load(dmlc::JSONReader *reader) {
       reader->BeginObject();
-      std::unordered_map<std::string, std::string> dict;
       int bitmask = 0;
       std::string key;
       while (reader->NextObjectItem(&key)) {
@@ -287,6 +298,7 @@ class GraphRuntime : public ModuleNode {
   struct GraphAttr {
     size_t storage_num_not_alloctaed{0};
     std::vector<int> storage_id;
+    std::vector<int> device_index;
     std::vector<std::string> dltype;
     std::vector<std::vector<int64_t> > shape;
     // The graph attribute fields.
@@ -322,6 +334,14 @@ class GraphRuntime : public ModuleNode {
           reader->Read(&shape);
           CHECK(!reader->NextArrayItem());
           bitmask |= 4;
+        } else if (key == "device_index") {
+          reader->BeginArray();
+          CHECK(reader->NextArrayItem());
+          reader->Read(&type);
+          CHECK_EQ(type, "list_int");
+          CHECK(reader->NextArrayItem());
+          reader->Read(&device_index);
+          CHECK(!reader->NextArrayItem());
         } else {
           reader->BeginArray();
           CHECK(reader->NextArrayItem());
@@ -372,13 +392,14 @@ class GraphRuntime : public ModuleNode {
   }
   /*! \brief Setup the temporal storage */
   void SetupStorage();
-  /*! \brief Setup the executors */
+  /*! \brief Setup the executors. */
   void SetupOpExecs();
   /*!
    * \brief Create a executtion function given input.
-   * \param attrs The node attributes
+   * \param attrs The node attributes.
    * \param args The arguments to the functor, including inputs and outputs.
-   * \param num_inputs Number of inputs
+   * \param num_inputs Number of inputs.
+   * \param dev_type The device type of the tvm_op.
    * \return The created executor.
    */
   std::function<void()> CreateTVMOp(const TVMOpParam& attrs,
@@ -392,7 +413,7 @@ class GraphRuntime : public ModuleNode {
   uint32_t entry_id(const NodeEntry& e) const {
     return entry_id(e.node_id, e.index);
   }
-  // Number of node entries
+  // Number of node entries.
   uint32_t num_node_entries() const {
     return node_row_ptr_.back();
   }
@@ -400,25 +421,25 @@ class GraphRuntime : public ModuleNode {
   uint32_t num_nodes() const {
     return static_cast<uint32_t>(nodes_.size());
   }
-  // The graph nodes.
+  /*! \brief The graph nodes. */
   std::vector<Node> nodes_;
-  // The argument nodes.
+  /*! \brief The argument nodes. */
   std::vector<uint32_t> input_nodes_;
-  // used or quick entry indexing
+  /*! \brief Used for quick entry indexing. */
   std::vector<uint32_t> node_row_ptr_;
-  // output entries
+  /*! \brief Output entries. */
   std::vector<NodeEntry> outputs_;
-  // Additional graph attributes
+  /*! \brief Additional graph attributes. */
   GraphAttr attrs_;
-  /*! \brief The code module */
+  /*! \brief The code module that contains both host and device code. */
   tvm::runtime::Module module_;
-  /*! \brief execution context */
-  TVMContext ctx_;
-  /*! \brief common storage pool */
+  /*! \brief Execution context of all devices including the host. */
+  std::vector<TVMContext> ctxs_;
+  /*! \brief Common storage pool for all devices. */
   std::vector<NDArray> storage_pool_;
-  /*! \brief data entry of each node */
+  /*! \brief Data entry of each node. */
   std::vector<NDArray> data_entry_;
-  /*! \brief operator on each node */
+  /*! \brief Operator on each node. */
   std::vector<std::function<void()> > op_execs_;
 };
 
@@ -458,12 +479,17 @@ void GraphRuntime::SetupStorage() {
   for (const std::string& s_type : attrs_.dltype) {
     vtype.push_back(tvm::runtime::String2TVMType(s_type));
   }
-  data_entry_.resize(num_node_entries());
-  // size of each storage pool entry
-  std::vector<size_t> pool_entry_bytes;
+
+  // Size and device type of each storage pool entry.
+  std::vector<PoolEntry> pool_entry;
   // Find the maximum space size.
   for (size_t i = 0; i < attrs_.shape.size(); ++i) {
     int storage_id = attrs_.storage_id[i];
+    // Use the fallback device if no device index is available.
+    int device_type = static_cast<int>(ctxs_[0].device_type);
+    if (!attrs_.device_index.empty()) {
+      device_type = attrs_.device_index[i];
+    }
     size_t size = 1;
     for (int64_t sz : attrs_.shape[i]) {
       size *= static_cast<size_t>(sz);
@@ -474,23 +500,42 @@ void GraphRuntime::SetupStorage() {
     CHECK_EQ(bits % 8U, 0U);
     size_t bytes = (bits / 8U) * size;
 
-    size_t sid = static_cast<size_t>(storage_id);
-    if (sid >= pool_entry_bytes.size()) {
-      pool_entry_bytes.resize(sid + 1, 0);
+    uint32_t sid = static_cast<uint32_t>(storage_id);
+    if (sid >= pool_entry.size()) {
+      pool_entry.resize(sid + 1, {0, -1});
+    } else {
+      CHECK(pool_entry[sid].device_type == -1 ||
+            pool_entry[sid].device_type == device_type)
+          << "The same pool entry cannot be assigned to multiple devices";
     }
-    pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes);
+    pool_entry[sid].size = std::max(pool_entry[sid].size, bytes);
+    pool_entry[sid].device_type = device_type;
   }
+
   // Allocate the space.
-  for (size_t i = 0; i < pool_entry_bytes.size(); ++i) {
+  for (const auto& pit : pool_entry) {
     std::vector<int64_t> shape;
-    shape.push_back(static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4);
-    storage_pool_.push_back(NDArray::Empty(shape, DLDataType {kDLFloat, 32, 1}, ctx_));
+    // This for loop is very fast since there are usually only a couple of
+    // devices available on the same hardware.
+    const auto& cit =
+        std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) {
+          return pit.device_type == static_cast<int>(c.device_type);
+        });
+    TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit;
+    shape.push_back(static_cast<int64_t>(pit.size + 3) / 4);
+    storage_pool_.push_back(
+        NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx));
   }
-  // Assign the pooled entries.
+
+  // Assign the pooled entries. A unified memory pool is used to simplifiy
+  // memory assignment for each node entry. The allocated memory on each device
+  // is mapped to this pool.
+  data_entry_.resize(num_node_entries());
   for (size_t i = 0; i < data_entry_.size(); ++i) {
     int storage_id = attrs_.storage_id[i];
     CHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
-    data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
+    data_entry_[i] =
+        storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
   }
 }
 
@@ -508,8 +553,8 @@ void GraphRuntime::SetupOpExecs() {
       uint32_t eid = this->entry_id(nid, index);
       args.push_back(*(data_entry_[eid].operator->()));
     }
-    CHECK_EQ(inode.op_type, "tvm_op")
-        << "Can only take tvm_op as op";
+    CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
+
     op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size());
   }
 }
@@ -543,13 +588,26 @@ std::function<void()> GraphRuntime::CreateTVMOp(
       t->shape = &(arg_ptr->shape_data[i]);
     }
   }
+
   if (param.func_name == "__nop") {
     return [](){};
+  } else if (param.func_name == "__copy") {
+    // Perform cross device data copy.
+    // Directly copy data from the input to the output.
+    auto fexec = [arg_ptr]() {
+      DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle);
+      DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
+      TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
+    };
+    return fexec;
   }
-  // get compiled function from module.
+
+  // Get compiled function from the module that contains both host and device
+  // code.
   tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false);
   CHECK(pf != nullptr) << "no such function in module: " << param.func_name;
-  auto fexec = [arg_ptr, pf] () {
+
+  auto fexec = [arg_ptr, pf]() {
     TVMRetValue rv;
     TVMArgs targs(arg_ptr->arg_values.data(),
                   arg_ptr->arg_tcodes.data(),
@@ -562,7 +620,7 @@ std::function<void()> GraphRuntime::CreateTVMOp(
 PackedFunc GraphRuntime::GetFunction(
     const std::string& name,
     const std::shared_ptr<ModuleNode>& sptr_to_self) {
-  // return member functions during query.
+  // Return member functions during query.
   if (name == "set_input") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         if (args[0].type_code() == kStr) {
@@ -618,29 +676,53 @@ PackedFunc GraphRuntime::GetFunction(
   }
 }
 
-Module GraphRuntimeCreate(std::string sym_json,
-                          tvm::runtime::Module m,
-                          int device_type,
-                          int device_id) {
-  TVMContext ctx;
-  ctx.device_type = static_cast<DLDeviceType>(device_type);
-  ctx.device_id   = device_id;
+Module GraphRuntimeCreate(const std::string& sym_json,
+                          const tvm::runtime::Module& m,
+                          const std::vector<TVMContext>& ctxs) {
   std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>();
-  exec->Init(sym_json, m, ctx);
+  exec->Init(sym_json, m, ctxs);
   return Module(exec);
 }
 
+// Get all context for the host and other runtime devices.
+std::vector<TVMContext> GetAllContext(const TVMArgs& args) {
+  // Reserve the first item as the fallback device.
+  std::vector<TVMContext> ret;
+  TVMContext ctx;
+  for (int i = 2; i < args.num_args; i += 2) {
+    int dev_type = args[i];
+    ctx.device_type = static_cast<DLDeviceType>(dev_type);
+    ctx.device_id = args[i + 1];
+    ret.push_back(ctx);
+  }
+  return ret;
+}
+
+// 4-argument version is currently reserved to keep support of calling
+// from tvm4j and javascript, since they don't have heterogeneous
+// execution support yet. For heterogenenous execution, at least 5 arguments will
+// be passed in. The third one is the number of devices.
+// Eventually, we will only probably pass TVMContext for all the languages.
 TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
-    *rv = GraphRuntimeCreate(args[0], args[1], args[2], args[3]);
+  .set_body([](TVMArgs args, TVMRetValue* rv) {
+    CHECK_GE(args.num_args, 4)
+        << "The expected number of arguments for graph_runtime.create is "
+           "at least 4, but it has "
+        << args.num_args;
+    const auto& contexts = GetAllContext(args);
+    *rv = GraphRuntimeCreate(args[0], args[1], contexts);
   });
 
 TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
-.set_body([](TVMArgs args, TVMRetValue *rv) {
+  .set_body([](TVMArgs args, TVMRetValue* rv) {
+    CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
+                                  "graph_runtime.remote_create is "
+                                  "at least 4, but it has "
+                               << args.num_args;
     void* mhandle = args[1];
-    *rv = GraphRuntimeCreate(args[0],
-                             *static_cast<tvm::runtime::Module*>(mhandle),
-                             args[2], args[3]);
+    const auto& contexts = GetAllContext(args);
+    *rv = GraphRuntimeCreate(
+        args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
   });
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py
new file mode 100644
index 0000000000000000000000000000000000000000..b916ee28571778690f034119d819b67f0cdbeac7
--- /dev/null
+++ b/tests/python/unittest/test_runtime_heterogeneous.py
@@ -0,0 +1,405 @@
+# pylint: disable=too-many-locals
+"""Unit tests for heterogeneous runtime"""
+import json
+import numpy as np
+
+import tvm
+from tvm.contrib import graph_runtime, util
+import topi
+
+def get_simplex_graph(host_dev_type, device_dev_type):
+    r""" Return the hand-crafted json object where only one copy node is
+    inserted. This node copies data from the target device to cpu.
+    The network is constructed as following:
+                 A    B
+                  \  /
+             elemwise_add  (gpu)
+                     \
+                     copy      C
+                       \      /
+                     elemwise_sub  (cpu)
+
+    Parameters
+    ----------
+    host_dev_type : int
+        The device type of the host processor, e.g. cpu.
+    device_dev_type : int
+        The device type of the device processor, e.g. gpu, opencl, etc.
+
+    Returns
+    -------
+    json : json
+        A json encoded object.
+    """
+    # Construct each node in the graph.
+    var_a = {"op": "null", "name": "A", "inputs": []}
+    var_b = {"op": "null", "name": "B", "inputs": []}
+    elemwise_add = {
+        "op": "tvm_op", "name": "elemwise_add",
+        "attrs": {
+            "flatten_data": "1",
+            "func_name": "elemwise_add",
+            "num_inputs": "2",
+            "num_outputs": "1"
+        },
+        "inputs": [[0, 0, 0], [1, 0, 0]]
+    }
+    copy = {
+        "op": "tvm_op",
+        "name": "__copy_add_to_sub",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "__copy",
+            "num_inputs": "1",
+            "num_outputs": "1"
+        },
+        "inputs": [[2, 0, 0]]
+    }
+    var_c = {"op": "null", "name": "C", "inputs": []}
+    elemwise_sub = {
+        "op": "tvm_op", "name": "elemwise_sub",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "elemwise_sub",
+            "num_inputs": "2",
+            "num_outputs": "1"
+        },
+        "inputs": [[3, 0, 0], [4, 0, 0]]
+    }
+
+    # Group the nodes.
+    nodes = [var_a, var_b, elemwise_add, copy, var_c, elemwise_sub]
+    arg_nodes = [0, 1, 4]
+    node_row_ptr = [0, 1, 2, 3, 4, 5, 6]
+    heads = [[5, 0, 0]]
+    shape = (4,)
+    attrs = {
+        "storage_id": ["list_int", [3, 4, 0, 1, 5, 2]],
+        "shape": ["list_shape", [shape, shape, shape, shape, shape, shape]],
+        "device_index": ["list_int", [device_dev_type, device_dev_type,
+                                      device_dev_type, host_dev_type,
+                                      host_dev_type, host_dev_type]],
+        "dtype": ["list_int", [0, 0, 0, 0, 0, 0]],
+        "dltype": ["list_str", ["float32", "float32", "float32",
+                                "float32", "float32", "float32"]]
+    }
+
+    # Construct the graph.
+    graph = {"nodes": nodes,
+             "arg_nodes": arg_nodes,
+             "node_row_ptr": node_row_ptr,
+             "heads": heads,
+             "attrs": attrs}
+    return json.dumps(graph)
+
+
+def test_simplex_data_transferring():
+    r"""
+    Test the heterogeneous execution of a simple network where data
+    transferring is from the target device to the host processor at runtime.
+    The host processor is always assumed to be cpu, and the device varies.
+    """
+    host = "cpu"
+    target_host = "llvm"
+    host_ctx = tvm.context(host)
+    if not tvm.module.enabled(target_host):
+        print("Skip test because llvm is not enabled.")
+        return
+
+    def check_device(device, target_device):
+        if not tvm.module.enabled(target_device):
+            print("Skip test because {} is not enabled.".format(target_device))
+            return
+
+        device_ctx = tvm.context(device)
+        graph = get_simplex_graph(host_ctx.device_type, device_ctx.device_type)
+        shape = (4,)
+
+        # Create module for add whose target is the device.
+        tensor_a = tvm.placeholder(shape, name="A")
+        tensor_b = tvm.placeholder(shape, name="B")
+        elemwise_add = tvm.compute(shape, lambda *i: tensor_a(*i)
+                                   + tensor_b(*i), name="elemwise_add")
+        target = topi.cpp.TEST_create_target(device)
+        schedule_add = topi.cpp.cuda.schedule_injective(target, [elemwise_add])
+        lower_add = tvm.lower(schedule_add, [tensor_a, tensor_b, elemwise_add],
+                              name="elemwise_add")
+        host_funcs_add, lib_add = tvm.build(lower_add, target=target_device,
+                                            name="elemwise_add",
+                                            postpone_host_codegen=True)
+
+        # Insert copy. Neither compute nor schedule is required for the copy
+        # node. The compute will be performed at runtime which is just data
+        # copy from the input to the output.
+        tensor_copy = tvm.placeholder(shape, name="__copy")
+
+        # Create module for sub whose target is the host.
+        tensor_c = tvm.placeholder(shape, name="C")
+        elemwise_sub = tvm.compute(shape, lambda *i: tensor_copy(*i)
+                                   - tensor_c(*i), name="elemwise_sub")
+        schedule_sub = tvm.create_schedule(elemwise_sub.op)
+        lower_sub = tvm.lower(schedule_sub, [tensor_copy, tensor_c,
+                                             elemwise_sub],
+                              name="elemwise_sub")
+
+        host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
+                                            name="elemwise_sub",
+                                            postpone_host_codegen=True)
+        host_funcs = host_funcs_add + host_funcs_sub
+        mhost = tvm.codegen.build_module(host_funcs, target_host)
+        if lib_add:
+            mhost.import_module(lib_add)
+        if lib_sub:
+            mhost.import_module(lib_sub)
+
+        ctx = [host_ctx, device_ctx]
+        mod = graph_runtime.create(graph, mhost, ctx)
+        params = {}
+        params["A"] = tensor_a = np.random.uniform(
+            size=shape).astype(tensor_a.dtype)
+        params["B"] = tensor_b = np.random.uniform(
+            size=shape).astype(tensor_b.dtype)
+        params["C"] = tensor_c = np.random.uniform(
+            size=shape).astype(tensor_c.dtype)
+        mod.set_input(**params)
+        mod.run()
+        out = mod.get_output(0, tvm.nd.empty(shape))
+        np.testing.assert_equal(
+            out.asnumpy(), (tensor_a + tensor_b) - tensor_c)
+
+    dev_tar = {"cuda": "cuda", "opencl": "opencl"}
+    for device, target in dev_tar.items():
+        check_device(device, target)
+
+
+def get_duplex_graph(host_dev_type, device_dev_type):
+    r""" Return the hand-crafted json object where two copy nodes are inserted.
+    Data transferring happens back-and-forth between the target device and CPU.
+    The network is constructed as following:
+                 A    B
+                  \  /
+             elemwise_add  (gpu)
+                     \
+                     copy        C
+                       \        /
+                      elemwise_sub  (cpu)
+                         \
+                         copy          D
+                           \          /
+                           elemwise_add  (gpu)
+
+    Parameters
+    ----------
+    host_dev_type : int
+        The device type of the host processor, e.g. cpu.
+    device_dev_type : int
+        The device type of the device processor, e.g. gpu, opencl, etc.
+
+    Returns
+    -------
+    json : json
+        A json encoded object.
+    """
+    # Construct each node in the graph.
+    var_a = {"op": "null", "name": "A", "inputs": []}
+    var_b = {"op": "null", "name": "B", "inputs": []}
+    elemwise_add0 = {
+        "op": "tvm_op", "name": "elemwise_add0",
+        "attrs": {
+            "flatten_data": "1",
+            "func_name": "elemwise_add0",
+            "num_inputs": "2",
+            "num_outputs": "1"
+        },
+        "inputs": [[0, 0, 0], [1, 0, 0]]
+    }
+    copy_add_sub = {
+        "op": "tvm_op",
+        "name": "__copy_add_to_sub",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "__copy",
+            "num_inputs": "1",
+            "num_outputs": "1"
+        },
+        "inputs": [[2, 0, 0]]
+    }
+    var_c = {"op": "null", "name": "C", "inputs": []}
+    elemwise_sub = {
+        "op": "tvm_op", "name": "elemwise_sub",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "elemwise_sub",
+            "num_inputs": "2",
+            "num_outputs": "1"
+        },
+        "inputs": [[3, 0, 0], [4, 0, 0]]
+    }
+    copy_sub_add = {
+        "op": "tvm_op",
+        "name": "__copy_sub_to_add",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "__copy",
+            "num_inputs": "1",
+            "num_outputs": "1"
+        },
+        "inputs": [[5, 0, 0]]
+    }
+    var_d = {"op": "null", "name": "D", "inputs": []}
+    elemwise_add1 = {
+        "op": "tvm_op", "name": "elemwise_add1",
+        "attrs": {
+            "flatten_data": "0",
+            "func_name": "elemwise_add1",
+            "num_inputs": "2",
+            "num_outputs": "1"
+        },
+        "inputs": [[6, 0, 0], [7, 0, 0]]
+    }
+
+    # Group the nodes.
+    nodes = [var_a, var_b, elemwise_add0, copy_add_sub, var_c, elemwise_sub,
+             copy_sub_add, var_d, elemwise_add1]
+    arg_nodes = [0, 1, 4, 7]
+    node_row_ptr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+    heads = [[8, 0, 0]]
+    shape = (4,)
+    attrs = {
+        "storage_id": ["list_int", [4, 5, 0, 1, 6, 2, 0, 7, 3]],
+        "shape": ["list_shape", [shape, shape, shape, shape, shape, shape,
+                                 shape, shape, shape]],
+        "device_index": ["list_int", [device_dev_type, device_dev_type,
+                                      device_dev_type,
+                                      host_dev_type, host_dev_type, host_dev_type,
+                                      device_dev_type, device_dev_type,
+                                      device_dev_type]],
+        "dtype": ["list_int", [0, 0, 0, 0, 0, 0, 0, 0, 0]],
+        "dltype": ["list_str", ["float32", "float32", "float32",
+                                "float32", "float32", "float32",
+                                "float32", "float32", "float32"]]
+    }
+
+    # Construct the graph.
+    graph = {"nodes": nodes,
+             "arg_nodes": arg_nodes,
+             "node_row_ptr": node_row_ptr,
+             "heads": heads,
+             "attrs": attrs}
+    return json.dumps(graph)
+
+
+def test_duplex_data_transferring():
+    r"""
+    Test the heterogeneous execution of a simple network where data
+    transferring occurs back-and-forth between the target device and host
+    processor.
+    The host processor is always assumed to be cpu, and the target device
+    varies.
+    """
+    host = "cpu"
+    target_host = "llvm"
+    host_ctx = tvm.context(host)
+    if not tvm.module.enabled(target_host):
+        print("Skip test because llvm is not enabled.")
+        return
+
+    def check_device(device, target_device):
+        if not tvm.module.enabled(target_device):
+            print("Skip test because {} is not enabled.".format(target_device))
+            return
+
+        device_ctx = tvm.context(device)
+        graph = get_duplex_graph(host_ctx.device_type, device_ctx.device_type)
+        shape = (4,)
+
+        # Insert copy nodes for data transferring between add and sub nodes.
+        # Transfers data from gpu to cpu.
+        copy_add_sub = tvm.placeholder(shape, name="__copy0")
+        # Transfers data from cpu to gpu.
+        copy_sub_add = tvm.placeholder(shape, name="__copy1")
+
+        # Create a module containing adds on the device.
+        tensor_a = tvm.placeholder(shape, name="A")
+        tensor_b = tvm.placeholder(shape, name="B")
+        tensor_d = tvm.placeholder(shape, name="D")
+        elemwise_add0 = tvm.compute(shape, lambda *i: tensor_a(*i)
+                                    + tensor_b(*i), name="elemwise_add0")
+        elemwise_add1 = tvm.compute(shape, lambda *i: copy_sub_add(*i)
+                                    + tensor_d(*i), name="elemwise_add1")
+        target = topi.cpp.TEST_create_target(device)
+        add_schedule0 = topi.cpp.cuda.schedule_injective(
+            target, [elemwise_add0])
+        lower_add0 = tvm.lower(
+            add_schedule0, [tensor_a, tensor_b, elemwise_add0],
+            name="elemwise_add0")
+        add_schedule1 = topi.cpp.cuda.schedule_injective(
+            target, [elemwise_add1])
+        lower_add1 = tvm.lower(
+            add_schedule1, [tensor_d, copy_sub_add, elemwise_add1],
+            name="elemwise_add1")
+        host_funcs_add, lib_add = tvm.build([lower_add0, lower_add1],
+                                            target=target_device,
+                                            postpone_host_codegen=True)
+
+        # Create module for sub whose target is the host.
+        tensor_c = tvm.placeholder(shape, name="C")
+        elemwise_sub = tvm.compute(shape, lambda *i: copy_add_sub(*i)
+                                   - tensor_c(*i), name="elemwise_sub")
+        sub_schedule = tvm.create_schedule(elemwise_sub.op)
+        lower_sub = tvm.lower(sub_schedule, [copy_add_sub, tensor_c,
+                                             elemwise_sub],
+                              name="elemwise_sub")
+        host_funcs_sub, lib_sub = tvm.build(lower_sub, target=target_host,
+                                            postpone_host_codegen=True)
+        host_funcs = host_funcs_add + host_funcs_sub
+        mhost = tvm.codegen.build_module(host_funcs, target_host)
+        if lib_add:
+            mhost.import_module(lib_add)
+        if lib_sub:
+            mhost.import_module(lib_sub)
+
+        ctx = [host_ctx, device_ctx]
+        params = {}
+        params["A"] = tensor_a = np.random.uniform(
+            size=shape).astype(tensor_a.dtype)
+        params["B"] = tensor_b = np.random.uniform(
+            size=shape).astype(tensor_b.dtype)
+        params["C"] = tensor_c = np.random.uniform(
+            size=shape).astype(tensor_c.dtype)
+        params["D"] = tensor_d = np.random.uniform(
+            size=shape).astype(tensor_d.dtype)
+
+        def check_verify():
+            mod = graph_runtime.create(graph, mhost, ctx)
+            mod.set_input(**params)
+            mod.run()
+            out = mod.get_output(0, tvm.nd.empty(shape))
+            np.testing.assert_equal(
+                out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
+
+        def check_load_module():
+            temp = util.tempdir()
+            path_lib = temp.relpath("deploy.so")
+            mhost.export_library(path_lib)
+            with open(temp.relpath("deploy.json"), "w") as out_file:
+                out_file.write(graph)
+            loaded_lib = tvm.module.load(path_lib)
+            loaded_graph = open(temp.relpath("deploy.json")).read()
+            mod = graph_runtime.create(loaded_graph, loaded_lib, ctx)
+            mod.set_input(**params)
+            mod.run()
+            out = mod.get_output(0, tvm.nd.empty(shape))
+            np.testing.assert_equal(
+                out.asnumpy(), tensor_a + tensor_b - tensor_c + tensor_d)
+
+        check_verify()
+        check_load_module()
+
+    dev_tar = {"cuda": "cuda", "opencl": "opencl"}
+    for device, target in dev_tar.items():
+        check_device(device, target)
+
+if __name__ == "__main__":
+    test_simplex_data_transferring()
+    test_duplex_data_transferring()