From 84aeaf4803901ef738c8d21005e21c899218c8ac Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Sun, 4 Jun 2017 17:34:14 -0700 Subject: [PATCH] Change Schedule Array constructor to static make method (#170) * Change Schedule Array constructor to static make method * Add CreateSchedule * Add doc * Change CreateSchedule to create_schedule at cpp side --- include/tvm/schedule.h | 21 +++++++--- python/tvm/schedule.py | 2 +- src/api/api_lang.cc | 4 +- src/schedule/schedule_lang.cc | 79 ++++++++++++++++++----------------- 4 files changed, 59 insertions(+), 47 deletions(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 085730a9b..4479c4fbe 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -197,11 +197,6 @@ class Schedule : public NodeRef { public: Schedule() {} explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {} - /*! - * \brief construct schedule for array of ops(and their dependencies). - * \param ops The ops to be scheduled. - */ - explicit Schedule(Array<Operation> ops); /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -439,10 +434,26 @@ class ScheduleNode : public Node { /*! \brief Invalidate temp cache. */ void InvalidateCache(); + /*! + * \brief Create a schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + * \return sch The created Schedule. + */ + static Schedule make(Array<Operation> ops); + static constexpr const char* _type_key = "Schedule"; TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); }; +/*! + * \brief Create a schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + * \return sch The created Schedule. + */ +inline Schedule create_schedule(Array<Operation> ops) { + return ScheduleNode::make(ops); +} + /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Node { public: diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index a23721514..8b2e63f9c 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -73,7 +73,7 @@ def create_schedule(ops): """ if not isinstance(ops, (list, _collections.Array)): ops = [ops] - return _api_internal._Schedule(ops) + return _api_internal._CreateSchedule(ops) @register_node diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 496b499f2..8dc53bc28 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -225,9 +225,9 @@ TVM_REGISTER_API("_IterVar") args[3]); }); -TVM_REGISTER_API("_Schedule") +TVM_REGISTER_API("_CreateSchedule") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Schedule(args[0].operator Array<Operation>()); + *ret = create_schedule(args[0].operator Array<Operation>()); }); TVM_REGISTER_API("_StageSetScope") diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 9e8056be1..25d72e346 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -322,45 +322,6 @@ Stage& Stage::parallel(IterVar var) { // NOLINT(*) return *this; } -Schedule::Schedule(Array<Operation> ops) { - auto n = std::make_shared<ScheduleNode>(); - node_ = n; - n->outputs = ops; - auto g = schedule::CreateReadGraph(n->outputs); - Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g); - // output set. - std::unordered_set<Operation> output_set; - for (Operation x : ops) { - output_set.insert(x); - } - for (Operation op : post_order) { - Stage stage(op); - stage->is_output = output_set.count(op) != 0; - n->stages.push_back(stage); - n->stage_map.Set(op, stage); - // mark scan updates. - if (op.as<ScanOpNode>()) { - const ScanOpNode* scan = op.as<ScanOpNode>(); - Array<Tensor> inputs; - for (Tensor t : scan->state_placeholder) { - inputs.push_back(t); - } - for (Tensor t : scan->inputs) { - inputs.push_back(t); - } - // Create the scan group. - Stage scan_group = create_group(scan->update, inputs, false); - scan_group->attach_type = kScanUpdate; - scan_group->attach_stage = stage; - - for (size_t i = 0; i < scan->update.size(); ++i) { - Stage s = n->stage_map[scan->update[i]->op]; - CHECK(scan_group.same_as(s->group)); - } - } - } -} - Stage CopyStage(const Stage& s) { std::shared_ptr<StageNode> n = std::make_shared<StageNode>(*s.operator->()); @@ -580,6 +541,46 @@ void ScheduleNode::InitCache() { CHECK_EQ(op2stage_cache_.size(), stages.size()); } +Schedule ScheduleNode::make(Array<Operation> ops) { + auto n = std::make_shared<ScheduleNode>(); + Schedule sch(n); + n->outputs = ops; + auto g = schedule::CreateReadGraph(n->outputs); + Array<Operation> post_order = schedule::PostDFSOrder(n->outputs, g); + // output set. + std::unordered_set<Operation> output_set; + for (Operation x : ops) { + output_set.insert(x); + } + for (Operation op : post_order) { + Stage stage(op); + stage->is_output = output_set.count(op) != 0; + n->stages.push_back(stage); + n->stage_map.Set(op, stage); + // mark scan updates. + if (op.as<ScanOpNode>()) { + const ScanOpNode* scan = op.as<ScanOpNode>(); + Array<Tensor> inputs; + for (Tensor t : scan->state_placeholder) { + inputs.push_back(t); + } + for (Tensor t : scan->inputs) { + inputs.push_back(t); + } + // Create the scan group. + Stage scan_group = sch.create_group(scan->update, inputs, false); + scan_group->attach_type = kScanUpdate; + scan_group->attach_stage = stage; + + for (size_t i = 0; i < scan->update.size(); ++i) { + Stage s = n->stage_map[scan->update[i]->op]; + CHECK(scan_group.same_as(s->group)); + } + } + } + return sch; +} + IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, -- GitLab