Skip to content
Snippets Groups Projects
Commit 357ad592 authored by tqchen's avatar tqchen
Browse files

Fix Schedule structure, refactor compute to all rely on iter var

parent 3a48b323
No related branches found
No related tags found
No related merge requests found
......@@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
*/
class IterVarNode : public Node {
public:
/*! \brief The looping variable */
Var var;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*! \brief The looping variable */
Var var;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
......@@ -147,12 +147,13 @@ class IterVarNode : public Node {
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("thread_tag", &thread_tag);
}
static IterVar make(Var var, Range dom, std::string thread_tag);
static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
};
......
......@@ -17,6 +17,8 @@ namespace tvm {
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief Iteration variables over the dimensions */
Array<IterVar> dim_var;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
......@@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode {
size_t num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
v->Visit("dim_var", &dim_var);
v->Visit("body", &body);
}
static Operation make(Domain domain,
std::string name,
Array<Var> dim_var,
static Operation make(std::string name,
Array<IterVar> dim_var,
Expr body);
static constexpr const char* _type_key = "ComputeOp";
......
......@@ -8,15 +8,14 @@
#include <string>
#include "./base.h"
#include "./split.h"
#include "./operation.h"
namespace tvm {
// Node container for Schedule
class ScheduleNode;
// Node container for AttachSpec
class AttachSpecNode;
// Node container for IterVarRelation
class IterVarRelationNode;
/*! \brief the attachment type */
enum AttachType : int {
......@@ -38,42 +37,132 @@ class Schedule : public NodeRef {
inline const ScheduleNode* operator->() const;
};
/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
*/
class IterVarRelation : public NodeRef {
public:
IterVarRelation() {}
explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarRelationNode* operator->() const;
};
// defintion of node containers
/*! \brief represents the schedule of the tensor */
/*!
* \brief represents the schedule of the tensor
*
* A schedule is a Directed acylic hypergraph.
* With each node is represented by a IterVar,
* and each hyper-edge is represented by a IterVarRelation.
*
* The relations can be Split/Fuse.
*
* The current data structure stores the hyper graph in its
* bipartite representation.
*
* The relations connects the IterVars in the graph.
*/
class ScheduleNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the schedule */
std::string scope;
/*! \brief Splits over iteration domains */
Array<Split> splits;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
/*!
* \brief The current leafs in the schedule.
* Operations can only be performed in leaves.
*/
Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */
AttachType attach_type;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
* \brief The attach point of this schedule.
*/
Split attach_parent;
IterVar attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
// the type key
const char* type_key() const final {
return "Schedule";
}
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("splits", &splits);
v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
}
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
};
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
};
/*!
* \brief Split the parent domain into product of
* outer and iter.
*/
class SplitNode : public IterVarRelationNode {
public:
/*! \brief The parent domain */
IterVar parent;
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The split factor */
Expr factor;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("parent", &parent);
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("factor", &factor);
}
static IterVarRelation make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode);
};
/*!
* \brief Fuse two domains into one domain.
*/
class FuseNode : public IterVarRelationNode {
public:
/*! \brief The outer domain */
IterVar outer;
/*! \brief The inner domain */
IterVar inner;
/*! \brief The target domain */
IterVar fused;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("outer", &outer);
v->Visit("inner", &inner);
v->Visit("fused", &fused);
}
static IterVarRelation make(
IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
};
// implementations
......@@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file split.h
* \brief Define a split over Domain or RDomain
*/
#ifndef TVM_SPLIT_H_
#define TVM_SPLIT_H_
#include "./base.h"
#include "./expr.h"
namespace tvm {
// internal node container for split.
class SplitNode;
/*! \brief Split over input domain */
class Split : public NodeRef {
public:
/*! \brief default constructor */
Split() {}
explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SplitNode* operator->() const;
};
/*!
* \brief base class of split node,
* specifies a split over domain
* split also defines how to generate
*/
class SplitNode : public Node {
public:
/*! \brief the variable to be splitted on */
Var var;
};
/*! \brief simple split node that splits over one dimension */
class DimSplitNode : public SplitNode {
public:
/*! \brief The factor of the split */
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
};
// Implementations of inline functions
inline const SplitNode* Split::operator->() const {
return static_cast<const SplitNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SPLIT_H_
......@@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode {
*/
class OperationNode : public Node {
public:
/*! \brief The domain of iteration of this op. */
Domain domain;
/*! \brief iter-Var over the dimensions */
Array<Var> dim_var;
/*! \brief optional name of the operation */
std::string name;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
......
......@@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"):
arg_names = fcompute.__code__.co_varnames
if ndim != len(arg_names):
raise ValueError("fcompute do not match dimension")
dim_var = [Var(x) for x in arg_names]
body = fcompute(*dim_var)
dom = [Range(0, x) for x in shape]
dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)]
body = fcompute(*[v.var for v in dim_var])
op_node = _function_internal._ComputeOp(
dom, name, dim_var, body)
name, dim_var, body)
return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0)
......
......@@ -5,7 +5,6 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include "./c_api_registry.h"
......@@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3));
args.at(2));
});
......@@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar)
});
TVM_REGISTER_API(_DimSplit)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = DimSplitNode::make(args.at(0), args.at(1));
});
TVM_REGISTER_API(_Schedule)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Schedule(args.at(0), args.at(1));
......
......@@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) {
}
IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag)
: IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {}
: IterVar(IterVarNode::make(dom, Var(var_name, Int(32)), thread_tag)) {}
IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) {
IterVar IterVarNode::make(Range dom, Var var, std::string thread_tag) {
std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>();
n->var = var;
n->dom = dom;
n->var = var;
n->thread_tag = thread_tag;
return IterVar(n);
}
......
......@@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<Var> dim_index;
std::vector<IterVar> dim_var;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_var" << i;
dim_index.push_back(Var(os.str()));
dim_var.push_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(dim_var.back()->var);
}
std::vector<Range> dom;
for (size_t i = 0; i < ndim; ++i) {
dom.push_back(Range(0, shape[i]));
}
op_node->dim_var = Array<Var>(dim_index);
op_node->domain = Domain(dom);
op_node->body = fcompute(op_node->dim_var);
op_node->dim_var = Array<IterVar>(dim_var);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(Domain domain,
std::string name,
Array<Var> dim_var,
Operation ComputeOpNode::make(std::string name,
Array<IterVar> dim_var,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->domain = domain;
n->name = name;
n->dim_var = dim_var;
n->body = body;
......@@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const {
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return dim_var;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0);
return name;
......@@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0);
std::vector<Expr> shape;
for (size_t i = 0; i < domain.size(); ++i) {
shape.push_back(domain[i]->extent);
for (size_t i = 0; i < dim_var.size(); ++i) {
const Range& r = dim_var[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
......
......@@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) {
node_ = n;
}
IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
auto n = std::make_shared<SplitNode>();
n->parent = parent;
n->outer = outer;
n->inner = inner;
n->factor = factor;
return IterVarRelation(n);
}
IterVarRelation FuseNode::make(
IterVar outer, IterVar inner, IterVar fused) {
auto n = std::make_shared<FuseNode>();
n->outer = outer;
n->inner = inner;
n->fused = fused;
return IterVarRelation(n);
}
TVM_REGISTER_NODE_TYPE(ScheduleNode);
TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode);
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file split.cc
*/
#include <tvm/split.h>
namespace tvm {
Split DimSplitNode::make(Var var,
Expr factor) {
auto n = std::make_shared<DimSplitNode>();
CHECK_EQ(factor.type().lanes(), 1);
n->var = var;
n->factor = factor;
return Split(n);
}
TVM_REGISTER_NODE_TYPE(DimSplitNode);
} // namespace tvm
......@@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) {
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make(
v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag);
Range::make_with_min_extent(new_min, new_extent),
v->var, v->thread_tag);
}
if (!changed) {
return rdom;
......
......@@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
return body;
}
void MakeLoop(const DimSplitNode* op,
const Split& s,
Scope<AttrKey, Expr>* pscope,
std::vector<Stmt>* nest) {
auto& scope = *pscope;
Expr out_min = scope[{op->var, "min"}];
Expr out_ext = scope[{op->var, "extent"}];
Expr stride = op->factor;
Var offset(s->var->name_hint + ".offset", Int(32));
// for loop with stride
// TODO(tqchen) split the loop to deal with tails
nest->emplace_back(
For::make(
offset, out_min, out_ext,
ForType::Parallel, DeviceAPI::None, Stmt()));
Expr in_min = offset + out_min;
Expr in_ext = min(stride, out_ext - offset);
// declare min and extent of the corresponding variable
nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt()));
nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt()));
// declare this is the loop
nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt()));
// setup the scope.
pscope->Push({op->var, "min"}, in_min);
pscope->Push({op->var, "extent"}, in_ext);
}
Stmt MakePipeline(const Schedule& sch, Stmt body) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment