diff --git a/examples/graph_executor/src/op_attr_types.h b/examples/graph_executor/src/op_attr_types.h index 2da9e2391b570dd3a2fd85df57da3a914d7fa8fb..c7b4a55e5eba1c09b2bfc188f4ca21450737e887 100644 --- a/examples/graph_executor/src/op_attr_types.h +++ b/examples/graph_executor/src/op_attr_types.h @@ -37,8 +37,7 @@ using DLTypeVector = std::vector<DLDataType>; */ using FTVMCompute = std::function< Array<Tensor> - (const NodeAttrs& attrs, - const Array<Tensor>& inputs)>; + (const NodeAttrs& attrs, const Array<Tensor>& inputs)>; /*! * \brief Build the computation schedule for diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc new file mode 100644 index 0000000000000000000000000000000000000000..2225cbb1752eee93cd928dcbc74c5ec15511d57d --- /dev/null +++ b/src/codegen/llvm/codegen_arm.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_arm.cc + * \brief ARM specific code generator + */ +#ifdef TVM_LLVM_VERSION +#include "./codegen_llvm.h" + +namespace tvm { +namespace codegen { + +// ARM specific code generator, this is used as an example on +// how to override behavior llvm code generator for specific target +class CodeGenARM final : public CodeGenLLVM { + public: + void InitTarget(llvm::TargetMachine* tm) final { + // set native vector bits. + native_vector_bits_ = 16 * 8; + CodeGenLLVM::InitTarget(tm); + } +}; + +TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenARM(); + *rv = static_cast<void*>(cg); + }); + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index e32dc9fa7de882d27840da96eefe729a1dc56e16..1924f212f1c693162779d42c6abb77bf43cd3392 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -13,6 +13,18 @@ namespace tvm { namespace codegen { +std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) { + std::string target = tm->getTarget().getName(); + std::string factory_name = "tvm.codegen.llvm.target_" + target; + const PackedFunc* f = runtime::Registry::Get(factory_name); + if (f != nullptr) { + void* handle = (*f)(); + return std::unique_ptr<CodeGenLLVM>(static_cast<CodeGenLLVM*>(handle)); + } else { + return std::unique_ptr<CodeGenLLVM>(new CodeGenLLVM()); + } +} + void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx) { @@ -93,17 +105,17 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { data_layout_.reset(new llvm::DataLayout(module_.get())); // initialize native vector bits std::string target = tm->getTarget().getName(); - if (target == "arm") { - native_vector_bits_ = 16 * 8; - } else if (target == "x86-64") { + if (target == "x86-64") { // for avx512 native_vector_bits_ = 64 * 8; } else if (target == "x86") { native_vector_bits_ = 32 * 8; } else { - native_vector_bits_ = 32 * 8; - LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8 - << " for target " << target; + if (native_vector_bits_ == 0) { + native_vector_bits_ = 32 * 8; + LOG(WARNING) << "set native vector to be " << native_vector_bits_ / 8 + << " for target " << target; + } } } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index bf7e426722042a37b984dc06c0082e994163d7f6..07a809d1a210b326c5bf7c68bff6e3b81976f0f5 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -28,6 +28,12 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value* (const Expr&)>, public StmtFunctor<void(const Stmt&)> { public: + /*! + * \brief Create new code generator based on target machine. + * \param tm The target machine + * \return The created llvm generator. + */ + static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm); /*! * \brief Initialize the code generator with given context * \param module_name The name of the module. @@ -136,6 +142,8 @@ class CodeGenLLVM : // do a scalarize call with f llvm::Value* CreateScalarizedCall( const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); + // Initialize target + virtual void InitTarget(llvm::TargetMachine* tm); // apply optimization on the module. virtual void Optimize(); // Get the maximim storage align bits of buffer pointer given storage scope. @@ -216,8 +224,6 @@ class CodeGenLLVM : // if not directly finalize function and pass on return code. // return the end block after the check llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); - // Initialize target - void InitTarget(llvm::TargetMachine* tm); // Add a function to set global module context void InitGlobalContext(); // add alias information. diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index bf975e736b31037b20fd07dd9d02b75934e9b770..13d90820abd04e9cecc7bd878a83d5bc4e5ea1e1 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -36,6 +36,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>); + } // namespace llvm } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 3b7d4753572d5b98e5bd4fb94d9e28f74bd81e63..16f579ed6b50288e98857eaa367925cf27dd284a 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -101,14 +101,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { tm_ = GetLLVMTargetMachine(target); CHECK_NE(funcs.size(), 0U); ctx_ = std::make_shared<llvm::LLVMContext>(); - CodeGenLLVM cg; + std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_); entry_func_ = funcs[0]->name; - cg.Init(funcs[0]->name, tm_, ctx_.get()); + cg->Init(funcs[0]->name, tm_, ctx_.get()); for (LoweredFunc f : funcs) { - cg.AddFunction(f); + cg->AddFunction(f); } - cg.AddMainFunction(funcs[0]->name); - module_ = cg.Finish(); + cg->AddMainFunction(funcs[0]->name); + module_ = cg->Finish(); mptr_ = module_.get(); } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 4b6e53a3b3d4e7cff8d4a80dc546f9e3dbb9f63f..4a496063daeca03e4513ffb8c25e296d10d5112b 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -26,27 +26,33 @@ class IntrinInjecter : public IRMutator { Expr Mutate_(const Call* op, const Expr& e) final { if (op->call_type == Call::Intrinsic || op->call_type == Call::PureIntrinsic) { - for (size_t i = 0; i < patterns_.size(); ++i) { - std::string& p = patterns_[i]; - size_t psize = p.length(); - p.resize(psize + op->name.length()); - op->name.copy(&p[0] + psize, op->name.length()); - const runtime::PackedFunc* f = runtime::Registry::Get(p); - p.resize(psize); - // if pattern exists. - if (f != nullptr) { - Expr r = (*f)(e); - CHECK(r.defined()) << "intrinsic rule must always return valid Expr"; - if (!r.same_as(e)) { - return this->Mutate(r); - } - } - } + Expr r = ApplyPattern(op->name, e); + if (r.defined()) return r; } return IRMutator::Mutate_(op, e); } private: + Expr ApplyPattern(const std::string& name, const Expr& e) { + for (size_t i = 0; i < patterns_.size(); ++i) { + std::string& p = patterns_[i]; + size_t psize = p.length(); + p.resize(psize + name.length()); + name.copy(&p[0] + psize, name.length()); + const runtime::PackedFunc* f = runtime::Registry::Get(p); + p.resize(psize); + // if pattern exists. + if (f != nullptr) { + Expr r = (*f)(e); + CHECK(r.defined()) << "intrinsic rule must always return valid Expr"; + if (!r.same_as(e)) { + return this->Mutate(r); + } + } + } + return Expr(); + } + // patterns std::vector<std::string> patterns_; };