Skip to content
Snippets Groups Projects
Commit 5445a936 authored by tqchen's avatar tqchen
Browse files

Refactor to use iterVar

parent 7591714a
No related branches found
No related tags found
No related merge requests found
Showing with 300 additions and 274 deletions
Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835
Subproject commit e96ee0f2fb5239021c0facd5398a9a96644bc411
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the domain in AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <ir/Range.h>
#include <memory>
#include "./base.h"
#include "./expr.h"
namespace tvm {
/*! \brief container class of reduction domain */
class RDomainNode;
class IterDomainNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
};
/*! \brief Domain is a multi-dimensional range */
using Domain = Array<Range>;
/*! \brief reduction domain */
class RDomain : public NodeRef {
public:
/*! \brief constructor*/
RDomain() {}
explicit RDomain(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit RDomain(Domain domain);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit RDomain(std::initializer_list<Range> domain)
: RDomain(Domain(domain)) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const RDomainNode* operator->() const;
/*! \return The dimension of the RDomain */
inline size_t ndim() const;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline Var index(size_t i) const;
/*! \return the 0-th index of the domain */
inline Var i0() const {
return index(0);
}
// low level constructor
static RDomain make(Array<Var> index, Domain domain);
};
/*! \brief use RDom as alias of RDomain */
using RDom = RDomain;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional domain.
*/
class IterVarNode : public Node {
/*! \brief The */
Var var;
/*! \brief the domain of iteration */
Range dom;
/*! \brief additional tag on the iteration variable */
std::string tag;
};
/*! \brief reduction domain node */
class RDomainNode : public Node {
public:
/*! \brief internal index */
Array<Var> index;
/*! \brief The inernal domain */
Domain domain;
/*! \brief constructor */
RDomainNode() {}
RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) {
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
static constexpr const char* _type_key = "RDomain";
TVM_DECLARE_NODE_TYPE_INFO(RDomainNode);
};
inline const RDomainNode* RDomain::operator->() const {
return static_cast<const RDomainNode*>(node_.get());
}
inline size_t RDomain::ndim() const {
return (*this)->index.size();
}
inline Var RDomain::index(size_t i) const {
return (*this)->index[i];
}
// overload print function
inline std::ostream& operator<<(std::ostream &os, const RDomain& r){ // NOLINT(*)
os << "rdomain(" << r->domain << ")";
return os;
}
} // namespace tvm
#endif // TVM_DOMAIN_H_
/*!
* Copyright (c) 2016 by Contributors
* \file expr.h
* \brief Defines the expressions in AST.
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <ir/IROperator.h>
#include <string>
#include <algorithm>
#include "./base.h"
namespace tvm {
......@@ -19,20 +21,14 @@ using Halide::Int;
using Halide::UInt;
using Halide::Handle;
// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;
using Halide::Expr;
using Halide::VarExpr;
using Halide::IR::FunctionRef;
using Halide::IR::FunctionBaseNode;
using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
......@@ -41,5 +37,134 @@ class Var : public Halide::VarExpr {
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
};
/*! \brief container class of iteration variable. */
class IterVarNode;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class Range : public Halide::IR::Range {
public:
/*! \brief constructor */
Range() {}
explicit Range(std::shared_ptr<Node> n) : Halide::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range(Expr begin, Expr end);
static Range make_with_min_extent(Expr min, Expr extent);
};
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class IterVar : public NodeRef {
public:
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construction of iteration variable.
* \param dom The iteration domain.
* \param var_name The name of iteration variable.
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
*/
explicit IterVar(Range dom, std::string var_name = "i", std::string thread_tag = "");
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarNode* operator->() const;
/*!
* \return the corresponding var in the IterVar.
*/
inline operator Expr() const;
/*! \brief specify container node */
using ContainerType = IterVarNode;
};
using Domain = Array<Range>;
// functions
using Halide::cast;
using Halide::min;
using Halide::max;
using Halide::abs;
using Halide::select;
/*!
* \brief sum of of source expression over rdom
* \param source The source expression.
*/
Expr sum(Expr source, Array<IterVar> rdom);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr max(Expr source, Array<IterVar> rdom);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr min(Expr source, Array<IterVar> rdom);
// print functions for expr
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
// definition of Node.
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class IterVarNode : public Node {
public:
/*! \brief The looping variable */
Var var;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("dom", &dom);
v->Visit("thread_tag", &thread_tag);
}
static IterVar make(Var var, Range dom, std::string thread_tag);
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
};
// inline implementations
inline const IterVarNode* IterVar::operator->() const {
return static_cast<const IterVarNode*>(node_.get());
}
inline IterVar::operator Expr() const {
return (*this)->var;
}
} // namespace tvm
#endif // TVM_EXPR_H_
......@@ -11,7 +11,7 @@
#include <type_traits>
#include <string>
#include "./base.h"
#include "./domain.h"
#include "./expr.h"
namespace tvm {
namespace ir {
......@@ -30,11 +30,11 @@ struct Reduce : public ExprNode<Reduce> {
std::string op;
/*! \brief The source operand */
Expr source;
/*! \brief The reduction domain */
RDomain rdom;
/*! \brief The reduction domains */
Array<IterVar> rdom;
/*! \brief construct expr from name and rdom */
static Expr make(std::string name, Expr src, RDomain rdom);
/*! \brief construct expr from op and rdom */
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
......
......@@ -8,7 +8,6 @@
#include <string>
#include "./expr.h"
#include "./domain.h"
#include "./tensor.h"
namespace tvm {
......
......@@ -8,7 +8,6 @@
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
namespace tvm {
......
......@@ -14,7 +14,6 @@
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
namespace tvm {
......@@ -66,8 +65,8 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read.
*/
Expr operator()(Array<Expr> indices) const;
// overload print function
friend std::ostream& operator<<(std::ostream &os, const Tensor& t);
/*! \brief specify container node */
using ContainerType = TensorNode;
};
/*! \brief Operation that produces tensors */
......@@ -87,6 +86,8 @@ class Operation : public NodeRef {
* \return The i-th output.
*/
Tensor output(size_t i) const;
/*! \brief specify container node */
using ContainerType = OperationNode;
};
/*! \brief Node to represent a tensor */
......@@ -162,11 +163,5 @@ inline size_t Tensor::ndim() const {
return (*this)->shape.size();
}
inline std::ostream& operator<<(std::ostream &os, const Tensor& t) { // NOLINT(*)
os << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
return os;
}
} // namespace tvm
#endif // TVM_TENSOR_H_
......@@ -118,6 +118,7 @@ def convert(value):
raise ValueError("don't know how to handle type %s" % type(value))
return value
def _push_arg(arg):
a = ArgVariant()
if arg is None:
......
......@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
from . import expr as _expr
@register_node
class Array(NodeBase):
......@@ -19,11 +20,9 @@ class Array(NodeBase):
@register_node
class Range(NodeBase):
def __repr__(self):
return ('Range(min='+ str(self.min) +
', extent=' + str(self.extent) + ')')
pass
@register_node
class RDomain(NodeBase):
class IterVar(_expr.ExprCompatible):
pass
......@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
class Expr(NodeBase):
class ExprCompatible(NodeBase):
def __add__(self, other):
return _make.Add(self, other)
......@@ -36,6 +36,10 @@ class Expr(NodeBase):
def __neg__(self):
return self.__mul__(-1)
class Expr(ExprCompatible):
pass
class ConstExpr(Expr):
pass
......
......@@ -103,33 +103,34 @@ def compute(shape, fcompute, name="TensorCompute"):
shape, name, body.dtype, op_node, 0)
def RDomain(dom):
"""Create a reduction domain given domain
def IterVar(dom, name='iter', thread_tag=''):
"""Create a iteration variable
Parameters
----------
dom : list of Range or list of pairs
The reduction domain.
dom : Range
The domain of iteration.
name : str
The name of iteration variable.
thread_tag : str
The thread tag of the iteration variable.
Returns
-------
rdom : RDomain
The result rdomain
iter_var : IterVar
The result itervar
"""
if not isinstance(dom, (list, tuple)):
dom = [dom]
elif not isinstance(dom[0], (list, tuple)):
dom = [dom]
dnorm = []
for x in dom:
if isinstance(x, (list, tuple)):
if len(x) != 2:
raise ValueError("need to list of ranges")
dnorm.append(Range(x[0], x[1]))
else:
dnorm.append(x)
dnorm = convert(dnorm)
return _function_internal._RDomain(dnorm)
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
raise ValueError("need to list of ranges")
dom = Range(dom[0], dom[1])
if not isinstance(dom, _collections.Range):
raise ValueError("dom need to be Range")
return _function_internal._IterVar(dom, name, thread_tag)
def sum(expr, rdom):
......@@ -143,10 +144,11 @@ def sum(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert isinstance(rdom, _collections.RDomain)
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
return x
def min(expr, rdom):
"""Create a min expression over rdom
......@@ -158,11 +160,11 @@ def min(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
return x
def max(expr, rdom):
"""Create a min expression over rdom
......@@ -174,8 +176,7 @@ def max(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert isinstance(expr, _expr.Expr)
assert isinstance(rdom, _collections.RDomain)
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
return x
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from ._ctypes._api import NodeBase, register_node, convert
from . import collections as _collections
from . import make as _make
from . import expr as _expr
......@@ -10,7 +11,18 @@ class Tensor(NodeBase):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
return _make.Call(self.dtype, self.name, indices, _expr.Call.Halide, self, 0)
indices = convert(indices)
args = []
for x in indices:
if isinstance(x, _collections.IterVar):
args.append(x.var)
elif isinstance(x, _expr.Expr):
args.append(x)
else:
raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
@property
def ndim(self):
......
......@@ -4,9 +4,7 @@
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <tvm/domain.h>
#include <tvm/tensor.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace dmlc {
......@@ -22,21 +20,9 @@ TVM_REGISTER_API(_format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::BaseExprNode;
using Halide::Internal::BaseStmtNode;
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (dynamic_cast<const TensorNode*>(sptr.get())) {
os << args.at(0).operator Tensor();
} else if (dynamic_cast<const RDomainNode*>(sptr.get())) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
} else {
LOG(FATAL) << "don't know how to print input NodeBaseType";
}
os << args.at(0).operator NodeRef();
*ret = os.str();
})
.add_argument("expr", "Node", "expression to be printed");
......
......@@ -5,10 +5,8 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/domain.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace tvm {
......@@ -95,11 +93,13 @@ TVM_REGISTER_API(_ComputeOp)
args.at(3));
});
TVM_REGISTER_API(_RDomain)
TVM_REGISTER_API(_IterVar)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = RDomain(args.at(0).operator Domain());
*ret = IterVar(args.at(0), args.at(1), args.at(2));
});
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1));
......
......@@ -125,7 +125,13 @@ class APIVariantValue {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle);
return Expr(sptr);
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
} else {
CHECK(dynamic_cast<typename Expr::ContainerType*>(sptr.get()))
<< "did not pass in Expr in a place need Expr";
return Expr(sptr);
}
}
inline operator double() const {
CHECK_EQ(type_id, kDouble);
......
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/base.h>
#include <tvm/domain.h>
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
RDomain::RDomain(Domain domain) {
std::vector<Var> index;
for (size_t i = 0; i < domain.size(); ++i) {
std::ostringstream os;
os << "reduction_index" << i;
index.push_back(Var(os.str()));
}
Array<Var> idx(index);
node_ = std::make_shared<RDomainNode>(
std::move(idx), std::move(domain));
}
RDomain RDomain::make(Array<Var> index, Domain domain) {
return RDomain(std::make_shared<RDomainNode>(index, domain));
}
TVM_REGISTER_NODE_TYPE(RDomainNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IRPrinter.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
}
Range Range::make_with_min_extent(Expr min, Expr extent) {
return Range(std::make_shared<Halide::IR::RangeNode>(min, extent));
}
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag)
: IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {}
IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->var = var;
n->dom = dom;
n->thread_tag = thread_tag;
return IterVar(n);
}
Expr sum(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Add", source, rdom);
}
Expr max(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Max", source, rdom);
}
Expr min(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make("Min", source, rdom);
}
std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
IRPrinter(os).print(n);
return os;
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", ";
}
p->stream << op->dom;
if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag;
}
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Halide::IR::RangeNode>([](const Halide::IR::RangeNode *op, IRPrinter *p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file ir_node.cc
* \file ir.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
......@@ -9,11 +9,6 @@
#include <ir/IRPrinter.h>
#include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace Halide {
namespace Internal {
......@@ -53,9 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace tvm {
namespace ir {
Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
Expr Reduce::make(std::string op, Expr source, Array<IterVar> rdom) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < rdom.size(); ++i) {
CHECK(rdom[i].defined());
}
n->type = source.type();
n->source = source;
n->op = op;
......
......@@ -41,6 +41,12 @@ Tensor TensorNode::make(Array<Expr> shape,
return Tensor(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')';
});
TVM_REGISTER_NODE_TYPE(TensorNode);
} // namespace tvm
......@@ -42,27 +42,29 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
inline RDomain MutateRDom(RDomain rdom, IRMutator *m) {
std::vector<Range> new_dom(rdom->domain.size());
inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom->domain.size(); i++) {
Range r = rdom->domain[i];
for (size_t i = 0; i < rdom.size(); i++) {
IterVar v = rdom[i];
Range r = v->dom;
Expr new_min = m->Mutate(r->min);
Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = Range::make_with_min_extent(new_min, new_extent);
new_dom[i] = IterVarNode::make(
v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag);
}
if (!changed) {
return rdom;
} else {
return RDomain::make(rdom->index, Domain(new_dom));
return Array<IterVar>(new_dom);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
RDomain new_rdom = MutateRDom(op->rdom, m);
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
Expr new_source = m->Mutate(op->source);
if (op->rdom.same_as(new_rdom) &&
op->source.same_as(new_source)) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment