diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 5a585a19cccf04ede1de6cc4e51ebf68baf403ce..bde7e6b2741875adbae1f5813e020e6a6f3ffaf3 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -254,9 +254,14 @@ PackedFunc OpenCLModuleNode::GetFunction( for (size_t i = 0; i < info.arg_types.size(); ++i) { TVMType t = info.arg_types[i]; CHECK_EQ(t.lanes, 1U); - uint32_t bits = t.bits; - CHECK_EQ(bits % 8, 0U); - arg_size[i] = bits / 8; + if (t.code == kHandle) { + // specially store pointer type size in OpenCL driver + arg_size[i] = sizeof(void*); + } else { + uint32_t bits = t.bits; + CHECK_EQ(bits % 8, 0U); + arg_size[i] = bits / 8; + } } // initialize the wrapped func. f.Init(this, sptr_to_self, kid_map_.at(name),