diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 391901730e57d539c081e6befee6d2c65abb98a5..d4186e8f816701db40e65979b887a57b44ce929f 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -85,10 +85,13 @@ EXPORT Target stackvm(); } // namespace target +class BuildConfig; + /*! * \brief Container for build configuration options */ -struct BuildConfig { +class BuildConfigNode : public Node { + public: /*! * \brief The data alignment to use when constructing buffers. If this is set to * -1, then TVM's internal default will be used @@ -126,10 +129,31 @@ struct BuildConfig { /*! \brief Whether to partition const loop */ bool partition_const_loop = false; - BuildConfig() { + void VisitAttrs(AttrVisitor* v) final { + v->Visit("data_alignment", &data_alignment); + v->Visit("offset_factor", &offset_factor); + v->Visit("double_buffer_split_loop", &double_buffer_split_loop); + v->Visit("auto_unroll_max_step", &auto_unroll_max_step); + v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth); + v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent); + v->Visit("unroll_explicit", &unroll_explicit); + v->Visit("restricted_func", &restricted_func); + v->Visit("detect_global_barrier", &detect_global_barrier); + v->Visit("partition_const_loop", &partition_const_loop); } + + static constexpr const char* _type_key = "BuildConfig"; + TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node); }; +TVM_DEFINE_NODE_REF(BuildConfig, BuildConfigNode); + +/*! +* \brief Construct a BuildConfig containing a new BuildConfigNode +* \return The new BuildConfig +*/ +EXPORT BuildConfig build_config(); + /*! * \brief Build a LoweredFunc given a schedule, args and binds * \param sch The schedule to lower. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 8b52b11d86e7ebfa820a6b94471aec0750908000..9c442a07425df0130033cd7d06a92aacb5cf2734 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -6,7 +6,9 @@ LoweredFunc and compiled Module. from __future__ import absolute_import as _abs import warnings import types +import os +from ._ffi.node import NodeBase, register_node from . import api from . import tensor from . import schedule @@ -18,6 +20,7 @@ from . import module from . import codegen from . import ndarray from . import target as _target +from . import make class DumpIR(object): """Dump IR for each pass. @@ -95,16 +98,23 @@ class DumpIR(object): BuildConfig.current.add_lower_pass = self._old_custom_pass DumpIR.scope_level -= 1 -class BuildConfig(object): +@register_node +class BuildConfig(NodeBase): """Configuration scope to set a build config option. - Parameters - ---------- - kwargs - Keyword arguments of configurations to set. + Note + ---- + This object is backed by node system in C++, with arguments that can be + exchanged between python and C++. + + Do not construct directly, use build_config instead. + + The fields that are backed by the C++ node are immutable once an instance + is constructed. See _node_defaults for the fields. """ + current = None - defaults = { + _node_defaults = { "auto_unroll_max_step": 0, "auto_unroll_max_depth": 8, "auto_unroll_max_extent": 0, @@ -114,30 +124,28 @@ class BuildConfig(object): "offset_factor": 0, "data_alignment": -1, "restricted_func": True, - "double_buffer_split_loop": 1, - "add_lower_pass": None, - "dump_pass_ir": False + "double_buffer_split_loop": 1 } - def __init__(self, **kwargs): + + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + super(BuildConfig, self).__init__(handle) + self.handle = handle self._old_scope = None self._dump_ir = DumpIR() - for k, _ in kwargs.items(): - if k not in BuildConfig.defaults: - raise ValueError( - "invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys())) - self._attr = kwargs - - def __getattr__(self, name): - if name not in self._attr: - return BuildConfig.defaults[name] - return self._attr[name] + self.dump_pass_ir = False + self.add_lower_pass = None def __enter__(self): # pylint: disable=protected-access self._old_scope = BuildConfig.current - attr = BuildConfig.current._attr.copy() - attr.update(self._attr) - self._attr = attr BuildConfig.current = self if self.dump_pass_ir is True: self._dump_ir.enter() @@ -149,8 +157,11 @@ class BuildConfig(object): self._dump_ir.exit() BuildConfig.current = self._old_scope - -BuildConfig.current = BuildConfig() + def __setattr__(self, name, value): + if name in BuildConfig._node_defaults: + raise AttributeError( + "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) + return super(BuildConfig, self).__setattr__(name, value) def build_config(**kwargs): """Configure the build behavior by setting config variables. @@ -206,8 +217,18 @@ def build_config(**kwargs): config: BuildConfig The build configuration """ - return BuildConfig(**kwargs) - + node_args = {k: v if k not in kwargs else kwargs[k] + for k, v in BuildConfig._node_defaults.items()} + config = make.node("BuildConfig", **node_args) + + for k in kwargs: + if not k in node_args: + setattr(config, k, kwargs[k]) + return config + +if not os.environ.get("TVM_USE_RUNTIME_LIB", False): + # BuildConfig is not available in tvm_runtime + BuildConfig.current = build_config() def get_binds(args, binds=None): """Internal function to get binds and arg_list given arguments. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 3b5916ea5fecc09aeceac0b8bed1f453c450bcf1..de388cf0b51fcf8d6e59f5614246fce39b0fce78 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -10,6 +10,7 @@ #include <tvm/buffer.h> #include <tvm/schedule.h> #include <tvm/api_registry.h> +#include <tvm/build_module.h> namespace tvm { diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 2e8e5bb278eb4124388525856f3fb2435fef1c1f..cca09a966e2191883ceaf208a5d3e478d8df8103 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -179,7 +179,7 @@ void GetBinds(const Array<Tensor>& args, for (const auto &x : args) { if (out_binds->find(x) == out_binds->end()) { auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, - config.data_alignment, config.offset_factor); + config->data_alignment, config->offset_factor); out_binds->Set(x, buf); out_arg_list->push_back(buf); } else { @@ -218,14 +218,14 @@ Stmt BuildStmt(Schedule sch, stmt = ir::StorageFlatten(stmt, out_binds, 64); stmt = ir::CanonicalSimplify(stmt); if (loop_partition) { - stmt = ir::LoopPartition(stmt, config.partition_const_loop); + stmt = ir::LoopPartition(stmt, config->partition_const_loop); } stmt = ir::VectorizeLoop(stmt); stmt = ir::InjectVirtualThread(stmt); - stmt = ir::InjectDoubleBuffer(stmt, config.double_buffer_split_loop); + stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = ir::StorageRewrite(stmt); - stmt = ir::UnrollLoop(stmt, config.auto_unroll_max_step, config.auto_unroll_max_depth, - config.auto_unroll_max_extent, config.unroll_explicit); + stmt = ir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, + config->auto_unroll_max_extent, config->unroll_explicit); // Phase 2 stmt = ir::Simplify(stmt); @@ -243,7 +243,7 @@ Array<LoweredFunc> lower(Schedule sch, const BuildConfig& config) { Array<NodeRef> out_arg_list; auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); - return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config.restricted_func) }); + return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } runtime::Module build(const Array<LoweredFunc>& funcs, @@ -266,7 +266,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs, for (const auto &x : funcs) { if (x->func_type == kMixedFunc) { auto func = x; - if (config.detect_global_barrier) { + if (config->detect_global_barrier) { func = ir::ThreadSync(func, "global"); } @@ -321,4 +321,27 @@ runtime::Module build(const Array<LoweredFunc>& funcs, return mhost; } + +BuildConfig build_config() { + return BuildConfig(std::make_shared<BuildConfigNode>()); +} + +TVM_REGISTER_NODE_TYPE(BuildConfigNode); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<BuildConfigNode>([](const BuildConfigNode *op, IRPrinter *p) { + p->stream << "build_config("; + p->stream << "data_alignment=" << op->data_alignment << ", "; + p->stream << "offset_factor=" << op->offset_factor << ", "; + p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; + p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; + p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; + p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; + p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; + p->stream << "restricted_func=" << op->restricted_func << ", "; + p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; + p->stream << "partition_const_loop=" << op->partition_const_loop; + p->stream << ")"; +}); + } // namespace tvm diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index fc3f6ac9324d3e946170f93735c105ae52631fc0..fe0a9151cc2cc1e333b5713fa512572c8efd8da7 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -27,7 +27,7 @@ TEST(BuildModule, Basic) { auto args = Array<Tensor>({ A, B, C }); std::unordered_map<Tensor, Buffer> binds; - BuildConfig config; + auto config = build_config(); auto target = target::llvm(); auto lowered = lower(s, args, "func", binds, config);