From 3b8ad0a22870152d6eb4979d7dcbd55f5b960179 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Mon, 17 Apr 2017 18:03:12 -0700 Subject: [PATCH] [SCHEDULE] Normalize returns a new schedule (#94) --- include/tvm/schedule.h | 7 ++- python/tvm/build.py | 2 +- python/tvm/schedule.py | 11 +++-- src/api/api_lang.cc | 2 +- src/schedule/schedule_dataflow_rewrite.cc | 8 ++-- src/schedule/schedule_lang.cc | 46 +++++++++++++++++++ tests/python/integration/test_dot.py | 2 +- tests/python/integration/test_gemm.py | 2 +- tests/python/unittest/test_codegen_device.py | 1 + .../unittest/test_schedule_bound_inference.py | 20 ++++---- .../unittest/test_schedule_schedule_ops.py | 6 +-- .../integration/test_codegen_verilog.py | 2 +- 12 files changed, 86 insertions(+), 23 deletions(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index ec1ac4c89..ca78e9e51 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -190,6 +190,11 @@ class Schedule : public NodeRef { * \param ops The ops to be scheduled. */ explicit Schedule(Array<Operation> ops); + /*! + * \brief Get a copy of current schedule. + * \return The copied schedule. + */ + Schedule copy() const; /*! * \brief Get the stage corresponds to the op * \param op The operation. @@ -257,7 +262,7 @@ class Schedule : public NodeRef { * * \return A normalized schedule, can be same as current one. */ - void normalize(); + Schedule normalize(); /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/python/tvm/build.py b/python/tvm/build.py index f144dc260..6f1580e94 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -57,7 +57,7 @@ def lower(sch, else: raise ValueError("args must be Tensor, Buffer or Var") # normalize schedule first - sch.normalize() + sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 96569730d..0783f260b 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -78,12 +78,17 @@ class Schedule(NodeBase): return self.stage_map[k] def normalize(self): - """Build a normalized schedule. + """Build a normalized schedule from the current schedule. Insert necessary rebase to make certain iter var to start from 0. This is needed before bound inference and followup step. + + Returns + ------- + sch : Schedule + The normalized schedule. """ - _api_internal._ScheduleNormalize(self) + return _api_internal._ScheduleNormalize(self) def create_group(self, outputs, inputs, include_inputs=False): """Create stage group by giving output and input boundary. @@ -261,7 +266,7 @@ class Stage(NodeBase): threads : list of threads The threads to be launched. """ - if isinstance(threads, _collections.IterVar): + if isinstance(threads, IterVar): threads = [threads] _api_internal._StageEnvThreads(self, threads) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 788b7e1ac..15155a232 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel) TVM_REGISTER_API(_ScheduleNormalize) .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Schedule() + *ret = args[0].operator Schedule() .normalize(); }); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 9705f18de..c0debcc29 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) { ReplaceDataFlow(sch->stages, &repl); } -void Schedule::normalize() { - InjectInline(operator->()); - RebaseNonZeroMinLoop(*this); +Schedule Schedule::normalize() { + Schedule sn = copy(); + InjectInline(sn.operator->()); + RebaseNonZeroMinLoop(sn); + return sn; } // Handle reduction factor. diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index e31d2418b..4ef21f780 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) { } } +Stage CopyStage(const Stage& s) { + std::shared_ptr<StageNode> n = + std::make_shared<StageNode>(*s.operator->()); + return Stage(n); +} + +Schedule Schedule::copy() const { + // map of stages. + const ScheduleNode* self = operator->(); + std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap; + std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>(); + n->outputs = self->outputs; + // Copy the stages. + for (Stage s : self->stages) { + Stage scopy = CopyStage(s); + smap[s] = scopy; + n->stages.push_back(scopy); + } + for (Stage g : self->groups) { + Stage gcopy = CopyStage(g); + smap[g] = gcopy; + n->groups.push_back(gcopy); + } + // Remaps the reference relations. + for (auto kv : self->stage_map) { + n->stage_map.Set(kv.first, smap.at(kv.second)); + } + for (Stage s : n->stages) { + if (s->attach_stage.defined()) { + s->attach_stage = smap.at(s->attach_stage); + } + if (s->group.defined()) { + s->group = smap.at(s->group); + } + } + for (Stage s : n->groups) { + if (s->attach_stage.defined()) { + s->attach_stage = smap.at(s->attach_stage); + } + if (s->group.defined()) { + s->group = smap.at(s->group); + } + } + return Schedule(n); +} + Stage Schedule::operator[](const Operation& op) { auto it = (*this)->stage_map.find(op); CHECK(it != (*this)->stage_map.end()) diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index 2aa3d0ddf..ab0b32e2a 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -10,7 +10,7 @@ def lower(s, args, name="mydot"): buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) binds[x] = buf arg_list.append(buf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds) diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 4af11364d..fb66c6011 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -60,7 +60,7 @@ def test_gemm(): max_auto_unroll_step = 0 # lowering test - s.normalize() + s = s.normalize() # one line to build the function. def check_device(device, host="stackvm"): diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 90c28a8af..8548a33fa 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -16,6 +16,7 @@ def test_add_pipeline(): s[C].bind(xi, tvm.thread_axis("blockIdx.x")) # compile to IR + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 21639399e..c47a29546 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -22,6 +22,8 @@ def test_bound2(): A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8) + # test normalize not affecting schedule + _ = s.normalize() s[A1].compute_at(s[A2], yo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) @@ -41,6 +43,8 @@ def test_bound3(): xi0, xi1 = s[A2].split(xi, nparts=16) s[A2].bind(xi0, tvm.thread_axis("threadIdx.x")) yo, yi = s[A2].split(A2.op.axis[1], 16) + # test normalize not affecting schedule + _ = s.normalize() s[A2].reorder(xo, xi0, yo, xi1, yi) s[A1].compute_at(s[A2], yo) @@ -63,7 +67,7 @@ def test_bound_scan(): XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) s[XX].compute_at(s[s_update], xo) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) assert bounds[XX.op.axis[1]].extent.value == 4 @@ -77,7 +81,7 @@ def test_bound_conv1d(): B = tvm.compute(n, computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[0]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) @@ -92,7 +96,7 @@ def test_bound_blur(): B = tvm.compute((n-2, n-2), computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[1]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[1]].extent.value == 3) @@ -106,7 +110,7 @@ def test_bound_rfactor(): s = tvm.create_schedule(B.op) kf, ki = s[B].split(k, nparts=4) BF = s.rfactor(B, kf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BF.op.axis[0]].extent.value == 4) @@ -123,7 +127,7 @@ def test_bound_group_schedule(): g.compute_at(s[x2], x2.op.axis[0]) assert s[x1].group == g assert s[x].group == g - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent == n @@ -141,7 +145,7 @@ def test_bound_nest_group(): assert s[x1].group == g2 g2.compute_at(s[x2], x2.op.axis[0]) g1.compute_at(s[x1], s[x1].op.axis[1]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1 @@ -169,7 +173,7 @@ def test_bound_nest_thread(): _, xi = s[A2].split(A2.op.axis[0], nparts=1) s[A2].bind(xi, thread_x) s[A1].compute_at(s[A3], tx) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A2.op.axis[0]].extent.value==32) @@ -225,7 +229,7 @@ def test_gemm_bound(): tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BB.op.axis[0]].extent.value==64) assert(bounds[AA.op.axis[0]].extent.value==64) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index f52f9eca3..ea02c60b8 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -51,7 +51,7 @@ def test_schedule_scan(): assert tuple(res.shape) == (m, n) s = tvm.create_schedule(res.op) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -68,7 +68,7 @@ def test_auto_inline(): s = tvm.create_schedule(T2.op) tvm.schedule.AutoInlineElemWise(s) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -83,7 +83,7 @@ def test_inline_mixed(): xo, xi = s[C].split(C.op.axis[0], factor=8) s[A1].compute_at(s[C], xo) s[A2].compute_inline() - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) diff --git a/tests/verilog/integration/test_codegen_verilog.py b/tests/verilog/integration/test_codegen_verilog.py index 7003fa936..f1994e38c 100644 --- a/tests/verilog/integration/test_codegen_verilog.py +++ b/tests/verilog/integration/test_codegen_verilog.py @@ -11,7 +11,7 @@ def lower(s, args, name): buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) binds[x] = buf arg_list.append(buf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds) -- GitLab