From 5f829774f2a7c47784799f9bd25add2ffd4064b0 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Wed, 19 Oct 2016 12:13:30 -0700 Subject: [PATCH] Add domain --- include/tvm/array.h | 12 ++++ include/tvm/domain.h | 123 ++++++++++++++++++++++++++++++++++++++- include/tvm/expr_node.h | 2 +- include/tvm/expr_util.h | 4 +- src/expr/domain.cc | 36 ++++++++++++ src/expr/expr.cc | 7 +++ src/expr/expr_node.cc | 1 + tests/cpp/expr_test.cc | 10 ++++ tests/cpp/tensor_test.cc | 6 +- 9 files changed, 192 insertions(+), 9 deletions(-) create mode 100644 src/expr/domain.cc diff --git a/include/tvm/array.h b/include/tvm/array.h index 9a1af9811..4484d3b89 100644 --- a/include/tvm/array.h +++ b/include/tvm/array.h @@ -128,6 +128,18 @@ class Array : public NodeRef { if (node_.get() == nullptr) return 0; return static_cast<const ArrayNode*>(node_.get())->data.size(); } + friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*) + for (size_t i = 0; i < r.size(); ++i) { + if (i == 0) { + os << '['; + } else { + os << ", "; + } + os << r[i]; + } + os << ']'; + return os; + } }; } // namespace tvm diff --git a/include/tvm/domain.h b/include/tvm/domain.h index 63cf6edbf..02fe7ca01 100644 --- a/include/tvm/domain.h +++ b/include/tvm/domain.h @@ -13,14 +13,133 @@ namespace tvm { +/*! \brief range over one dimension */ +class RangeNode : public Node { + public: + /*! \brief beginning of the node */ + Expr begin; + /*! \brief end of the node */ + Expr end; + /*! \brief constructor */ + RangeNode() {} + RangeNode(Expr && begin, Expr && end) + : begin(std::move(begin)), end(std::move(end)) { + } + const char* type_key() const override { + return "RangeNode"; + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("begin", &begin); + fvisit("end", &end); + } + void VisitAttrs(AttrVisitor* visitor) override {} +}; -//using Domain = Array<Range>; +/*! \brief Node range */ +class Range : public NodeRef { + public: + /*! \brief constructor */ + Range() {} + /*! + * \brief constructor + * \param begin start of the range. + * \param end end of the range. + */ + Range(Expr begin, Expr end); + /*! \return The extent of the range */ + Expr extent() const; + /*! \return the begining of the range */ + inline const Expr& begin() const { + return static_cast<const RangeNode*>(node_.get())->begin; + } + /*! \return the end of the range */ + inline const Expr& end() const { + return static_cast<const RangeNode*>(node_.get())->end; + } + friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*) + os << '[' << r.begin() << ", " << r.end() <<')'; + return os; + } +}; +/*! \brief Domain is a multi-dimensional range */ +using Domain = Array<Range>; -class RDomain : public NodeRef { +/*! \brief reduction domain node */ +class RDomainNode : public Node { + public: + /*! \brief internal index */ + Array<Var> index; + /*! \brief The inernal domain */ + Domain domain; + /*! \brief constructor */ + RDomainNode() {} + RDomainNode(Array<Var> && index, Domain && domain) + : index(std::move(index)), domain(std::move(domain)) { + } + const char* type_key() const override { + return "RDomainNode"; + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("index", &index); + fvisit("domain", &domain); + } + void VisitAttrs(AttrVisitor* visitor) override {} +}; +/*! \brief reduction domain */ +class RDomain : public NodeRef { + public: + /*! \brief constructor*/ + RDomain() {} + /*! + * constructor by domain + * \param domain The domain of reduction. + */ + explicit RDomain(Domain domain); + /*! + * \brief constructor by list of ranges + * \param domain The reduction domain + */ + explicit RDomain(std::initializer_list<Range> domain) + : RDomain(Domain(domain)) {} + /*! + * \brief constructor from node pointer + * \param nptr Another node shared pointer + */ + explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) { + CHECK(node_.get() != nullptr); + CHECK(node_->is_type<RDomainNode>()); + } + /*! \return The dimension of the RDomain */ + inline size_t ndim() const { + return static_cast<const RDomainNode*>(node_.get())->index.size(); + } + /*! \return the 0-th index of the domain */ + inline Var i0() const { + return index(0); + } + /*! + * \param i the index. + * \return i-th index variable in the RDomain + */ + inline Var index(size_t i) const { + return static_cast<const RDomainNode*>(node_.get())->index[i]; + } + /*! + * \return The domain of the reduction. + */ + inline const Domain& domain() const { + return static_cast<const RDomainNode*>(node_.get())->domain; + } + friend std::ostream& operator<<(std::ostream &os, const RDomain& r) { // NOLINT(*) + os << "rdomain(" << r.domain() << ")"; + return os; + } }; +/*! \brief use RDom as alias of RDomain */ +using RDom = RDomain; } // namespace tvm diff --git a/include/tvm/expr_node.h b/include/tvm/expr_node.h index d0946d053..371d00908 100644 --- a/include/tvm/expr_node.h +++ b/include/tvm/expr_node.h @@ -11,8 +11,8 @@ #include "./tensor.h" #include "./expr.h" - namespace tvm { + /*! \brief variable node for symbolic variables */ class VarNode : public ExprNode { public: diff --git a/include/tvm/expr_util.h b/include/tvm/expr_util.h index b91b73231..ec4283e48 100644 --- a/include/tvm/expr_util.h +++ b/include/tvm/expr_util.h @@ -16,7 +16,9 @@ namespace tvm { * \param src The source expression * \return the simplified expression. */ -Expr Simplify(const Expr& src); +inline Expr Simplify(Expr src) { + return src; +} /*! * \brief visit the exression node in expr tree in post DFS order. diff --git a/src/expr/domain.cc b/src/expr/domain.cc new file mode 100644 index 000000000..1fd51bd0c --- /dev/null +++ b/src/expr/domain.cc @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file domain.cc + */ +#include <tvm/domain.h> +#include <tvm/op.h> +#include <tvm/expr_node.h> +#include <tvm/expr_util.h> + +namespace tvm { + +Range::Range(Expr begin, Expr end) { + node_ = std::make_shared<RangeNode>( + std::move(begin), std::move(end)); +} + +Expr Range::extent() const { + return Simplify(end() - begin()); +} + + +RDomain::RDomain(Domain domain) { + std::vector<Var> index; + for (size_t i = 0; i < domain.size(); ++i) { + index.push_back(Var("reduction_index")); + } + Array<Var> idx(index); + node_ = std::make_shared<RDomainNode>( + std::move(idx), std::move(domain)); +} + +TVM_REGISTER_NODE_TYPE(RangeNode); +TVM_REGISTER_NODE_TYPE(ArrayNode); +TVM_REGISTER_NODE_TYPE(RDomainNode); + +} // namespace tvm diff --git a/src/expr/expr.cc b/src/expr/expr.cc index fe93bb08c..d0479d3f6 100644 --- a/src/expr/expr.cc +++ b/src/expr/expr.cc @@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const { os << ')'; return; } + case kReduceNode: { + const auto* n = Get<ReduceNode>(); + os << "reduce("<< n->op->FunctionName() << ", "; + n->src.Print(os); + os << ", " << n->rdom << ')'; + return; + } default: { LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name(); } diff --git a/src/expr/expr_node.cc b/src/expr/expr_node.cc index c6626672e..cc05472c9 100644 --- a/src/expr/expr_node.cc +++ b/src/expr/expr_node.cc @@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode); TVM_REGISTER_NODE_TYPE(FloatNode); TVM_REGISTER_NODE_TYPE(UnaryOpNode); TVM_REGISTER_NODE_TYPE(BinaryOpNode); +TVM_REGISTER_NODE_TYPE(ReduceNode); } // namespace tvm diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 356ffb6cb..cf48f74ce 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -11,6 +11,16 @@ TEST(Expr, Basic) { CHECK(os.str() == "max(((x + 1) + 2), 100)"); } +TEST(Expr, Reduction) { + using namespace tvm; + Var x("x"); + RDomain rdom({{0, 3}}); + auto z = sum(x + 1 + 2, rdom); + std::ostringstream os; + os << z; + CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))"); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index 9c33fe60f..5281bccdd 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -1,3 +1,4 @@ + #include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/tvm.h> @@ -7,11 +8,6 @@ TEST(Tensor, Basic) { Var m, n, k; Tensor A({m, k}); Tensor B({n, k}); - - auto x = [=](Var i, Var j, Var k) { - return A(i, k) * B(j, k); - }; - auto C = Tensor({m, n}, x); } int main(int argc, char ** argv) { -- GitLab