diff --git a/examples/graph_executor/python/tvm_graph/build.py b/examples/graph_executor/python/tvm_graph/build.py index 6ba51853b8880d9c49ff48b7a4edbb9e60d2975e..228713d86649894c3339f74f1c2ee18b74173deb 100644 --- a/examples/graph_executor/python/tvm_graph/build.py +++ b/examples/graph_executor/python/tvm_graph/build.py @@ -29,7 +29,7 @@ def build(sym, target, shape, dtype="float32"): def bind(g, ctx): - m = _create_exec(g.handle, ctx) + m = _create_exec(g.handle, ctx.device_type, ctx.device_id) return m diff --git a/examples/graph_executor/src/graph_executor.cc b/examples/graph_executor/src/graph_executor.cc index 868cfa42b2094ba6674ffedec6669a0592d190bb..d649017cf60ccd526bd726fd6272b30487dd0f89 100644 --- a/examples/graph_executor/src/graph_executor.cc +++ b/examples/graph_executor/src/graph_executor.cc @@ -263,7 +263,9 @@ tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) { TVM_REGISTER_GLOBAL("tvm_graph._create_executor") .set_body([](TVMArgs args, TVMRetValue *rv) { void* graph_handle = args[0]; - TVMContext ctx = args[1]; + int device_type = args[1]; + int device_id = args[2]; + TVMContext ctx{static_cast<DLDeviceType>(device_type), device_id}; nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0]; *rv = CreateExecutor(g, ctx); });