From 5f829774f2a7c47784799f9bd25add2ffd4064b0 Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Wed, 19 Oct 2016 12:13:30 -0700
Subject: [PATCH] Add domain

---
 include/tvm/array.h      |  12 ++++
 include/tvm/domain.h     | 123 ++++++++++++++++++++++++++++++++++++++-
 include/tvm/expr_node.h  |   2 +-
 include/tvm/expr_util.h  |   4 +-
 src/expr/domain.cc       |  36 ++++++++++++
 src/expr/expr.cc         |   7 +++
 src/expr/expr_node.cc    |   1 +
 tests/cpp/expr_test.cc   |  10 ++++
 tests/cpp/tensor_test.cc |   6 +-
 9 files changed, 192 insertions(+), 9 deletions(-)
 create mode 100644 src/expr/domain.cc

diff --git a/include/tvm/array.h b/include/tvm/array.h
index 9a1af9811..4484d3b89 100644
--- a/include/tvm/array.h
+++ b/include/tvm/array.h
@@ -128,6 +128,18 @@ class Array : public NodeRef {
     if (node_.get() == nullptr) return 0;
     return static_cast<const ArrayNode*>(node_.get())->data.size();
   }
+  friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) {  // NOLINT(*)
+    for (size_t i = 0; i < r.size(); ++i) {
+      if (i == 0) {
+        os << '[';
+      } else {
+        os << ", ";
+      }
+      os << r[i];
+    }
+    os << ']';
+    return os;
+  }
 };
 
 }  // namespace tvm
diff --git a/include/tvm/domain.h b/include/tvm/domain.h
index 63cf6edbf..02fe7ca01 100644
--- a/include/tvm/domain.h
+++ b/include/tvm/domain.h
@@ -13,14 +13,133 @@
 
 namespace tvm {
 
+/*! \brief range over one dimension */
+class RangeNode : public Node {
+ public:
+  /*! \brief beginning of the node */
+  Expr begin;
+  /*! \brief end of the node */
+  Expr end;
+  /*! \brief constructor */
+  RangeNode() {}
+  RangeNode(Expr && begin, Expr && end)
+      : begin(std::move(begin)), end(std::move(end)) {
+  }
+  const char* type_key() const override {
+    return "RangeNode";
+  }
+  void VisitNodeRefFields(FNodeRefVisit fvisit) override {
+    fvisit("begin", &begin);
+    fvisit("end", &end);
+  }
+  void VisitAttrs(AttrVisitor* visitor) override {}
+};
 
-//using Domain = Array<Range>;
+/*! \brief Node range */
+class Range : public NodeRef {
+ public:
+  /*! \brief constructor */
+  Range() {}
+  /*!
+   * \brief constructor
+   * \param begin start of the range.
+   * \param end end of the range.
+   */
+  Range(Expr begin, Expr end);
+  /*! \return The extent of the range */
+  Expr extent() const;
+  /*! \return the begining of the range */
+  inline const Expr& begin() const {
+    return static_cast<const RangeNode*>(node_.get())->begin;
+  }
+  /*! \return the end  of the range */
+  inline const Expr& end() const {
+    return static_cast<const RangeNode*>(node_.get())->end;
+  }
+  friend std::ostream& operator<<(std::ostream &os, const Range& r) {  // NOLINT(*)
+    os << '[' << r.begin() << ", " << r.end() <<')';
+    return os;
+  }
+};
 
+/*! \brief Domain is a multi-dimensional range */
+using Domain = Array<Range>;
 
-class RDomain : public NodeRef {
+/*! \brief reduction domain node */
+class RDomainNode : public Node {
+ public:
+  /*! \brief internal index */
+  Array<Var> index;
+  /*! \brief The inernal domain */
+  Domain domain;
+  /*! \brief constructor */
+  RDomainNode() {}
+  RDomainNode(Array<Var> && index, Domain && domain)
+      : index(std::move(index)), domain(std::move(domain)) {
+  }
+  const char* type_key() const override {
+    return "RDomainNode";
+  }
+  void VisitNodeRefFields(FNodeRefVisit fvisit) override {
+    fvisit("index", &index);
+    fvisit("domain", &domain);
+  }
+  void VisitAttrs(AttrVisitor* visitor) override {}
+};
 
+/*! \brief reduction domain */
+class RDomain : public NodeRef {
+ public:
+  /*! \brief constructor*/
+  RDomain() {}
+  /*!
+   * constructor by domain
+   * \param domain The domain of reduction.
+   */
+  explicit RDomain(Domain domain);
+  /*!
+   * \brief constructor by list of ranges
+   * \param domain The reduction domain
+   */
+  explicit RDomain(std::initializer_list<Range> domain)
+      : RDomain(Domain(domain)) {}
+  /*!
+   * \brief constructor from node pointer
+   * \param nptr Another node shared pointer
+   */
+  explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
+    CHECK(node_.get() != nullptr);
+    CHECK(node_->is_type<RDomainNode>());
+  }
+  /*! \return The dimension of the RDomain */
+  inline size_t ndim() const {
+    return static_cast<const RDomainNode*>(node_.get())->index.size();
+  }
+  /*! \return the 0-th index of the domain */
+  inline Var i0() const {
+    return index(0);
+  }
+  /*!
+   * \param i the index.
+   * \return i-th index variable in the RDomain
+   */
+  inline Var index(size_t i) const {
+    return static_cast<const RDomainNode*>(node_.get())->index[i];
+  }
+  /*!
+   * \return The domain of the reduction.
+   */
+  inline const Domain& domain() const {
+    return static_cast<const RDomainNode*>(node_.get())->domain;
+  }
+  friend std::ostream& operator<<(std::ostream &os, const RDomain& r) {  // NOLINT(*)
+    os << "rdomain(" << r.domain() << ")";
+    return os;
+  }
 };
 
+/*! \brief use RDom as alias of RDomain */
+using RDom = RDomain;
 
 }  // namespace tvm
 
diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h
index d0946d053..371d00908 100644
--- a/include/tvm/expr_node.h
+++ b/include/tvm/expr_node.h
@@ -11,8 +11,8 @@
 #include "./tensor.h"
 #include "./expr.h"
 
-
 namespace tvm {
+
 /*! \brief variable node for symbolic variables */
 class VarNode : public ExprNode {
  public:
diff --git a/include/tvm/expr_util.h b/include/tvm/expr_util.h
index b91b73231..ec4283e48 100644
--- a/include/tvm/expr_util.h
+++ b/include/tvm/expr_util.h
@@ -16,7 +16,9 @@ namespace tvm {
  * \param src The source expression
  * \return the simplified expression.
  */
-Expr Simplify(const Expr& src);
+inline Expr Simplify(Expr src) {
+  return src;
+}
 
 /*!
  * \brief visit the exression node in expr tree in post DFS order.
diff --git a/src/expr/domain.cc b/src/expr/domain.cc
new file mode 100644
index 000000000..1fd51bd0c
--- /dev/null
+++ b/src/expr/domain.cc
@@ -0,0 +1,36 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file domain.cc
+ */
+#include <tvm/domain.h>
+#include <tvm/op.h>
+#include <tvm/expr_node.h>
+#include <tvm/expr_util.h>
+
+namespace tvm {
+
+Range::Range(Expr begin, Expr end) {
+  node_ = std::make_shared<RangeNode>(
+      std::move(begin), std::move(end));
+}
+
+Expr Range::extent() const {
+  return Simplify(end() - begin());
+}
+
+
+RDomain::RDomain(Domain domain) {
+  std::vector<Var> index;
+  for (size_t i = 0; i < domain.size(); ++i) {
+    index.push_back(Var("reduction_index"));
+  }
+  Array<Var> idx(index);
+  node_ = std::make_shared<RDomainNode>(
+      std::move(idx), std::move(domain));
+}
+
+TVM_REGISTER_NODE_TYPE(RangeNode);
+TVM_REGISTER_NODE_TYPE(ArrayNode);
+TVM_REGISTER_NODE_TYPE(RDomainNode);
+
+}  // namespace tvm
diff --git a/src/expr/expr.cc b/src/expr/expr.cc
index fe93bb08c..d0479d3f6 100644
--- a/src/expr/expr.cc
+++ b/src/expr/expr.cc
@@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const {
       os << ')';
       return;
     }
+    case kReduceNode: {
+      const auto* n = Get<ReduceNode>();
+      os << "reduce("<< n->op->FunctionName() << ", ";
+      n->src.Print(os);
+      os << ", " << n->rdom << ')';
+      return;
+    }
     default: {
       LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
     }
diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc
index c6626672e..cc05472c9 100644
--- a/src/expr/expr_node.cc
+++ b/src/expr/expr_node.cc
@@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode);
 TVM_REGISTER_NODE_TYPE(FloatNode);
 TVM_REGISTER_NODE_TYPE(UnaryOpNode);
 TVM_REGISTER_NODE_TYPE(BinaryOpNode);
+TVM_REGISTER_NODE_TYPE(ReduceNode);
 
 }  // namespace tvm
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index 356ffb6cb..cf48f74ce 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -11,6 +11,16 @@ TEST(Expr, Basic) {
   CHECK(os.str() == "max(((x + 1) + 2), 100)");
 }
 
+TEST(Expr, Reduction) {
+  using namespace tvm;
+  Var x("x");
+  RDomain rdom({{0, 3}});
+  auto z = sum(x + 1 + 2, rdom);
+  std::ostringstream os;
+  os << z;
+  CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))");
+}
+
 int main(int argc, char ** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";
diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc
index 9c33fe60f..5281bccdd 100644
--- a/tests/cpp/tensor_test.cc
+++ b/tests/cpp/tensor_test.cc
@@ -1,3 +1,4 @@
+
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <tvm/tvm.h>
@@ -7,11 +8,6 @@ TEST(Tensor, Basic) {
   Var m, n, k;
   Tensor A({m, k});
   Tensor B({n, k});
-
-  auto x = [=](Var i, Var j, Var k) {
-    return A(i, k) * B(j, k);
-  };
-  auto C = Tensor({m, n}, x);
 }
 
 int main(int argc, char ** argv) {
-- 
GitLab