diff --git a/HalideIR b/HalideIR index e20e5e9abb3aa43147a90a4ffb3e190f62862970..a3698398faff7fec1c0fa4e4479357651382db75 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit e20e5e9abb3aa43147a90a4ffb3e190f62862970 +Subproject commit a3698398faff7fec1c0fa4e4479357651382db75 diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index e11ee773f1e0063a3c02ac0e3de46acde4af1b7f..352644a6f0c939781c42ed5a2efaf8766064d00b 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -70,7 +70,11 @@ Target CreateTarget(const std::string& target_name, t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { // For now assume rocm schedule for opencl - t->device_type = static_cast<int>(target_name == "rocm" ? kDLROCM : kDLOpenCL); + if (target_name == "opencl") { + t->device_type = kDLOpenCL; + } else { + t->device_type = kDLROCM; + } t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make("gpu")); t->max_num_threads = 256; @@ -78,14 +82,21 @@ Target CreateTarget(const std::string& target_name, t->thread_warp_size = 16; } } else if (target_name == "metal" || target_name == "vulkan") { - t->device_type = static_cast<int>(target_name == "metal" ? kDLMetal : kDLVulkan); + if (target_name == "metal") { + t->device_type = kDLMetal; + } else { + t->device_type = kDLVulkan; + } t->keys_array.push_back(ir::StringImm::make(target_name)); t->keys_array.push_back(ir::StringImm::make("gpu")); t->max_num_threads = 256; } else if (target_name == "opengl") { - t->device_type = kDLGPU; + t->device_type = kOpenGL; t->keys_array.push_back(ir::StringImm::make("opengl")); - } else if (target_name == "stackvm" || target_name == "ext_dev") { + } else if (target_name == "stackvm") { + t->device_type = kDLCPU; + } else if (target_name == "ext_dev") { + t->device_type = kExtDev; } else { LOG(ERROR) << "Unknown target name " << target_name; return target::stackvm(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 1768d6334b5ce7d44fec7b2b51d4e374a14d6f40..3e055cd0079c8f01435df76c9dc47d6d557af35b 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -39,6 +39,7 @@ void MetalWorkspace::GetAttr( *rv = 1; break; } + case kMaxSharedMemoryPerBlock: return; case kComputeVersion: return; case kExist: break; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 6341c9f4b83d0fc438d1e931791a687336673805..1fb52b945f9a2ed64b6343d67e1a20649b4c9af6 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -32,9 +32,9 @@ void OpenCLWorkspace::GetAttr( } CHECK_LT(index, devices.size()) << "Invalid device id " << index; - size_t value; switch (kind) { case kMaxThreadsPerBlock: { + size_t value; OPENCL_CALL(clGetDeviceInfo( devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &value, nullptr)); @@ -45,8 +45,16 @@ void OpenCLWorkspace::GetAttr( *rv = 1; break; } - case kComputeVersion: return; - case kExist: break; + case kMaxSharedMemoryPerBlock: { + cl_ulong value; + OPENCL_CALL(clGetDeviceInfo( + devices[index], CL_DEVICE_LOCAL_MEM_SIZE, + sizeof(cl_ulong), &value, nullptr)); + *rv = static_cast<int64_t>(value); + break; + } + case kComputeVersion: return; + case kExist: break; } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 6a950069192c6b80f259835f92615f35e7cf8e41..a1d77b66c25187e8621fe7232b0f49a7e5957659 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -44,6 +44,7 @@ class ROCMDeviceAPI final : public DeviceAPI { value = 64; break; } + case kMaxSharedMemoryPerBlock: return; case kComputeVersion: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index ef97e99431c28d9c4a797a551868164bfeae33e3..03de48f86fab68282c61314c5e7153fa82700da9 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -51,6 +51,13 @@ void VulkanWorkspace::GetAttr( *rv = value; break; } + case kMaxSharedMemoryPerBlock: { + VkPhysicalDeviceProperties phy_prop; + vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop); + int64_t value = phy_prop.limits.maxComputeSharedMemorySize; + *rv = value; + break; + } case kWarpSize: { *rv = 1; break; diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index fb868569501346cc6e5b1aed802a2d9ffc2c8fda..9cdfef7f6a01afeb5c224f96bed84a4fb79fa543 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -15,6 +15,15 @@ TEST(Expr, Basic) { } +TEST(ExprNodeRef, Basic) { + using namespace tvm; + Var x("x"); + Expr z = max(x + 1 + 2, 100); + const ir::Max* op = z.as<ir::Max>(); + CHECK(op->GetNodeRef().same_as(z)); +} + + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";