diff --git a/tutorials/nnvm_quick_start.py b/tutorials/nnvm_quick_start.py index e16184300e2f8309bd1aef6107609224cd44b8be..0244cbe81e5eb41a7e184c99a354f9f620090e8a 100644 --- a/tutorials/nnvm_quick_start.py +++ b/tutorials/nnvm_quick_start.py @@ -133,7 +133,7 @@ loaded_lib = tvm.module.load(path_lib) loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) -module = graph_runtime.create(loaded_json, loaded_lib, tvm.gpu(0)) +module = graph_runtime.create(loaded_json, loaded_lib, ctx) module.load_params(loaded_params) module.run(data=input_data) out = module.get_output(0).asnumpy()