Skip to content
Snippets Groups Projects
Commit 28876530 authored by tqchen's avatar tqchen
Browse files

Add AttrStmt

parent 61de73b4
No related branches found
No related tags found
No related merge requests found
Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f
Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
......@@ -17,6 +17,7 @@ namespace tvm {
namespace ir {
using Halide::Internal::ExprNode;
using Halide::Internal::StmtNode;
using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
......@@ -47,6 +48,34 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This is used to insert hint(shape, storage, split) about certain scopes.
*/
struct AttrStmt : public StmtNode<AttrStmt> {
/*! \brief this is attribute about certain node */
NodeRef node;
/*! \brief the type key of the attribute */
std::string type_key;
/*! \brief The attribute value, value is well defined at current scope. */
Expr value;
/*! \brief The body statement to be executed */
Stmt body;
/*! \brief construct expr from name and rdom */
static Stmt make(NodeRef node, std::string type_key, Expr value, Stmt body);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("node", &node);
v->Visit("type_key", &type_key);
v->Visit("value", &value);
v->Visit("body", &body);
}
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "AttrStmt";
};
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
......
......@@ -32,6 +32,7 @@ class ComputeOpNode : public OperationNode {
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("domain", &domain);
v->Visit("name", &name);
......
......@@ -38,42 +38,7 @@ class Schedule : public NodeRef {
inline const ScheduleNode* operator->() const;
};
/*! \brief schedule container */
class AttachSpec : public NodeRef {
public:
AttachSpec() {}
explicit AttachSpec(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const AttachSpecNode* operator->() const;
};
// defintion of node containers
/*! \brief The attach specification of each subschedule */
class AttachSpecNode : public Node {
public:
/*! \brief The attachment type */
AttachType attach_type;
/*!
* \brief The split to be attached to,
* only valid when attach_type is kRoot
*/
Split attach_split;
/*! \brief the child schedule to be attached. */
Schedule schedule;
const char* type_key() const final {
return "AttachSpec";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("attach_type", &attach_type);
v->Visit("attach_split", &attach_split);
v->Visit("schedule", &schedule);
}
};
/*! \brief represents the schedule of the tensor */
class ScheduleNode : public Node {
public:
......@@ -83,8 +48,17 @@ class ScheduleNode : public Node {
std::string scope;
/*! \brief Splits over iteration domains */
Array<Split> splits;
/*! \brief attach specifications */
Array<AttachSpec> attachs;
/*! \brief The attachment type of the schedule */
AttachType attach_type;
/*!
* \brief The attach point of this schedule, if it is a split
* \note This is not a cyclic dependency,
* because split do not refer back to parent schedule.
*/
Split attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
// the type key
const char* type_key() const final {
return "Schedule";
}
......@@ -92,7 +66,9 @@ class ScheduleNode : public Node {
v->Visit("scope", &scope);
v->Visit("op", &op);
v->Visit("splits", &splits);
v->Visit("attachs", &attachs);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
}
};
......@@ -101,9 +77,5 @@ inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline const AttachSpecNode* AttachSpec::operator->() const {
return static_cast<const AttachSpecNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_SCHEDULE_H_
......@@ -29,13 +29,6 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});
TVM_REGISTER_API(_make_Reduce)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Reduce::make(args.at(0),
args.at(1),
args.at(2));
});
TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
......@@ -54,22 +47,6 @@ TVM_REGISTER_API(_make_Allocate)
args.at(4));
});
TVM_REGISTER_API(_make_LetStmt)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 3) {
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2));
} else {
CHECK_EQ(args.size(), 5);
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
}
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -89,6 +66,12 @@ TVM_REGISTER_API(_make_LetStmt)
*ret = Node::make(args.at(0), args.at(1), args.at(2)); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API(_make_## Node) \
.set_body([](const ArgStack& args, RetValue *ret) { \
......@@ -99,6 +82,9 @@ TVM_REGISTER_API(_make_LetStmt)
.add_argument("lhs", "Expr", "left operand") \
.add_argument("rhs", "Expr", "right operand")
REGISTER_MAKE3(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
......@@ -123,6 +109,7 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
......
......@@ -18,10 +18,16 @@ namespace Halide {
namespace Internal {
using tvm::ir::Reduce;
using tvm::ir::AttrStmt;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with IRVisitor yet";
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}
template<>
void StmtNode<AttrStmt>::accept(IRVisitor *v, const Stmt&) const {
LOG(FATAL) << "AttrStmt do not work with old Visitor, use IRFunctor style visitor";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
......@@ -33,15 +39,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ", rdom=" << op->rdom << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
});
} // namespace Internal
} // namespace Halide
namespace tvm {
namespace ir {
// reduce
TVM_REGISTER_NODE_TYPE(Reduce);
Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
auto n = std::make_shared<Reduce>();
CHECK(source.defined());
......@@ -52,9 +63,17 @@ Expr Reduce::make(std::string op, Expr source, RDomain rdom) {
return Expr(n);
}
Stmt AttrStmt::make(NodeRef node, std::string type_key, Expr value, Stmt body) {
auto n = std::make_shared<AttrStmt>();
n->node = node;
n->type_key = type_key;
n->value = value;
n->body = body;
return Stmt(n);
}
// HalideIR node
using namespace Halide::Internal;
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
......
......@@ -74,8 +74,6 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const {
return Array<Expr>(shape);
}
TVM_REGISTER_NODE_TYPE(ComputeOpNode);
} // namespace tvm
......@@ -13,7 +13,6 @@ Schedule::Schedule(Operation op, std::string scope) {
node_ = n;
}
TVM_REGISTER_NODE_TYPE(AttachSpecNode);
TVM_REGISTER_NODE_TYPE(ScheduleNode);
} // namespace tvm
......@@ -19,11 +19,12 @@ class IRInline : public IRMutator {
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
return InlineCall(call);
} else {
return IRMutator::Mutate(expr);
return expr;
}
}
......
......@@ -72,6 +72,18 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.set_dispatch<AttrStmt>([](const AttrStmt* op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Stmt body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->type_key, op->value, op->body);
}
});
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
......
......@@ -65,6 +65,12 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->source);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
......
......@@ -13,11 +13,19 @@ namespace {
// inject the operator's realization on the stmt.
class InjectRealize : public IRMutator {
public:
explicit InjectRealize(std::vector<Tensor> tensors)
: tensors_(tensors) {}
std::vector<Tensor> tensors_;
};
explicit InjectRealize(Schedule sch)
: sch_(sch) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
const For* op = stmt.as<For>();
return stmt;
}
private:
// the operations to be carried
Schedule sch_;
};
} // namespace
} // namespace ir
......
......@@ -22,10 +22,15 @@ def test_let():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1), y, "stride")
assert stmt.attr_of_node == y
print(stmt)
x, 10, tvm.make.Evaluate(x + 1));
def test_attr():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1));
assert stmt.node == y
print(stmt)
def test_basic():
a = tvm.Var('a')
......@@ -44,6 +49,8 @@ def test_stmt():
if __name__ == "__main__":
test_attr()
test_const()
test_make()
test_ir()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment