Skip to content
Snippets Groups Projects
Commit 5f829774 authored by tqchen's avatar tqchen
Browse files

Add domain

parent 5324b211
No related branches found
No related tags found
No related merge requests found
......@@ -128,6 +128,18 @@ class Array : public NodeRef {
if (node_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(node_.get())->data.size();
}
friend std::ostream& operator<<(std::ostream &os, const Array<T>& r) { // NOLINT(*)
for (size_t i = 0; i < r.size(); ++i) {
if (i == 0) {
os << '[';
} else {
os << ", ";
}
os << r[i];
}
os << ']';
return os;
}
};
} // namespace tvm
......
......@@ -13,14 +13,133 @@
namespace tvm {
/*! \brief range over one dimension */
class RangeNode : public Node {
public:
/*! \brief beginning of the node */
Expr begin;
/*! \brief end of the node */
Expr end;
/*! \brief constructor */
RangeNode() {}
RangeNode(Expr && begin, Expr && end)
: begin(std::move(begin)), end(std::move(end)) {
}
const char* type_key() const override {
return "RangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("begin", &begin);
fvisit("end", &end);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
//using Domain = Array<Range>;
/*! \brief Node range */
class Range : public NodeRef {
public:
/*! \brief constructor */
Range() {}
/*!
* \brief constructor
* \param begin start of the range.
* \param end end of the range.
*/
Range(Expr begin, Expr end);
/*! \return The extent of the range */
Expr extent() const;
/*! \return the begining of the range */
inline const Expr& begin() const {
return static_cast<const RangeNode*>(node_.get())->begin;
}
/*! \return the end of the range */
inline const Expr& end() const {
return static_cast<const RangeNode*>(node_.get())->end;
}
friend std::ostream& operator<<(std::ostream &os, const Range& r) { // NOLINT(*)
os << '[' << r.begin() << ", " << r.end() <<')';
return os;
}
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
class RDomain : public NodeRef {
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> && index, Domain && domain)
: index(std::move(index)), domain(std::move(domain)) {
}
const char* type_key() const override {
return "RDomainNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("index", &index);
fvisit("domain", &domain);
}
void VisitAttrs(AttrVisitor* visitor) override {}
};
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit RDomain(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
CHECK(node_->is_type<RDomainNode>());
}
/*! \return The dimension of the RDomain */
inline size_t ndim() const {
return static_cast<const RDomainNode*>(node_.get())->index.size();
}
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const {
return static_cast<const RDomainNode*>(node_.get())->index[i];
}
/*!
* \return The domain of the reduction.
*/
inline const Domain& domain() const {
return static_cast<const RDomainNode*>(node_.get())->domain;
}
friend std::ostream& operator<<(std::ostream &os, const RDomain& r) { // NOLINT(*)
os << "rdomain(" << r.domain() << ")";
return os;
}
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
} // namespace tvm
......
......@@ -11,8 +11,8 @@
#include "./tensor.h"
#include "./expr.h"
namespace tvm {
/*! \brief variable node for symbolic variables */
class VarNode : public ExprNode {
public:
......
......@@ -16,7 +16,9 @@ namespace tvm {
* \param src The source expression
* \return the simplified expression.
*/
Expr Simplify(const Expr& src);
inline Expr Simplify(Expr src) {
return src;
}
/*!
* \brief visit the exression node in expr tree in post DFS order.
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/domain.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
namespace tvm {
Range::Range(Expr begin, Expr end) {
node_ = std::make_shared<RangeNode>(
std::move(begin), std::move(end));
}
Expr Range::extent() const {
return Simplify(end() - begin());
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
index.push_back(Var("reduction_index"));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
......@@ -48,6 +48,13 @@ void Expr::Print(std::ostream& os) const {
os << ')';
return;
}
case kReduceNode: {
const auto* n = Get<ReduceNode>();
os << "reduce("<< n->op->FunctionName() << ", ";
n->src.Print(os);
os << ", " << n->rdom << ')';
return;
}
default: {
LOG(FATAL) << "not able to handle type " << typeid(node_.get()).name();
}
......
......@@ -42,5 +42,6 @@ TVM_REGISTER_NODE_TYPE(IntNode);
TVM_REGISTER_NODE_TYPE(FloatNode);
TVM_REGISTER_NODE_TYPE(UnaryOpNode);
TVM_REGISTER_NODE_TYPE(BinaryOpNode);
TVM_REGISTER_NODE_TYPE(ReduceNode);
} // namespace tvm
......@@ -11,6 +11,16 @@ TEST(Expr, Basic) {
CHECK(os.str() == "max(((x + 1) + 2), 100)");
}
TEST(Expr, Reduction) {
using namespace tvm;
Var x("x");
RDomain rdom({{0, 3}});
auto z = sum(x + 1 + 2, rdom);
std::ostringstream os;
os << z;
CHECK(os.str() == "reduce(+, ((x + 1) + 2), rdomain([[0, 3)]))");
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
......@@ -7,11 +8,6 @@ TEST(Tensor, Basic) {
Var m, n, k;
Tensor A({m, k});
Tensor B({n, k});
auto x = [=](Var i, Var j, Var k) {
return A(i, k) * B(j, k);
};
auto C = Tensor({m, n}, x);
}
int main(int argc, char ** argv) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment