From 3b8e70ae03fe705805cfc1961df983deeafd05b9 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Fri, 16 Jun 2017 18:10:15 -0700 Subject: [PATCH] [RUNTIME] Move device_api to include (#185) * [RUNTIME] Move device_api to include * fix doxygen * fix device api * fx --- include/tvm/runtime/c_runtime_api.h | 89 +++++++++++++---------- {src => include/tvm}/runtime/device_api.h | 38 +++++----- src/codegen/verilog/vpi_device_api.cc | 2 +- src/runtime/c_runtime_api.cc | 24 +++++- src/runtime/cpu_device_api.cc | 2 +- src/runtime/cuda/cuda_common.h | 15 ++-- src/runtime/cuda/cuda_device_api.cc | 14 +++- src/runtime/cuda/cuda_module.cc | 3 +- src/runtime/metal/metal_common.h | 2 +- src/runtime/opencl/opencl_common.h | 4 +- src/runtime/rpc/rpc_device_api.cc | 2 +- src/runtime/rpc/rpc_session.cc | 3 +- src/runtime/rpc/rpc_session.h | 2 +- 13 files changed, 123 insertions(+), 77 deletions(-) rename {src => include/tvm}/runtime/device_api.h (84%) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 7c1724e4b..c9f7371de 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -194,43 +194,6 @@ TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, */ TVM_DLL int TVMModFree(TVMModuleHandle mod); -/*! - * \brief Backend function for modules to get function - * from its environment mod_node (its imports and global function). - * - * The user do should not call TVMFuncFree on func. - * - * \note This API is supposed to be used by backend, - * it is not supposed to be used by user. - * - * \param mod_node The module handle. - * \param func_name The name of the function. - * \param out The result function. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *out); - -/*! - * \brief Backend function for running parallel for loop. - * - * \note This API is supposed to be used by backend, - * it is not supposed to be used by user. - * - * \param begin The start of iteration. - * \param end The end of iteration. - * \param lambda The lambda function to be executed. - * \param env The environment of lambda function. - * - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_DLL int TVMBackendParallelFor( - int64_t begin, - int64_t end, - int (*lambda)(int64_t begin, int64_t end, void* env), - void* env); - /*! * \brief Free the function when it is no longer needed. * \param func The function handle @@ -351,6 +314,44 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); TVM_DLL int TVMFuncListGlobalNames(int *out_size, const char*** out_array); +// Backend related functions. +/*! + * \brief Backend function for modules to get function + * from its environment mod_node (its imports and global function). + * + * The user do should not call TVMFuncFree on func. + * + * \note This API is supposed to be used by backend, + * it is not supposed to be used by user. + * + * \param mod_node The module handle. + * \param func_name The name of the function. + * \param out The result function. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, + const char* func_name, + TVMFunctionHandle *out); + +/*! + * \brief Backend function for running parallel for loop. + * + * \note This API is supposed to be used by backend, + * it is not supposed to be used by user. + * + * \param begin The start of iteration. + * \param end The end of iteration. + * \param lambda The lambda function to be executed. + * \param env The environment of lambda function. + * + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_DLL int TVMBackendParallelFor( + int64_t begin, + int64_t end, + int (*lambda)(int64_t begin, int64_t end, void* env), + void* env); + // Array related apis for quick proptyping /*! * \brief Allocate a nd-array's memory, @@ -368,6 +369,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, TVMType dtype, TVMContext ctx, TVMArrayHandle* out); + /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. @@ -385,6 +387,19 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); + +/*! + * \brief Set the runtime stream of current thread to be stream. + * The subsequent calls to the same device_type + * will use the setted stream handle. + * The specific type of stream is runtime device dependent. + * + * \param ctx The context. + * \param handle The stream handle. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMSetStream(TVMContext ctx, TVMStreamHandle handle); + /*! * \brief Wait until all computations on stream completes. * \param ctx The ctx to be synchronized. diff --git a/src/runtime/device_api.h b/include/tvm/runtime/device_api.h similarity index 84% rename from src/runtime/device_api.h rename to include/tvm/runtime/device_api.h index f7ffd8677..21bd8e4c3 100644 --- a/src/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -1,24 +1,29 @@ /*! * Copyright (c) 2016 by Contributors * \file device_api.h - * \brief Device specific API + * \brief Abstract device memory management API */ #ifndef TVM_RUNTIME_DEVICE_API_H_ #define TVM_RUNTIME_DEVICE_API_H_ -#include <tvm/base.h> -#include <tvm/runtime/c_runtime_api.h> #include <string> +#include "./packed_func.h" +#include "./c_runtime_api.h" namespace tvm { namespace runtime { - +/*! + * \brief the query type into GetAttr + */ enum DeviceAttrKind : int { kExist = 0, kMaxThreadsPerBlock = 1, kWarpSize = 2 }; - +/*! + * \brief TVM Runtime Device API, abstracts the device + * specific interface for memory management. + */ class DeviceAPI { public: /*! \brief virtual destructor */ @@ -34,6 +39,7 @@ class DeviceAPI { * \param ctx The device context * \param kind The result kind * \param rv The return value. + * \sa DeviceAttrKind */ virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0; /*! @@ -53,7 +59,6 @@ class DeviceAPI { virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0; /*! * \brief copy data from one place to another - * \param dev The device to perform operation. * \param from The source array. * \param from_offset The byte offeset in the from. * \param to The target array. @@ -77,6 +82,12 @@ class DeviceAPI { * \param stream The stream to be sync. */ virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0; + /*! + * \brief Set the stream + * \param ctx The context to set stream. + * \param stream The stream to be set. + */ + virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {} /*! * \brief Get device API base don context. * \param ctx The context @@ -88,21 +99,6 @@ class DeviceAPI { /*! \brief The device type bigger than this is RPC device */ constexpr int kRPCSessMask = 128; - -/*! - * \brief The name of Device API factory. - * \param type The device type. - */ -inline std::string DeviceName(int type) { - switch (type) { - case kCPU: return "cpu"; - case kGPU: return "gpu"; - case kOpenCL: return "opencl"; - case kMetal: return "metal"; - case kVPI: return "vpi"; - default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; - } -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_DEVICE_API_H_ diff --git a/src/codegen/verilog/vpi_device_api.cc b/src/codegen/verilog/vpi_device_api.cc index a57534850..dd3fdea05 100644 --- a/src/codegen/verilog/vpi_device_api.cc +++ b/src/codegen/verilog/vpi_device_api.cc @@ -4,12 +4,12 @@ * \brief Simulated VPI RAM device. */ #include <tvm/runtime/registry.h> +#include <tvm/runtime/device_api.h> #include <tvm/packed_func_ext.h> #include <cstdlib> #include <unordered_map> #include <map> #include <queue> -#include "../../runtime/device_api.h" #include "./vpi_session.h" namespace tvm { diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index d8dd53d31..bb302fefb 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -8,7 +8,7 @@ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/module.h> #include <tvm/runtime/registry.h> -#include <dmlc/timer.h> +#include <tvm/runtime/device_api.h> #include <array> #include <algorithm> #include <string> @@ -16,11 +16,25 @@ #include <thread> #include <mutex> #include "./runtime_base.h" -#include "./device_api.h" namespace tvm { namespace runtime { +/*! + * \brief The name of Device API factory. + * \param type The device type. + */ +inline std::string DeviceName(int type) { + switch (type) { + case kCPU: return "cpu"; + case kGPU: return "gpu"; + case kOpenCL: return "opencl"; + case kMetal: return "metal"; + case kVPI: return "vpi"; + default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; + } +} + class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; @@ -380,6 +394,12 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, API_END(); } +int TVMSetStream(TVMContext ctx, TVMStreamHandle stream) { + API_BEGIN(); + DeviceAPIManager::Get(ctx)->SetStream(ctx, stream); + API_END(); +} + int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { API_BEGIN(); DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 7909936a2..39b757a12 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -4,9 +4,9 @@ */ #include <dmlc/logging.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/device_api.h> #include <cstdlib> #include <cstring> -#include "./device_api.h" namespace tvm { namespace runtime { diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index acc8809b0..8dfb8de35 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -34,13 +34,14 @@ namespace runtime { << "CUDA: " << cudaGetErrorString(e); \ } - -/*! - * \brief Compile code into ptx using NVRTC - * \param code The cuda code. - * \return The PTX code. - */ -std::string NVRTCCompile(const std::string& code); +/*! \brief Thread local workspace */ +class CUDAThreadEntry { + public: + /*! \brief The cuda stream */ + cudaStream_t stream{nullptr}; + // get the threadlocal workspace + static CUDAThreadEntry* ThreadLocal(); +}; } // namespace runtime } // namespace tvm #endif // TVM_CUDA_RUNTIME diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index f9d93cf48..2c73ab658 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -4,13 +4,14 @@ * \brief GPU specific API */ #include <tvm/runtime/config.h> +#include <tvm/runtime/device_api.h> #if TVM_CUDA_RUNTIME #include <dmlc/logging.h> +#include <dmlc/thread_local.h> #include <tvm/runtime/registry.h> #include <cuda_runtime.h> #include "./cuda_common.h" -#include "../device_api.h" namespace tvm { namespace runtime { @@ -92,6 +93,11 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream))); } + void SetStream(TVMContext ctx, TVMStreamHandle stream) final { + CUDAThreadEntry::ThreadLocal() + ->stream = static_cast<cudaStream_t>(stream); + } + private: static void GPUCopy(const void* from, void* to, @@ -106,6 +112,12 @@ class CUDADeviceAPI final : public DeviceAPI { } }; +typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; + +CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { + return CUDAThreadStore::Get(); +} + TVM_REGISTER_GLOBAL("device_api.gpu") .set_body([](TVMArgs args, TVMRetValue* rv) { static CUDADeviceAPI inst; diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 28a36078f..29623bf58 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -167,6 +167,7 @@ class CUDAWrappedFunc { if (fcache_[device_id] == nullptr) { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } + CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); CUDA_DRIVER_CALL(cuLaunchKernel( fcache_[device_id], @@ -176,7 +177,7 @@ class CUDAWrappedFunc { wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), - 0, nullptr, void_args, 0)); + 0, strm, void_args, 0)); } private: diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 03910569e..be9b852f6 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -16,11 +16,11 @@ #include <tvm/runtime/config.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/device_api.h> #include <dmlc/logging.h> #include <mutex> #include <string> #include <vector> -#include "../device_api.h" namespace tvm { namespace runtime { diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 45f93708f..1ef0123f8 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -9,10 +9,10 @@ #include <tvm/runtime/config.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/device_api.h> #include <dmlc/logging.h> -#if TVM_OPENCL_RUNTIME -#include "../device_api.h" +#if TVM_OPENCL_RUNTIME #ifdef __APPLE__ #include <OpenCL/opencl.h> #else diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index ce6c02bf1..e8cc5b94a 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -4,8 +4,8 @@ */ #include <dmlc/logging.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/device_api.h> #include "./rpc_session.h" -#include "../device_api.h" namespace tvm { namespace runtime { diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index ca77c02af..0dc8c91ea 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -4,11 +4,12 @@ * \brief RPC session for remote function call. */ #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/device_api.h> +#include <tvm/runtime/registry.h> #include <memory> #include <array> #include <chrono> #include "./rpc_session.h" -#include "../device_api.h" namespace tvm { namespace runtime { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index f56a9f87c..bca1ee406 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -7,9 +7,9 @@ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/device_api.h> #include <mutex> #include <string> -#include "../device_api.h" #include "../../common/socket.h" namespace tvm { -- GitLab