From 17e7e3d50a6d265d07396ff2dff961bd4ce8e76f Mon Sep 17 00:00:00 2001
From: alex-weaver <awsweaver@gmail.com>
Date: Tue, 5 Dec 2017 20:06:46 +0000
Subject: [PATCH] Port build_module.py to C++ (#667)

* Port build_module.py to C++

* Fix lint errors

* Fix more lint errors

* Fix more lint errors

* Fix more lint errors

* Fix build error

* Implemented style fixes

* Fix lint errors

* Added function to construct target from string
lower now returns array

* Fix lint error

* Implemented review changes - style & Target options -> std::vector

* Fixed lint, argument alignment and added unit test

* Changed test to target LLVM, fixed sign compare warnings

* Reverted unit test to CUDA, changed Jenkinsfile to enable GPU for C++ tests

* Slight change to Jenkinsfile

* Changed build_module test from CUDA to LLVM

* Added function var() to construct a Var instance.
Changed implementation of LLVMEnabled()

* Reverted Jenkinsfile
---
 include/tvm/build_module.h     | 153 ++++++++++++++++
 include/tvm/expr.h             |   7 +
 include/tvm/schedule.h         |  10 +-
 src/codegen/build_module.cc    | 314 +++++++++++++++++++++++++++++++++
 src/lang/expr.cc               |   4 +
 tests/cpp/build_module_test.cc |  42 +++++
 6 files changed, 525 insertions(+), 5 deletions(-)
 create mode 100644 include/tvm/build_module.h
 create mode 100644 src/codegen/build_module.cc
 create mode 100644 tests/cpp/build_module_test.cc

diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h
new file mode 100644
index 000000000..a1563e8e7
--- /dev/null
+++ b/include/tvm/build_module.h
@@ -0,0 +1,153 @@
+/*!
+*  Copyright (c) 2017 by Contributors
+* \file build_module.h
+* \brief Functions for compiling ops.
+*/
+#ifndef TVM_BUILD_MODULE_H_
+#define TVM_BUILD_MODULE_H_
+
+#include <string>
+#include <vector>
+#include "./tvm/runtime/packed_func.h"
+#include "./tvm/schedule_pass.h"
+#include "./tvm/lowered_func.h"
+
+namespace tvm {
+
+/*!
+* \brief Container for target device information.
+* Use target::llvm, target::cuda etc functions instead of constructing directly.
+*/
+struct Target {
+  /*! \brief The name of the target device */
+  std::string target_name;
+  /*! \brief The type of the target device */
+  DLDeviceType device_type;
+  /*! \brief The maximum threads that a schedule should use for this device */
+  int max_num_threads = 1;
+  /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
+  int thread_warp_size = 1;
+  /*! \brief Keys for this target */
+  std::unordered_set<std::string> keys;
+  /*! \brief Options for this target */
+  std::vector<std::string> options;
+
+  Target(const std::string& target_name,
+         DLDeviceType device_type,
+         int max_num_threads,
+         int thread_warp_size,
+         const std::unordered_set<std::string>& keys,
+         const std::vector<std::string>& options) :
+    target_name(target_name),
+    device_type(device_type),
+    max_num_threads(max_num_threads),
+    thread_warp_size(thread_warp_size),
+    keys(keys),
+    options(options) {
+  }
+
+  /*! \return the full device string to pass to codegen::Build */
+  EXPORT std::string str() const;
+
+  /*!
+   * \brief Create a Target given a string
+   * \param target_str the string to parse
+   */
+  EXPORT static Target create(const std::string& target_str);
+};
+
+/*! \brief This namespace provides functions to construct Target instances */
+namespace target {
+/*! \return A target for LLVM */
+EXPORT Target llvm();
+
+/*! \return A target for CUDA */
+EXPORT Target cuda();
+
+/*! \return A target for ROCm */
+EXPORT Target rocm();
+
+/*! \return A target for Metal */
+EXPORT Target metal();
+
+/*! \return A target for rasp */
+EXPORT Target rasp();
+
+/*! \return A target for stackvm */
+EXPORT Target stackvm();
+
+}  // namespace target
+
+/*!
+* \brief Container for build configuration options
+*/
+struct BuildConfig {
+  /*!
+   * \brief The data alignment to use when constructing buffers. If this is set to
+   * -1, then TVM's internal default will be used
+   */
+  int data_alignment = -1;
+  /*!
+   * \brief The offset factor to use when constructing buffers. If this is set to
+   * 0, then the offset field is not used.
+   */
+  int offset_factor = 0;
+
+  /*!
+   * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
+   * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
+   */
+  int double_buffer_split_loop = 1;
+  /*! \brief Threshold of number of steps in the loop to be automatically unrolled */
+  int auto_unroll_max_step = 0;
+  /*! \brief The maximum nested level of loops that can be automatically unrolled */
+  int auto_unroll_max_depth = 8;
+  /*! \brief The maximum extent of loop that will be unrolled */
+  int auto_unroll_max_extent = 0;
+  /*!
+   * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
+   * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
+   */
+  bool unroll_explicit = true;
+
+  /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
+  bool restricted_func = true;
+
+  /*! \brief Whether to detect global barrier */
+  bool detect_global_barrier = false;
+
+  BuildConfig() {
+  }
+};
+
+/*!
+* \brief Build a LoweredFunc given a schedule, args and binds
+* \param sch The schedule to lower.
+* \param args The arguments to the function.
+* \param name The name of the lowered function.
+* \param binds Buffer assignments.
+* \param config The build configuration.
+* \return The lowered function.
+*/
+EXPORT Array<LoweredFunc> lower(Schedule sch,
+                                const Array<Tensor>& args,
+                                const std::string& name,
+                                const std::unordered_map<Tensor, Buffer>& binds,
+                                const BuildConfig& config);
+
+/*!
+* \brief Build a device and host module for a specific target from an array of lowered functions.
+* \param funcs The functions to be built.
+* \param target The target device to build for.
+* \param target_host The target for building host code. If null, a suitable default will be used.
+* \param config The build configuration.
+* \return The built module.
+*/
+EXPORT runtime::Module build(const Array<LoweredFunc>& funcs,
+                             const Target& target,
+                             Target* target_host,
+                             const BuildConfig& config);
+
+}  // namespace tvm
+
+#endif  // TVM_BUILD_MODULE_H_
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 4e4e25c0c..c0f4fea24 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -291,6 +291,13 @@ inline const char* IterVarType2String(IterVarType t) {
   return "Unknown";
 }
 
+/*!
+ * \brief Construct a new Var expression
+ * \param name_hint The name hint for the expression
+ * \param t The type of the expression
+ */
+TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
+
 /*
  * \brief Template function to convert Map to unordered_map
  *  Sometimes useful for API gluing when internal uses unordered_map
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index a0e4a2c9e..3efc31774 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -81,7 +81,7 @@ class Stage : public NodeRef {
    * \param thread_ivar The thread axis to be binded.
    * \return reference to self.
    */
-  Stage& bind(IterVar ivar, IterVar thread_ivar);
+  EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar);
   /*!
    * \brief Set predicate under which store to the array can be performed.
    *  Use this when there are duplicated threads doing the same store and we only
@@ -110,7 +110,7 @@ class Stage : public NodeRef {
    * \param p_inner The result inner domain.
    * \return reference to self.
    */
-  Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
+  EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
   /*!
    * \brief Split the iteration with given number of parts.
    *
@@ -248,13 +248,13 @@ class Schedule : public NodeRef {
    * \brief Get the stage corresponds to the op
    * \param op The operation.
    */
-  Stage operator[](const Operation& op);
+  EXPORT Stage operator[](const Operation& op);
   /*!
    * \brief Short hand for getting the stage of tensor's operation.
    * \param tensor The tensor
    * \return The stage corresponding to the tensor's op
    */
-  Stage operator[](const Tensor& tensor) {
+  EXPORT Stage operator[](const Tensor& tensor) {
     return this->operator[](tensor->op);
   }
   /*!
@@ -493,7 +493,7 @@ class ScheduleNode : public Node {
    * \param ops The ops to be scheduled.
    * \return sch The created Schedule.
    */
-  static Schedule make(Array<Operation> ops);
+  EXPORT static Schedule make(Array<Operation> ops);
 
   static constexpr const char* _type_key = "Schedule";
   TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc
new file mode 100644
index 000000000..d936b873b
--- /dev/null
+++ b/src/codegen/build_module.cc
@@ -0,0 +1,314 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ *  Compile executable modules.
+ * \file build_module.cc
+ */
+#include <tvm/build_module.h>
+#include <tvm/operation.h>
+#include <tvm/ir_pass.h>
+#include <tvm/codegen.h>
+
+
+namespace tvm {
+
+std::string Target::str() const {
+  std::ostringstream result;
+  result << target_name;
+  for (const auto &x : options) {
+    result << " " << x;
+  }
+  return result.str();
+}
+
+Target TargetFromName(const std::string& name) {
+  if (name == "llvm") {
+    return target::llvm();
+  } else if (name == "cuda" || name == "nvptx") {
+    return target::cuda();
+  } else if (name == "rocm" || name == "opencl") {
+    /* For now, assume rocm schedule for opencl */
+    return target::rocm();
+  } else if (name == "metal") {
+    return target::metal();
+  } else if (name == "stackvm" || name == "ext_dev") {
+    return target::stackvm();
+  } else {
+    LOG(ERROR) << "Unknown target name " << name;
+    return target::stackvm();
+  }
+}
+
+bool StartsWith(const std::string& str, const std::string& pattern) {
+  return str.compare(0, pattern.length(), pattern) == 0;
+}
+
+std::string GetDeviceName(const std::string& target_str) {
+  std::istringstream ss(target_str);
+  std::string target_name;
+  ss >> target_name;
+
+  std::string item;
+  while (ss >> item) {
+    if (StartsWith(item, "-device=")) {
+      return item.substr(std::string("-device=").length());
+    }
+  }
+
+  return "";
+}
+
+Target Target::create(const std::string& target_str) {
+  if (target_str.length() == 0) {
+    LOG(ERROR) << "target_str must not be empty";
+  }
+
+  std::istringstream ss(target_str);
+  std::string target_name;
+
+  ss >> target_name;
+  auto device_name = GetDeviceName(target_str);
+
+  auto result = device_name == "rasp" ?
+    target::rasp() :
+    TargetFromName(target_name);
+
+  std::string item;
+  while (ss >> item) {
+    result.options.push_back(item);
+  }
+
+  return result;
+}
+
+namespace target {
+Target llvm() {
+  std::unordered_set<std::string> keys({ "llvm", "cpu" });
+  std::vector<std::string> options;
+  return Target("llvm", kDLCPU, 512, 1, keys, options);
+}
+
+Target cuda() {
+  std::unordered_set<std::string> keys({ "cuda", "gpu" });
+  std::vector<std::string> options;
+  return Target("cuda", kDLGPU, 512, 32, keys, options);
+}
+
+Target rocm() {
+  std::unordered_set<std::string> keys({ "rocm", "gpu" });
+  std::vector<std::string> options;
+  return Target("rocm", kDLROCM, 256, 1, keys, options);
+}
+
+Target metal() {
+  std::unordered_set<std::string> keys({ "gpu" });
+  std::vector<std::string> options;
+  return Target("metal", kDLMetal, 256, 1, keys, options);
+}
+
+Target rasp() {
+  std::unordered_set<std::string> keys({ "llvm", "cpu" });
+  std::vector<std::string> options({
+    "-device=rasp",
+    "-mtriple=armv7l-none-linux-gnueabihf",
+    "-mcpu=cortex-a53",
+    "-mattr=+neon"
+  });
+  return Target("llvm", kDLCPU, 512, 1, keys, options);
+}
+
+Target stackvm() {
+  std::unordered_set<std::string> keys({ "stackvm", "cpu" });
+  std::vector<std::string> options;
+  return Target("stackvm", kDLCPU, 512, 1, keys, options);
+}
+}  // namespace target
+
+bool LLVMEnabled() {
+  const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
+  return pf != nullptr;
+}
+
+/*! \return The default host target for a given device target */
+Target DefaultTargetHost(Target target) {
+  if (target.device_type == kDLCPU) {
+    return target;
+  } else {
+    if (LLVMEnabled()) {
+      return target::llvm();
+    } else {
+      return target::stackvm();
+    }
+  }
+}
+
+Buffer BufferWithOffsetAlignment(Array<Expr> shape,
+                                 Type dtype,
+                                 std::string name,
+                                 int data_alignment,
+                                 int offset_factor) {
+  auto data = Var(name, Handle());
+
+  Expr elem_offset;
+  if (offset_factor != 0) {
+    elem_offset = Var(name + "_elem_offset", shape[0].type());
+  } else {
+    elem_offset = Expr();
+  }
+
+  return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
+    data_alignment, offset_factor);
+}
+
+void GetBinds(const Array<Tensor>& args,
+              const std::unordered_map<Tensor, Buffer>& binds,
+              Map<Tensor, Buffer>* out_binds,
+              Array<NodeRef>* out_arg_list,
+              const BuildConfig& config) {
+  *out_binds = binds;
+
+  for (const auto &x : args) {
+    if (out_binds->find(x) == out_binds->end()) {
+      auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
+        config.data_alignment, config.offset_factor);
+      out_binds->Set(x, buf);
+      out_arg_list->push_back(buf);
+    } else {
+      out_arg_list->push_back((*out_binds)[x]);
+    }
+  }
+}
+
+/*!
+* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes.
+* \param sch The schedule to build.
+* \param args The arguments for the schedule.
+* \param binds Buffer assignments.
+* \param loop_partition True if the LoopPartition pass should be included.
+* \param out_arg_list Returns the arguments for the Stmt.
+* \param config The build configuration.
+* \return The built Stmt.
+*/
+Stmt BuildStmt(Schedule sch,
+               const Array<Tensor>& args,
+               const std::unordered_map<Tensor, Buffer>& binds,
+               bool loop_partition,
+               Array<NodeRef> *out_arg_list,
+               const BuildConfig& config) {
+  Map<Tensor, Buffer> out_binds;
+  GetBinds(args, binds, &out_binds, out_arg_list, config);
+
+  sch = sch.normalize();
+
+  // Phase 0
+  auto bounds = schedule::InferBound(sch);
+  auto stmt = schedule::ScheduleOps(sch, bounds);
+  stmt = ir::InjectPrefetch(stmt);
+
+  // Phase 1
+  stmt = ir::StorageFlatten(stmt, out_binds, 64);
+  stmt = ir::CanonicalSimplify(stmt);
+  if (loop_partition) {
+    stmt = ir::LoopPartition(stmt);
+  }
+  stmt = ir::VectorizeLoop(stmt);
+  stmt = ir::InjectVirtualThread(stmt);
+  stmt = ir::InjectDoubleBuffer(stmt, config.double_buffer_split_loop);
+  stmt = ir::StorageRewrite(stmt);
+  stmt = ir::UnrollLoop(stmt, config.auto_unroll_max_step, config.auto_unroll_max_depth,
+    config.auto_unroll_max_extent, config.unroll_explicit);
+
+  // Phase 2
+  stmt = ir::Simplify(stmt);
+  stmt = ir::LowerStorageAccessInfo(stmt);
+  stmt = ir::RemoveNoOp(stmt);
+  stmt = ir::RewriteUnsafeSelect(stmt);
+
+  return stmt;
+}
+
+Array<LoweredFunc> lower(Schedule sch,
+                         const Array<Tensor>& args,
+                         const std::string& name,
+                         const std::unordered_map<Tensor, Buffer>& binds,
+                         const BuildConfig& config) {
+  Array<NodeRef> out_arg_list;
+  auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
+  return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config.restricted_func) });
+}
+
+runtime::Module build(const Array<LoweredFunc>& funcs,
+                      const Target& target,
+                      Target* target_host,
+                      const BuildConfig& config) {
+  std::unordered_set<std::string> all_names;
+  for (const auto &x : funcs) {
+    CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
+    all_names.insert(x->name);
+  }
+
+  Target target_host_val = target_host == nullptr ?
+    DefaultTargetHost(target) :
+    *target_host;
+
+  Array<LoweredFunc> fhost;
+  Array<LoweredFunc> fdevice;
+
+  for (const auto &x : funcs) {
+    if (x->func_type == kMixedFunc) {
+      auto func = x;
+      if (config.detect_global_barrier) {
+        func = ir::ThreadSync(func, "global");
+      }
+
+      func = ir::ThreadSync(func, "shared");
+      func = ir::LowerThreadAllreduce(func, target.thread_warp_size);
+      auto fsplits = ir::SplitHostDevice(func);
+      fhost.push_back(fsplits[0]);
+      for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
+        fdevice.push_back(*f);
+      }
+    } else if (x->func_type == kHostFunc) {
+      fhost.push_back(x);
+    } else if (x->func_type == kDeviceFunc) {
+      fdevice.push_back(x);
+    } else {
+      LOG(FATAL) << "unknown function type " << x->func_type;
+    }
+  }
+
+  if (target.keys.count("gpu") > 0 && fdevice.size() == 0) {
+    LOG(WARNING) << "Specified target " + target.str() +
+      " but cannot find device code. Did you forget to bind?";
+  }
+
+  for (size_t i = 0; i < fhost.size(); ++i) {
+    auto func = fhost[i];
+    func = ir::BindDeviceType(func, target.device_type);
+    func = ir::LowerTVMBuiltin(func);
+    fhost.Set(i, func);
+  }
+
+
+  for (size_t i = 0; i < fdevice.size(); ++i) {
+    auto func = fdevice[i];
+    func = ir::LowerIntrin(func, target.target_name);
+    fdevice.Set(i, func);
+  }
+
+  for (size_t i = 0; i < fhost.size(); ++i) {
+    auto func = fhost[i];
+    func = ir::LowerIntrin(func, target_host_val.target_name);
+    func = ir::CombineContextCall(func);
+    fhost.Set(i, func);
+  }
+
+  auto mhost = codegen::Build(fhost, target_host_val.str());
+
+  if (fdevice.size() > 0) {
+    auto mdev = codegen::Build(fdevice, target.str());
+    mhost.Import(mdev);
+  }
+
+  return mhost;
+}
+}  // namespace tvm
diff --git a/src/lang/expr.cc b/src/lang/expr.cc
index 348733bad..be83b521e 100644
--- a/src/lang/expr.cc
+++ b/src/lang/expr.cc
@@ -47,6 +47,10 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
   return os;
 }
 
+Var var(const std::string& name_hint, Type t) {
+  return Var(name_hint, t);
+}
+
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
     p->stream << "iter_var(";
diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc
new file mode 100644
index 000000000..fc3f6ac93
--- /dev/null
+++ b/tests/cpp/build_module_test.cc
@@ -0,0 +1,42 @@
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/tvm.h>
+#include <tvm/operation.h>
+#include <tvm/build_module.h>
+
+TEST(BuildModule, Basic) {
+  using namespace tvm;
+  auto n = var("n");
+  Array<Expr> shape;
+  shape.push_back(n);
+
+  auto A = placeholder(shape, Float(32), "A");
+  auto B = placeholder(shape, Float(32), "B");
+
+  auto C = compute(A->shape, [&A, &B](Expr i) {
+    return A[i] + B[i];
+  }, "C");
+
+  auto s = create_schedule({ C->op });
+
+  auto cAxis = C->op.as<ComputeOpNode>()->axis;
+
+  IterVar bx, tx;
+  s[C].split(cAxis[0], 64, &bx, &tx);
+
+  auto args = Array<Tensor>({ A, B, C });
+  std::unordered_map<Tensor, Buffer> binds;
+
+  BuildConfig config;
+  auto target = target::llvm();
+
+  auto lowered = lower(s, args, "func", binds, config);
+  auto module = build(lowered, target, nullptr, config);
+}
+
+
+int main(int argc, char ** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  testing::FLAGS_gtest_death_test_style = "threadsafe";
+  return RUN_ALL_TESTS();
+}
-- 
GitLab