diff --git a/src/codegen/build_metal.cc b/src/codegen/build_metal.cc index f2a7e14f9a9f9ce41e7820ed83cd1488f64ea64d..42aa0965ec9d2e29cd92714fe3148f7bb9b51cf0 100644 --- a/src/codegen/build_metal.cc +++ b/src/codegen/build_metal.cc @@ -35,7 +35,7 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); #else LOG(WARNING) << "Metal runtime not enabled, return a source module..."; - return SourceModuleCreate(code, "metal"); + return DeviceSourceModuleCreate(code, "metal", ExtractFuncInfo(funcs), "metal"); #endif // TVM_METAL_RUNTIME } diff --git a/src/codegen/build_opencl.cc b/src/codegen/build_opencl.cc index 499c88a009cdc1ea425e2981b2762393497bd079..51779d3f7a3e0b1ccc96ea048f398a88a42e975a 100644 --- a/src/codegen/build_opencl.cc +++ b/src/codegen/build_opencl.cc @@ -27,7 +27,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs)); #else LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; - return SourceModuleCreate(code, "cl"); + return DeviceSourceModuleCreate(code, "cl", ExtractFuncInfo(funcs), "opencl"); #endif // TVM_OPENCL_RUNTIME } diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index e1f003d32b10157bbcd89ff675a20f75f3cb3a9e..d289d627b310ac9bfe3b547fc1e3700a4473ec45 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -38,7 +38,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { stream->Write(sz); for (runtime::Module im : mod->imports()) { CHECK_EQ(im->imports().size(), 0U) - << "Only support simply one-level hierachy"; + << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); std::string bin; stream->Write(tkey); diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index 0ee5b71d017c44aa884d2aaa07d88e87858aa918..bc99eeeb1d331288f6f7361f55c9bb045eb3c636 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -11,6 +11,7 @@ #include <string> #include <vector> #include <unordered_map> +#include "../runtime/meta_data.h" namespace tvm { namespace codegen { @@ -108,6 +109,19 @@ class CodeGenSourceBase { * \param fmt The code. format. */ runtime::Module SourceModuleCreate(std::string code, std::string fmt); + +/*! + * \brief Create a source module for viewing and limited saving + * \param code The code to be viewed. + * \param fmt The code. format. + * \param fmap The map function information map of each function. + * \param type_key The type_key of the runtime module of this source code + */ +runtime::Module DeviceSourceModuleCreate( + std::string code, + std::string fmt, + std::unordered_map<std::string, runtime::FunctionInfo> fmap, + std::string type_key); } // namespace codegen } // namespace tvm #endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index 1ad2168ae06ee8fa7f5833fbfbf1a6e36e14bb0b..23c0cbd8466e7288182861ab78659bf49fb71ef4 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -5,6 +5,8 @@ */ #include <tvm/runtime/packed_func.h> #include "./codegen_source_base.h" +#include "../runtime/file_util.h" +#include "../runtime/meta_data.h" namespace tvm { namespace codegen { @@ -12,8 +14,14 @@ namespace codegen { using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; + +using runtime::GetFileFormat; +using runtime::GetMetaFilePath; +using runtime::FunctionInfo; +using runtime::SaveBinaryToFile; + // Simulator function -class SourceModuleNode final : public runtime::ModuleNode { +class SourceModuleNode : public runtime::ModuleNode { public: SourceModuleNode(std::string code, std::string fmt) @@ -21,6 +29,7 @@ class SourceModuleNode final : public runtime::ModuleNode { const char* type_key() const { return "source"; } + PackedFunc GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) final { @@ -33,7 +42,7 @@ class SourceModuleNode final : public runtime::ModuleNode { return code_; } - private: + protected: std::string code_; std::string fmt_; }; @@ -44,6 +53,50 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { return runtime::Module(n); } +// supports limited save without cross compile +class DeviceSourceModuleNode final : public SourceModuleNode { + public: + DeviceSourceModuleNode(std::string code, + std::string fmt, + std::unordered_map<std::string, FunctionInfo> fmap, + std::string type_key) + : SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {} + + const char* type_key() const { + return type_key_.c_str(); + } + + void SaveToFile(const std::string& file_name, + const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, fmt_) + << "Can only save to format=" << fmt_; + std::string meta_file = GetMetaFilePath(file_name); + SaveMetaDataToFile(meta_file, fmap_); + SaveBinaryToFile(file_name, code_); + } + + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(fmt_); + stream->Write(fmap_); + stream->Write(code_); + } + + private: + std::unordered_map<std::string, FunctionInfo> fmap_; + std::string type_key_; +}; + +runtime::Module DeviceSourceModuleCreate( + std::string code, + std::string fmt, + std::unordered_map<std::string, FunctionInfo> fmap, + std::string type_key) { + std::shared_ptr<DeviceSourceModuleNode> n = + std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key); + return runtime::Module(n); +} + TVM_REGISTER_GLOBAL("module.source_module_create") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = SourceModuleCreate(args[0], args[1]); diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 85c50e3e9755259cc86ba6e342567ae0479ea45a..54da2d9b0443f2fd8bee75d7ea0bd6087f6844fe 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -16,7 +16,7 @@ namespace tvm { namespace runtime { /*! - * \brief create a cuda module from data. + * \brief create a opencl module from data. * * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl"