Skip to content
Snippets Groups Projects
Commit 78ea652d authored by Tianqi Chen's avatar Tianqi Chen Committed by Haichen Shen
Browse files

[PASS] Schedule Ops init working version (#6)

* [PASS] Schedule Ops init working version

* bugfix in PassUp
parent 302c2e64
No related branches found
No related tags found
No related merge requests found
Showing
with 428 additions and 200 deletions
Subproject commit 5d1bd103c2abe19392b4d8def7e3ff1c854e8683 Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
using Halide::Type; using Halide::Type;
using Halide::Float; using Halide::Float;
using Halide::Bool;
using Halide::Int; using Halide::Int;
using Halide::UInt; using Halide::UInt;
using Halide::Handle; using Halide::Handle;
...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt; ...@@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter; using Halide::Internal::IRPrinter;
using Halide::Internal::Variable; using Halide::Internal::Variable;
using Halide::Internal::make_const;
/*! \brief a named variable in TVM */ /*! \brief a named variable in TVM */
class Var : public Halide::VarExpr { class Var : public Halide::VarExpr {
public: public:
......
...@@ -18,6 +18,16 @@ ...@@ -18,6 +18,16 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
/*! /*!
* \brief verifies whether the IR stmt or Expr is in SSA form. * \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For) * That is: each VarExpr is defined and assigned once(in Let/For)
...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f, ...@@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body, Expr body,
Stmt stmt); Stmt stmt);
/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -12,6 +12,36 @@ ...@@ -12,6 +12,36 @@
namespace tvm { namespace tvm {
/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};
/*! /*!
* \brief A Compute op that compute a tensor on certain domain. * \brief A Compute op that compute a tensor on certain domain.
*/ */
...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode { ...@@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
size_t num_outputs() const final { int num_outputs() const final {
return 1; return 1;
} }
Array<IterVar> root_iter_vars() const final; Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final; Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final; Array<Expr> output_shape(size_t i) const final;
...@@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode { ...@@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */ /*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>; using FCompute = std::function<Expr (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");
/*! /*!
* \brief Construct a new tensor by computing over shape, * \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis) * using the computation rule: result_tensor[axis] = fcompute(axis)
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file bound.h * \file schedule_pass.h
* \brief The bound inference logics on the schedule. * \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/ */
#ifndef TVM_SCHEDULE_BOUND_H_ #ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_BOUND_H_ #define TVM_SCHEDULE_PASS_H_
#include <tvm/expr.h> #include "./base.h"
#include <tvm/schedule.h> #include "./schedule.h"
#include <unordered_map>
namespace tvm { namespace tvm {
namespace schedule { namespace schedule {
...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch); ...@@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
#endif // TVM_SCHEDULE_BOUND_H_
...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef; ...@@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input, * \brief Tensor structure representing a possible input,
* or intermediate computation result. * or intermediate computation result.
*/ */
class Tensor : public FunctionRef { class Tensor : public NodeRef {
public: public:
/*! \brief default constructor, used internally */ /*! \brief default constructor, used internally */
Tensor() {} Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {} explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef { ...@@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
}; };
/*! \brief Operation that produces tensors */ /*! \brief Operation that produces tensors */
class Operation : public NodeRef { class Operation : public FunctionRef {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Operation() {} Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {} explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*! /*!
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
...@@ -137,12 +128,10 @@ class Operation : public NodeRef { ...@@ -137,12 +128,10 @@ class Operation : public NodeRef {
}; };
/*! \brief Node to represent a tensor */ /*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode { class TensorNode : public Node {
public: public:
/*! \brief The shape of the tensor */ /*! \brief The shape of the tensor */
Array<Expr> shape; Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
/*! \brief the source operation, can be None */ /*! \brief the source operation, can be None */
...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode { ...@@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("op", &op); v->Visit("op", &op);
v->Visit("value_index", &value_index); v->Visit("value_index", &value_index);
} }
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape, static Tensor make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index); int value_index);
...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode { ...@@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*! /*!
* \brief base class of operation node. * \brief base class of operation node.
*/ */
class OperationNode : public Node { class OperationNode : public FunctionBaseNode {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */ /*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */ /*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0; virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */ /*! \return shape of i-th output */
......
...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32): ...@@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype) return _function_internal._Var(name, dtype)
def placeholder(shape, dtype = None, name="TensorObj"): def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object. """Construct an empty tensor object.
Parameters Parameters
...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"): ...@@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor The created tensor
""" """
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
return _function_internal._Tensor( return _function_internal._Placeholder(
shape, name, dtype, None, 0) shape, dtype, name)
def compute(shape, fcompute, name="TensorCompute"): def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain. """Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis) The compute rule is result[axis] = fcompute(axis)
......
...@@ -34,7 +34,9 @@ class Tensor(NodeBase): ...@@ -34,7 +34,9 @@ class Tensor(NodeBase):
else: else:
raise ValueError("The indices must be expression") raise ValueError("The indices must be expression")
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0) return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)
def __getitem__(self, indices): def __getitem__(self, indices):
return TensorSlice(self, indices) return TensorSlice(self, indices)
...@@ -71,3 +73,7 @@ class Operation(NodeBase): ...@@ -71,3 +73,7 @@ class Operation(NodeBase):
@register_node @register_node
class ComputeOp(Operation): class ComputeOp(Operation):
pass pass
@register_node
class PlaceholderOp(Operation):
pass
...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For) ...@@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5)); args.at(5));
}); });
TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});
TVM_REGISTER_API(_make_Call) TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0), *ret = Call::make(args.at(0),
...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt); ...@@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store); REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide); REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse); REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate); REGISTER_MAKE1(Evaluate);
......
...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range) ...@@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0), *ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2), args.at(2),
args.at(3), args.at(3),
args.at(4)); args.at(4));
...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash) ...@@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor())); std::hash<Tensor>()(args.at(0).operator Tensor()));
}); });
TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_ComputeOp) TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0), *ret = ComputeOpNode::make(args.at(0),
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue; ...@@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA); REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/tensor.h> #include <tvm/tensor.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h" #include "../schedule/graph.h"
namespace tvm { namespace tvm {
......
...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IterVarNode); TVM_REGISTER_NODE_TYPE(IterVarNode);
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) { .set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = "; p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value); p->print(op->value);
p->stream << '\n'; p->stream << '\n';
p->print(op->body); p->print(op->body);
......
...@@ -9,11 +9,73 @@ ...@@ -9,11 +9,73 @@
namespace tvm { namespace tvm {
Tensor Operation::output(size_t i) const {
auto node = std::make_shared<TensorNode>();
node->op = *this;
node->value_index = 0;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
// PlaceholderOpNode
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) { .set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
p->stream << "op(" << op << ")"; p->stream << "placeholder(" << op->name << ", " << op << ")";
}); });
TVM_REGISTER_NODE_TYPE(PlaceholderOpNode);
Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
Type PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
Type dtype) {
auto n = std::make_shared<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
n->dtype = dtype;
return Operation(n);
}
Tensor Placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
// ComputeOpNode
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>(); auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension. // compute dimension.
...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name, ...@@ -43,39 +105,10 @@ Operation ComputeOpNode::make(std::string name,
return Operation(n); return Operation(n);
} }
Tensor Operation::output(size_t i) const { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
auto node = std::make_shared<TensorNode>(); .set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
node->op = *this; p->stream << "compute(" << op->name << ", " << op << ")";
node->value_index = 0; });
node->name = (*this)->output_name(i);
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
CHECK_EQ(i, 0U);
return name;
}
Type ComputeOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return body.type();
}
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode); TVM_REGISTER_NODE_TYPE(ComputeOpNode);
......
...@@ -8,33 +8,24 @@ ...@@ -8,33 +8,24 @@
namespace tvm { namespace tvm {
Tensor::Tensor(Array<Expr> shape, std::string name, Type dtype) {
auto node = std::make_shared<TensorNode>();
node->name = std::move(name);
node->dtype = dtype;
node->shape = std::move(shape);
node_ = std::move(node);
}
Expr Tensor::operator()(Array<Expr> indices) const { Expr Tensor::operator()(Array<Expr> indices) const {
using Halide::Internal::Call; using Halide::Internal::Call;
CHECK_EQ(ndim(), indices.size()) CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read" << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "ndim = " << ndim() << ", indices.size=" << indices.size();
auto n = Call::make( auto n = Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this); (*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
return n; return n;
} }
Tensor TensorNode::make(Array<Expr> shape, Tensor TensorNode::make(Array<Expr> shape,
std::string name,
Type dtype, Type dtype,
Operation op, Operation op,
int value_index) { int value_index) {
auto n = std::make_shared<TensorNode>(); auto n = std::make_shared<TensorNode>();
n->shape = shape; n->shape = shape;
n->name = name;
n->dtype = dtype; n->dtype = dtype;
n->op = op; n->op = op;
n->value_index = value_index; n->value_index = value_index;
...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape, ...@@ -44,7 +35,7 @@ Tensor TensorNode::make(Array<Expr> shape,
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) { .set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
p->stream << "Tensor(shape=" << t->shape p->stream << "Tensor(shape=" << t->shape
<< ", name=" << t->name << ')'; << ", op.name=" << t->op->name << ')';
}); });
TVM_REGISTER_NODE_TYPE(TensorNode); TVM_REGISTER_NODE_TYPE(TensorNode);
......
...@@ -22,6 +22,7 @@ class IRInline : public IRMutator { ...@@ -22,6 +22,7 @@ class IRInline : public IRMutator {
expr = IRMutator::Mutate(expr); expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>(); const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) { if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call); return InlineCall(call);
} else { } else {
return expr; return expr;
...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f, ...@@ -55,6 +56,8 @@ Stmt Inline(FunctionRef f,
Array<Var> args, Array<Var> args,
Expr body, Expr body,
Stmt stmt) { Stmt stmt) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt)); return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
} }
} // namespace ir } // namespace ir
......
...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -254,11 +254,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
}) })
.set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) { .set_dispatch<Provide>([](const Provide *op, const Stmt& s, IRMutator* m) {
auto new_args = MutateArray(op->args, m); auto new_args = MutateArray(op->args, m);
auto new_values = MutateArray(op->values, m); auto new_value = m->Mutate(op->value);
if (op->args.same_as(new_args) && op->values.same_as(new_values)) { if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s; return s;
} else { } else {
return Provide::make(op->func, new_values, new_args); return Provide::make(op->func, op->value_index, new_value, new_args);
} }
}) })
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -312,7 +312,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
condition.same_as(op->condition)) { condition.same_as(op->condition)) {
return s; return s;
} else { } else {
return Realize::make(op->func, op->types, new_bounds, return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body); condition, body);
} }
}) })
...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -329,7 +330,10 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) { .set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition); Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case); Stmt then_case = m->Mutate(op->then_case);
Stmt else_case = m->Mutate(op->else_case); Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) && if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
......
...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) ...@@ -157,7 +157,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
}) })
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) { .set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v); VisitArray(op->args, v);
VisitArray(op->values, v); v->Visit(op->value);
}) })
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) { .set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) { for (size_t i = 0; i < op->extents.size(); i++) {
......
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h> #include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include "./scope.h" #include "./scope.h"
#include "../schedule/graph.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -20,7 +23,7 @@ namespace { ...@@ -20,7 +23,7 @@ namespace {
* IterVar->The assignment. * IterVar->The assignment.
*/ */
void PassUpOffset(const Schedule& s, void PassUpOffset(const Schedule& s,
const std::unordered_map<IterVar, Range>& dom_map, const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, Expr>* p_state) { std::unordered_map<IterVar, Expr>* p_state) {
auto& state = *p_state; auto& state = *p_state;
for (size_t i = s->relations.size(); i != 0; --i) { for (size_t i = s->relations.size(); i != 0; --i) {
...@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s, ...@@ -28,8 +31,8 @@ void PassUpOffset(const Schedule& s,
if (rel.as<SplitNode>()) { if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>(); const SplitNode* s = rel.as<SplitNode>();
Expr outer = state.at(s->outer); Expr outer = state.at(s->outer);
Expr inner = state.at(s->outer); Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
Expr offset = inner + outer * factor; Expr offset = inner + outer * factor;
Expr outer_min = dom_map.at(s->parent)->min; Expr outer_min = dom_map.at(s->parent)->min;
if (!is_zero(outer_min)) { if (!is_zero(outer_min)) {
...@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s, ...@@ -39,7 +42,7 @@ void PassUpOffset(const Schedule& s,
} else if (rel.as<FuseNode>()) { } else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>(); const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused); Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->outer)->extent; Expr factor = dom_map.at(s->inner)->extent;
state[s->outer] = value / factor; state[s->outer] = value / factor;
state[s->inner] = value % factor; state[s->inner] = value % factor;
} else { } else {
...@@ -84,24 +87,35 @@ void SplitByAdd(Expr expr, ...@@ -84,24 +87,35 @@ void SplitByAdd(Expr expr,
* \param nest A list of For and LetStmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined.
* \param body body * \param body body
*/ */
Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) { Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
while (!nest.empty()) { // use reverse iteration
Stmt s = std::move(nest.back()); for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
nest.pop_back(); for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
if (s.as<For>()) { Stmt s = *rj;
auto n = std::make_shared<For>(*s.as<For>()); if (s.as<For>()) {
n->body = body; auto n = std::make_shared<For>(*s.as<For>());
body = Stmt(n); CHECK(is_no_op(n->body));
} else if (s.as<LetStmt>()) { n->body = body;
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>()); body = Stmt(n);
n->body = body; } else if (s.as<LetStmt>()) {
body = Stmt(n); auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
} else if (s.as<AttrStmt>()) { CHECK(is_no_op(n->body));
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>()); n->body = body;
n->body = body; body = Stmt(n);
body = Stmt(n); } else if (s.as<AttrStmt>()) {
} else { auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
LOG(FATAL) << "not supported nest type"; CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
} }
} }
return body; return body;
...@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) { ...@@ -111,119 +125,251 @@ Stmt CombineNest(std::vector<Stmt>&& nest, Stmt body) {
* \brief Make the loop nest of the correspondings schedule. * \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule. * \param sch The schedule.
* \param dom_map The domain map. * \param dom_map The domain map.
*
* \return a nested representation of loop statements.
* The flattened Stmt are ordered from outmost to inner most order.
*/ */
std::vector<Stmt> MakeLoopNest( std::vector<std::vector<Stmt> > MakeLoopNest(
const Schedule& sch, const Schedule& sch,
const std::unordered_map<IterVar, Range>& dom_map) { const Map<IterVar, Range>& dom_map) {
// optional, use let to define some CSE in dom_map. // optional, use let to define some CSE in dom_map.
auto leaf_iter_vars = sch->leaf_iter_vars; auto leaf_iter_vars = sch->leaf_iter_vars;
std::unordered_map<IterVar, Expr> offset; std::unordered_map<IterVar, Expr> offset;
std::unordered_map<const Variable*, size_t> loop_level; std::unordered_map<const Variable*, size_t> loop_level;
Stmt no_op = Evaluate::make(0);
// create the loop nest // create the loop nest
std::vector<Stmt> nest; std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1, Stmt()); nest.resize(leaf_iter_vars.size() + 1);
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i]; auto iv = leaf_iter_vars[i];
// initialize the offset and loop_level // initialize the offset and loop_level
offset[iv] = iv->var; offset[iv] = iv->var;
loop_level[iv->var.as<Variable>()] = i + 1; loop_level[iv->var.as<Variable>()] = i + 1;
// Mark the iter var in the IR, to remember the point
nest[i] = AttrStmt::make(iv->var, "scope", iv, Stmt());
if (iv->thread_tag.length() == 0) { if (iv->thread_tag.length() == 0) {
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
nest[i] = For::make(iv->var, dom->min, dom->extent, nest[i + 1].emplace_back(
ForType::Serial, DeviceAPI::None, nest[i]); For::make(iv->var, dom->min, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
} }
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
} }
// message passing to get offset of root iter vars. // message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &offset); PassUpOffset(sch, dom_map, &offset);
for (IterVar iv : sch->op->root_iter_vars()) { for (IterVar iv : sch->op->root_iter_vars()) {
Expr value = offset.at(iv); Expr value = offset.at(iv);
if (value.same_as(iv->var)) continue; if (!value.same_as(iv->var)) {
using Entry = std::pair<size_t, Expr>; using Entry = std::pair<size_t, Expr>;
std::vector<Entry> splits; std::vector<Entry> splits;
SplitByAdd(value, loop_level, &splits); SplitByAdd(value, loop_level, &splits);
Expr offset = 0; Expr offset = 0;
for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) { size_t nsplit_left = splits.size() - 1;
auto iv = leaf_iter_vars[i]; for (size_t i = 0; i <= leaf_iter_vars.size(); ++i) {
for (const auto& kv : splits) { size_t hit = 0;
if (kv.first == i) { for (const auto& kv : splits) {
offset = offset + splits[i].second; if (kv.first == i) {
if (is_zero(offset)) {
offset = kv.second;
} else {
offset = offset + kv.second;
++hit;
}
}
} }
nsplit_left -= hit;
if (hit != 0) {
std::ostringstream os;
os << iv->var->name_hint << ".at.l" << i;
Var base_offset(os.str());
if (nsplit_left == 0) {
base_offset = iv->var;
}
nest[i].emplace_back(
LetStmt::make(base_offset, offset, no_op));
offset = base_offset;
}
}
Range dom = dom_map.at(iv);
if (!offset.same_as(iv->var)) {
// define the iv->var
nest.back().emplace_back(
LetStmt::make(iv->var, offset, no_op));
} }
std::ostringstream os; Expr condition = (iv->var - dom->min) < dom->extent;
os << iv->var->name_hint << ".at.l" << i; // Boundary condition checking
Var base_offset(os.str()); // Need better boundary condition here.
nest[i] = LetStmt::make(base_offset, offset, nest[i]); nest.back().emplace_back(IfThenElse::make(condition, no_op));
offset = base_offset;
} }
nest.back() = LetStmt::make(iv->var, offset, nest.back());
} }
return nest; return nest;
} }
/*! /*!
* \brief Make the loop nest of the correspondings schedule. * \brief Make pipeline specifically for compute op node.
* \param op The operation. * \param op The compute node
* \param tensors The tensors generated by provide.
*/ */
Stmt MakeBody(const Operation& op) { Stmt MakeProvide(const ComputeOpNode* op,
Stmt body; const std::vector<Tensor>& tensors) {
if (op.as<ComputeOpNode>()) { Tensor t = tensors[0];
const ComputeOpNode* compute = op.as<ComputeOpNode>(); Array<Expr> args;
// Note: Tensor's address cannot uniquely for (IterVar iv : op->axis) {
Tensor t = op.output(0); args.push_back(iv->var);
Array<Expr> args; }
for (IterVar iv : compute->axis) { return Provide::make(t->op, t->value_index, op->body, args);
args.push_back(iv->var); }
}
body = Provide::make(t, {compute->body}, args); /*!
* \brief Make pipeline specifically for compute op node.
* \param op The compute node
* \param dom_map The domain map
* \param tensors The tensors generated by provide.
* \param body The content of the pipeline.
*/
Stmt MakeRealize(const ComputeOpNode* op,
const Map<IterVar, Range>& dom_map,
const std::vector<Tensor>& tensors,
Stmt body) {
Tensor t = tensors[0];
Halide::Internal::Region bounds;
for (IterVar iv : op->axis) {
bounds.push_back(dom_map.at(iv));
}
return Realize::make(t->op, t->value_index, t->dtype,
bounds, make_const(Bool(1), true), body);
}
Stmt MakePipeline(const Schedule& sch,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
std::vector<Tensor> tensors;
for (int i = 0; i < sch->op->num_outputs(); ++i) {
tensors.emplace_back(sch->op.output(i));
}
Stmt provide;
if (sch->op.as<ComputeOpNode>()) {
provide = MakeProvide(sch->op.as<ComputeOpNode>(), tensors);
} else { } else {
LOG(FATAL) << "not supported op"; LOG(FATAL) << "not supported op";
} }
return body; std::vector<std::vector<Stmt> > nest = MakeLoopNest(sch, dom_map);
} Stmt producer = MergeNest(nest, provide);
producer = ProducerConsumer::make(sch->op, true, producer);
Stmt MakePipeline(const Schedule& sch, Stmt body) { Stmt pipeline = producer;
return body; if (consumer.defined()) {
consumer = ProducerConsumer::make(sch->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
if (sch->op.as<ComputeOpNode>()) {
return MakeRealize(sch->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
return Stmt();
}
} }
// inject the operator's realization on the stmt. // inject the operator's realization on the stmt.
class InjectRealize : public IRMutator { class InjectRealize : public IRMutator {
public: public:
explicit InjectRealize(Schedule sch) InjectRealize(Schedule schedule, Map<IterVar, Range> dom_map)
: sch_(sch) {} : schedule(schedule), dom_map(dom_map) {}
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
CHECK(stmt.defined());
stmt = IRMutator::Mutate(stmt);
const AttrStmt* op = stmt.as<AttrStmt>(); const AttrStmt* op = stmt.as<AttrStmt>();
if (op != nullptr) {
attr_scope_.Push({op->node, op->type_key}, op->value);
stmt = IRMutator::Mutate(stmt);
attr_scope_.Pop({op->node, op->type_key});
} else {
stmt = IRMutator::Mutate(stmt);
}
if (op != nullptr && if (op != nullptr &&
op->type_key == "scope" && op->type_key == "scope") {
op->node == sch_->attach_parent) { if (op->node == schedule->attach_parent) {
return AttrStmt::make( CHECK(!found_attach);
op->node, op->type_key, op->value, found_attach = true;
MakePipeline(sch_, op->body)); stmt = AttrStmt::make(
} else { op->node, op->type_key, op->value,
return stmt; MakePipeline(schedule, dom_map,
IRMutator::Mutate(op->body)));
}
} }
return stmt;
} }
private:
// the operations to be carried // the operations to be carried
Schedule sch_; Schedule schedule;
Scope<AttrKey, Expr> attr_scope_; // domain map
Map<IterVar, Range> dom_map;
// whether attach point is found
bool found_attach{false};
}; };
void GetOpToScheduleMap(
Schedule s,
std::unordered_map<Operation, Schedule>* ret) {
CHECK(!ret->count(s->op))
<< "Duplicated schedule for op";
(*ret)[s->op] = s;
for (Schedule c : s->children) {
GetOpToScheduleMap(c, ret);
}
}
// order schedule by DFS calling order of ops
std::vector<Schedule> OrderSchedule(Schedule s) {
auto g = schedule::CreateReadGraph(s->op);
auto post_order = schedule::PostDFSOrder(s->op, g);
std::unordered_map<Operation, Schedule> op2sch;
GetOpToScheduleMap(s, &op2sch);
std::vector<Schedule> sorder;
// reverse iteration.
for (size_t i = post_order.size(); i != 0; --i) {
sorder.push_back(op2sch.at(post_order[i - 1]));
}
return sorder;
}
Stmt InjectInline(const Operation op, Stmt body) {
CHECK(body.defined());
const ComputeOpNode* compute = op.as<ComputeOpNode>();
CHECK(compute != nullptr)
<< "can only inline compute op";
Array<Var> args;
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
return Inline(op, args, compute->body, body);
}
} // namespace } // namespace
Stmt ScheduleOps(
Schedule s, Map<IterVar, Range> dom_map) {
std::vector<Schedule> svec = OrderSchedule(s);
Stmt body = Stmt();
for (Schedule s : svec) {
if (s->attach_type == kInline) {
body = InjectInline(s->op, body);
} else if (s->attach_type == kRoot || s-> attach_type == kNone) {
body = MakePipeline(s, dom_map, body);
} else if (s->attach_type == kScope) {
CHECK(body.defined());
InjectRealize mutator(s, dom_map);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point";
}
}
return body;
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment