From b19e01bf27ecc0c4c0b0f461588bcddd58dd7d65 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sat, 11 Mar 2017 20:12:42 -0800 Subject: [PATCH] [PASS] RemoveNoOp. (#68) --- include/tvm/buffer.h | 1 - include/tvm/channel.h | 54 +++++ include/tvm/expr.h | 1 + include/tvm/ir.h | 11 +- include/tvm/ir_pass.h | 14 ++ src/api/api_pass.cc | 2 + src/lang/channel.cc | 22 ++ src/pass/remove_no_op.cc | 111 ++++++++++ src/pass/simple_passes.cc | 1 + src/pass/split_host_device.cc | 10 +- src/pass/split_pipeline.cc | 194 ++++++++++++++++++ src/schedule/schedule_ops.cc | 6 +- .../python/unittest/test_pass_remove_no_op.py | 29 +++ .../unittest/test_pass_split_pipeline.py | 31 +++ 14 files changed, 477 insertions(+), 10 deletions(-) create mode 100644 include/tvm/channel.h create mode 100644 src/lang/channel.cc create mode 100644 src/pass/remove_no_op.cc create mode 100644 src/pass/split_pipeline.cc create mode 100644 tests/python/unittest/test_pass_remove_no_op.py create mode 100644 tests/python/unittest/test_pass_split_pipeline.py diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 141e4b68f..9f266844f 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -1,4 +1,3 @@ - /*! * Copyright (c) 2016 by Contributors * \file buffer.h diff --git a/include/tvm/channel.h b/include/tvm/channel.h new file mode 100644 index 000000000..81f9e5248 --- /dev/null +++ b/include/tvm/channel.h @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file channel.h + * \brief Channel object for pipeline. + */ +#ifndef TVM_CHANNEL_H_ +#define TVM_CHANNEL_H_ + +#include <tvm/expr.h> + +namespace tvm { +// Node container of channel +struct ChannelNode; + +/*! \brief The data channel. */ +class Channel : public NodeRef { + public: + /*! \brief default constructor */ + Channel() {} + explicit Channel(std::shared_ptr<Node> n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const ChannelNode* operator->() const; +}; + +/*! + * \brief Generalized FIFO channel. + */ +struct ChannelNode : public Node { + /*! \brief Variable to channel handle */ + Var handle_var; + /*! \brief default data type in read/write */ + Type dtype; + + // visit all attributes + void VisitAttrs(AttrVisitor* v) final { + v->Visit("handle_var", &handle_var); + v->Visit("dtype", &dtype); + } + + static Channel make(Var handle_var, Type dtype); + static constexpr const char* _type_key = "Channel"; + + TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node); +}; + +// Inline implementations +inline const ChannelNode* Channel::operator->() const { + return static_cast<const ChannelNode*>(node_.get()); +} +} // namespace tvm +#endif // TVM_CHANNEL_H_ diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 761cd2b04..8d100d272 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -39,6 +39,7 @@ using Halide::Internal::as_const_int; using Halide::Internal::as_const_uint; using Halide::Internal::const_true; using Halide::Internal::const_false; +using Halide::Internal::is_no_op; inline Type TVMType2Type(TVMType t) { return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 153d3105f..6b7ba2927 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -90,9 +90,7 @@ constexpr const char* virtual_thread = "virtual_thread"; * \brief Mark storage scope of buffers */ constexpr const char* storage_scope = "storage_scope"; -/*! - * \brief Mark storage scope of realizations - */ +/*! \brief Mark storage scope of realization */ constexpr const char* realize_scope = "realize_scope"; /*! \brief Mark of loop scope */ constexpr const char* loop_scope = "loop_scope"; @@ -100,6 +98,13 @@ constexpr const char* loop_scope = "loop_scope"; constexpr const char* scan_update_scope = "scan_update_scope"; /*! \brief Mark of scan init scope */ constexpr const char* scan_init_scope = "scan_init_scope"; +// Pipeline related attributes +/*! \brief channel read scope */ +constexpr const char* channel_read_scope = "channel_read_scope"; +/*! \brief channel write scope */ +constexpr const char* channel_write_scope = "channel_write_scope"; +/*! \brief pipeline module scope */ +constexpr const char* pipeline_stage_scope = "pipeline_stage_scope"; } // namespace attr /*! \brief namespace of TVM Intrinsic functions */ diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index f1ee06188..8f71ad145 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -105,6 +105,20 @@ Stmt Inline(Stmt stmt, Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer); +/*! + * \brief Remove No Op from the Stmt. + * \param stmt The stmt to be trasnformed + * \return Transformed stmt. + */ +Stmt RemoveNoOp(Stmt stmt); + +/*! + * \brief Split statement into pipeine stages. + * \param stmt The stmt to be splitted + * \return Transformed stmt. + */ +Stmt SplitPipeline(Stmt stmt); + /*! * \brief unroll the constant loops * \param stmt The statment to be unrolled. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index f995f13d1..82de8addf 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -70,6 +70,8 @@ REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(LiftAllocate); REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(LoopPartition); +REGISTER_PASS1(RemoveNoOp); +REGISTER_PASS1(SplitPipeline); } // namespace ir } // namespace tvm diff --git a/src/lang/channel.cc b/src/lang/channel.cc new file mode 100644 index 000000000..dd850becf --- /dev/null +++ b/src/lang/channel.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file channel.cc + */ +#include <tvm/channel.h> + +namespace tvm { + +Channel ChannelNode::make(Var handle_var, Type dtype) { + auto n = std::make_shared<ChannelNode>(); + n->handle_var = handle_var; + n->dtype = dtype; + return Channel(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<ChannelNode>([](const ChannelNode *op, IRPrinter *p) { + p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ChannelNode); +} // namespace tvm diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc new file mode 100644 index 000000000..9709ae1b8 --- /dev/null +++ b/src/pass/remove_no_op.cc @@ -0,0 +1,111 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file remove_no_op.cc + * \brief Remove no op from the stmt + */ +#include <tvm/ir.h> +#include <tvm/ir_pass.h> +#include <tvm/ir_mutator.h> +#include <unordered_map> + +namespace tvm { +namespace ir { + +// Mark the statment of each stage. +class NoOpRemover : public IRMutator { + public: + Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<LetStmt>(); + return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; + } + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<AttrStmt>(); + return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; + } + Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<IfThenElse>(); + if (op->else_case.defined()) { + if (is_no_op(op->else_case)) { + if (is_no_op(op->then_case)) { + return MakeEvaluate(op->condition); + } else { + return IfThenElse::make(op->condition, op->then_case); + } + } else { + return stmt; + } + } else { + if (is_no_op(op->then_case)) { + return MakeEvaluate(op->condition); + } else { + return stmt; + } + } + } + Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<For>(); + return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; + } + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Allocate>(); + return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; + } + Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<ProducerConsumer>(); + return is_no_op(op->body) ? op->body : stmt; + } + Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Realize>(); + return is_no_op(op->body) ? op->body : stmt; + } + Stmt Mutate_(const Evaluate* op, const Stmt& s) final { + if (HasSideEffect(op->value)) return s; + return Evaluate::make(0); + } + Stmt Mutate_(const Block* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as<Block>(); + if (is_no_op(op->first)) { + return op->rest; + } else if (is_no_op(op->rest)) { + return op->first; + } else { + return stmt; + } + } + + private: + Stmt MakeEvaluate(Expr value) { + if (HasSideEffect(value)) { + return Evaluate::make(value); + } else { + return Evaluate::make(0); + } + } + Stmt MakeEvaluate(const Array<Expr>& values) { + Stmt stmt; + for (Expr e : values) { + if (HasSideEffect(e)) { + if (stmt.defined()) { + stmt = Block::make(stmt, Evaluate::make(e)); + } else { + stmt = Evaluate::make(e); + } + } + } + return stmt.defined() ? stmt : Evaluate::make(0); + } +}; + +Stmt RemoveNoOp(Stmt stmt) { + return NoOpRemover().Mutate(stmt); +} +} // namespace ir +} // namespace tvm diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 5fc928cdd..70af63ce7 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -48,6 +48,7 @@ class IRSubstitue : public IRMutator { }; Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) { + if (value_map.size() == 0) return stmt; IRSubstitue m; for (auto kv : value_map) { m.smap[kv.first.get()] = kv.second; diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 642c1ed12..d64eff792 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -5,6 +5,7 @@ */ #include <tvm/ir.h> #include <tvm/lowered_func.h> +#include <tvm/channel.h> #include <tvm/ir_pass.h> #include <tvm/ir_mutator.h> #include <tvm/runtime/module.h> @@ -17,7 +18,7 @@ namespace ir { class IRUseDefAnalysis : public IRMutator { public: Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { - if (op->type_key == "thread_extent") { + if (op->type_key == attr::thread_extent) { IterVar iv(op->node.node_); CHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times @@ -35,6 +36,13 @@ class IRUseDefAnalysis : public IRMutator { Stmt body = this->Mutate(op->body); if (value.same_as(value) && body.same_as(body)) return s; return AttrStmt::make(op->node, op->type_key, value, body); + } else if (op->type_key == attr::channel_write_scope || + op->type_key == attr::channel_read_scope) { + Channel ch(op->node.node_); + if (!use_count_.count(ch->handle_var.get())) { + this->HandleDef(ch->handle_var.get()); + } + return IRMutator::Mutate_(op, s); } else { return IRMutator::Mutate_(op, s); } diff --git a/src/pass/split_pipeline.cc b/src/pass/split_pipeline.cc new file mode 100644 index 000000000..93b3b86ed --- /dev/null +++ b/src/pass/split_pipeline.cc @@ -0,0 +1,194 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file split_pipeline.cc + * \brief Split statement into pipeline stage modules. + */ +#include <tvm/ir.h> +#include <tvm/expr.h> +#include <tvm/ir_pass.h> +#include <tvm/ir_visitor.h> +#include <tvm/ir_mutator.h> +#include <tvm/channel.h> +#include <unordered_map> +#include "./ir_util.h" + +namespace tvm { +namespace ir { + +class MarkChannelAccess : public IRMutator { + public: + MarkChannelAccess( + const std::unordered_map<const Variable*, Channel>& cmap) + : cmap_(cmap) {} + + Expr Mutate_(const Load *op, const Expr& e) final { + auto it = rmap_.find(op->buffer_var.get()); + if (it != rmap_.end()) { + ++it->second.read_count; + } + return IRMutator::Mutate_(op, e); + } + Stmt Mutate_(const Store *op, const Stmt& s) final { + auto it = rmap_.find(op->buffer_var.get()); + if (it != rmap_.end()) { + ++it->second.write_count; + } + return IRMutator::Mutate_(op, s); + } + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + if (cmap_.count(op->buffer_var.get())) { + CHECK(!rmap_.count(op->buffer_var.get())); + rmap_[op->buffer_var.get()] = Entry(); + Stmt body = Mutate(op->body); + body = CreateChannelAccess(op, body); + rmap_.erase(op->buffer_var.get()); + return body; + } else { + return IRMutator::Mutate_(op, s); + } + } + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->type_key == ir::attr::storage_scope) { + Var buf_var(op->node.node_); + if (cmap_.count(buf_var.get())) return Mutate(op->body); + } + return IRMutator::Mutate_(op, s); + } + + private: + // Create channel access wrap + Stmt CreateChannelAccess(const Allocate* op, Stmt body) { + const Entry& rw = rmap_.at(op->buffer_var.get()); + CHECK(rw.write_count == 0 || rw.read_count == 0) + << "Cannot read/write to the same channel " << op->buffer_var + << " body:" << body; + if (rw.write_count == 0 && rw.read_count == 0) { + return body; + } + const Channel& ch = cmap_.at(op->buffer_var.get()); + int32_t csize = op->constant_allocation_size(); + Expr alloc_size; + if (csize > 0) { + alloc_size = IntImm::make(Int(32), csize); + } else { + alloc_size = op->extents[0]; + for (size_t i = 1; i < op->extents.size(); ++i) { + alloc_size *= op->extents[i]; + } + alloc_size = ir::Simplify(alloc_size); + } + + if (rw.write_count) { + return AttrStmt::make( + ch, ir::attr::channel_write_scope, alloc_size, body); + } else { + CHECK(rw.read_count); + return AttrStmt::make( + ch, ir::attr::channel_read_scope, alloc_size, body); + } + } + struct Entry { + int read_count{0}; + int write_count{0}; + }; + // The channels of each allocation. + const std::unordered_map<const Variable*, Channel>& cmap_; + // the result. + std::unordered_map<const Variable*, Entry> rmap_; +}; + + +// Mark the statment of each stage. +class StageSplitter : public IRMutator { + public: + Stmt Mutate(Stmt stmt) final { + nest_.push_back(stmt); + Stmt ret = IRMutator::Mutate(stmt); + nest_.pop_back(); + return ret; + } + Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) { + if (!op->is_producer) return IRMutator::Mutate_(op, s); + Stmt body = Mutate(op->body); + stages_.emplace_back(BuildStage(body, op->func)); + return Evaluate::make(0); + } + + Stmt Split(Stmt stmt) { + stmt = Mutate(stmt); + stmt = RemoveNoOp(stmt); + CHECK(is_no_op(stmt)); + CHECK_NE(stages_.size(), 0); + stmt = stages_.back(); + for (size_t i = stages_.size() - 1; i != 0; --i) { + stmt = Block::make(stages_[i - 1], stmt); + } + stmt = MarkChannelAccess(cmap_).Mutate(stmt); + return RemoveNoOp(stmt); + } + + private: + // Build the stage. + Stmt BuildStage(Stmt body, NodeRef target) { + int stage_index = static_cast<size_t>(stages_.size()); + std::string stage_suffix = "." + std::to_string(stage_index); + // The Substitute + Map<Var, Expr> subst; + std::vector<Stmt> nest; + Stmt no_op = Evaluate::make(0); + + for (const Stmt& s : nest_) { + if (const For* op = s.as<For>()) { + Var loop_var(op->loop_var); + Var new_var = loop_var.copy_with_suffix(stage_suffix); + subst.Set(loop_var, new_var); + nest.emplace_back(For::make( + new_var, op->min, op->extent, + op->for_type, op->device_api, no_op)); + } else if (const LetStmt* op = s.as<LetStmt>()) { + Var var(op->var); + Var new_var = var.copy_with_suffix(stage_suffix); + subst.Set(var, new_var); + nest.emplace_back(LetStmt::make(new_var, op->value, no_op)); + } else if (const IfThenElse* op = s.as<IfThenElse>()) { + CHECK(!op->else_case.defined()); + nest.emplace_back(IfThenElse::make(op->condition, no_op)); + } else if (const AttrStmt* op = s.as<AttrStmt>()) { + nest.emplace_back(AttrStmt::make( + op->node, op->type_key, op->value, no_op)); + } else if (s.as<ProducerConsumer>()) { + } else if (s.as<Block>()) { + } else if (const Allocate* op = s.as<Allocate>()) { + nest.emplace_back(Allocate::make( + op->buffer_var, op->type, op->extents, + op->condition, no_op, op->new_expr, op->free_function)); + MarkChannel(op); + } else { + LOG(FATAL) << "not supported nest type " << s->type_key(); + } + } + body = Substitute(MergeNest(nest, body), subst); + return AttrStmt::make( + target, ir::attr::pipeline_stage_scope, + make_const(Int(32), stage_index), body); + } + void MarkChannel(const Allocate* op) { + if (!cmap_.count(op->buffer_var.get())) { + Channel ch = ChannelNode::make(Var(op->buffer_var), op->type); + cmap_[op->buffer_var.get()] = ch; + } + } + // The stack + std::vector<Stmt> nest_; + // The stages + std::vector<Stmt> stages_; + // channel map + std::unordered_map<const Variable*, Channel> cmap_; +}; + +Stmt SplitPipeline(Stmt stmt) { + return StageSplitter().Split(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 43081d8b3..6489f21e7 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -26,12 +26,8 @@ Stmt MakePipeline(const Stage& s, producer = ProducerConsumer::make(s->op, true, producer); } Stmt pipeline = producer; - // check if consumer is nop. - bool is_no_op{false}; - const Evaluate* ev = consumer.as<Evaluate>(); - if (ev && ev->value.as<IntImm>()) is_no_op = true; - if (consumer.defined() && !is_no_op) { + if (consumer.defined() && !is_no_op(consumer)) { consumer = ProducerConsumer::make(s->op, false, consumer); pipeline = Block::make(producer, consumer); } diff --git a/tests/python/unittest/test_pass_remove_no_op.py b/tests/python/unittest/test_pass_remove_no_op.py new file mode 100644 index 000000000..8aadaf8c0 --- /dev/null +++ b/tests/python/unittest/test_pass_remove_no_op.py @@ -0,0 +1,29 @@ +import tvm + +def test_remove_no_op(): + i = tvm.Var('i') + j = tvm.Var('j') + k = tvm.Var('k') + m = tvm.Var('m') + n = tvm.Var('n') + dtype = 'int64' + Ab = tvm.Buffer((n, ), dtype) + stmt = tvm.make.For( + i, 0, 4, 0, 0, + tvm.make.For( + j, 0, n, 0, 0, + tvm.make.For( + k, 0, m, 0, 0, + tvm.make.IfThenElse( + (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) + ret = tvm.ir_pass.RemoveNoOp(stmt) + assert(isinstance(ret, tvm.stmt.Evaluate)) + store = tvm.make.Store(Ab.data, + tvm.make.Load(dtype, Ab.data, i) + 1, + i + 1) + stmt2 = tvm.make.Block(stmt, store) + assert(tvm.ir_pass.RemoveNoOp(stmt2) == store) + + +if __name__ == "__main__": + test_remove_no_op() diff --git a/tests/python/unittest/test_pass_split_pipeline.py b/tests/python/unittest/test_pass_split_pipeline.py new file mode 100644 index 000000000..86beb5eee --- /dev/null +++ b/tests/python/unittest/test_pass_split_pipeline.py @@ -0,0 +1,31 @@ +import tvm + +def test_basic_pipeline(): + n = tvm.convert(128) + A = tvm.placeholder((n,), name='A') + stages = [] + num_stage = 3 + + B = A + for k in range(num_stage): + stages.append(B) + B = tvm.compute((n,), lambda i: B[i] + k, name="A%s" % k) + + s = tvm.Schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=4) + for S in stages: + s[S].compute_at(s[B], xo) + + # Lowering + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + Bb = tvm.Buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb}) + stmt = tvm.ir_pass.Simplify(stmt) + stmt = tvm.ir_pass.SplitPipeline(stmt) + print(stmt) + assert(tvm.ir_pass.VerifySSA(stmt)) + +if __name__ == "__main__": + test_basic_pipeline() -- GitLab