diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index ec1ac4c898a69356b39574d7fb0082dfb324fa73..ca78e9e51789a30724684fcb448d4138ac804a47 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -190,6 +190,11 @@ class Schedule : public NodeRef { * \param ops The ops to be scheduled. */ explicit Schedule(Array<Operation> ops); + /*! + * \brief Get a copy of current schedule. + * \return The copied schedule. + */ + Schedule copy() const; /*! * \brief Get the stage corresponds to the op * \param op The operation. @@ -257,7 +262,7 @@ class Schedule : public NodeRef { * * \return A normalized schedule, can be same as current one. */ - void normalize(); + Schedule normalize(); /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/python/tvm/build.py b/python/tvm/build.py index f144dc260097b9d44d7fdbb538153133207a30fe..6f1580e948d9d397956b6eb69023e7bbf17652dd 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -57,7 +57,7 @@ def lower(sch, else: raise ValueError("args must be Tensor, Buffer or Var") # normalize schedule first - sch.normalize() + sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 96569730d64327ef76ef52a3ddc48b425ac5a810..0783f260b2a8260952be67c3f447cbdeb745c98e 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -78,12 +78,17 @@ class Schedule(NodeBase): return self.stage_map[k] def normalize(self): - """Build a normalized schedule. + """Build a normalized schedule from the current schedule. Insert necessary rebase to make certain iter var to start from 0. This is needed before bound inference and followup step. + + Returns + ------- + sch : Schedule + The normalized schedule. """ - _api_internal._ScheduleNormalize(self) + return _api_internal._ScheduleNormalize(self) def create_group(self, outputs, inputs, include_inputs=False): """Create stage group by giving output and input boundary. @@ -261,7 +266,7 @@ class Stage(NodeBase): threads : list of threads The threads to be launched. """ - if isinstance(threads, _collections.IterVar): + if isinstance(threads, IterVar): threads = [threads] _api_internal._StageEnvThreads(self, threads) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 788b7e1acd0507b04cb1524a356d27f2caeb92ee..15155a232f4389a2f32574dec49d3aacbbee971d 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -311,7 +311,7 @@ TVM_REGISTER_API(_StageParallel) TVM_REGISTER_API(_ScheduleNormalize) .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Schedule() + *ret = args[0].operator Schedule() .normalize(); }); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 9705f18de4bf1e4edaf7525b747b2514eaa5727c..c0debcc29e19a2202fafd1dabc40d91bae4f738e 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -242,9 +242,11 @@ void InjectInline(ScheduleNode* sch) { ReplaceDataFlow(sch->stages, &repl); } -void Schedule::normalize() { - InjectInline(operator->()); - RebaseNonZeroMinLoop(*this); +Schedule Schedule::normalize() { + Schedule sn = copy(); + InjectInline(sn.operator->()); + RebaseNonZeroMinLoop(sn); + return sn; } // Handle reduction factor. diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index e31d2418bcec1a083588367b0540d9d37db40d06..4ef21f7809f231afac196edbb717d2e71ac7b30a 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -355,6 +355,52 @@ Schedule::Schedule(Array<Operation> ops) { } } +Stage CopyStage(const Stage& s) { + std::shared_ptr<StageNode> n = + std::make_shared<StageNode>(*s.operator->()); + return Stage(n); +} + +Schedule Schedule::copy() const { + // map of stages. + const ScheduleNode* self = operator->(); + std::unordered_map<Stage, Stage, NodeHash, NodeEqual> smap; + std::shared_ptr<ScheduleNode> n = std::make_shared<ScheduleNode>(); + n->outputs = self->outputs; + // Copy the stages. + for (Stage s : self->stages) { + Stage scopy = CopyStage(s); + smap[s] = scopy; + n->stages.push_back(scopy); + } + for (Stage g : self->groups) { + Stage gcopy = CopyStage(g); + smap[g] = gcopy; + n->groups.push_back(gcopy); + } + // Remaps the reference relations. + for (auto kv : self->stage_map) { + n->stage_map.Set(kv.first, smap.at(kv.second)); + } + for (Stage s : n->stages) { + if (s->attach_stage.defined()) { + s->attach_stage = smap.at(s->attach_stage); + } + if (s->group.defined()) { + s->group = smap.at(s->group); + } + } + for (Stage s : n->groups) { + if (s->attach_stage.defined()) { + s->attach_stage = smap.at(s->attach_stage); + } + if (s->group.defined()) { + s->group = smap.at(s->group); + } + } + return Schedule(n); +} + Stage Schedule::operator[](const Operation& op) { auto it = (*this)->stage_map.find(op); CHECK(it != (*this)->stage_map.end()) diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index 2aa3d0ddfe6249b254f7f0b72352de176b86ccb2..ab0b32e2afa5fef010aabf0c35ca2f1a6d71020b 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -10,7 +10,7 @@ def lower(s, args, name="mydot"): buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) binds[x] = buf arg_list.append(buf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds) diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 4af11364de914c384252ca050f5834ed8ff180cf..fb66c6011b285ee0b1795f0433223e0a73081b1b 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -60,7 +60,7 @@ def test_gemm(): max_auto_unroll_step = 0 # lowering test - s.normalize() + s = s.normalize() # one line to build the function. def check_device(device, host="stackvm"): diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 90c28a8af363fbfc0a2a8a78ea9fd77914a793d8..8548a33fa8cabb976b71dc7ebd543a896fa2621c 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -16,6 +16,7 @@ def test_add_pipeline(): s[C].bind(xi, tvm.thread_axis("blockIdx.x")) # compile to IR + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 21639399e918fc3de2b682365b8c5b4e2f8de5f3..c47a295460890cdde06f7063878ac2f84e338b18 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -22,6 +22,8 @@ def test_bound2(): A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') s = tvm.create_schedule(A2.op) xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8) + # test normalize not affecting schedule + _ = s.normalize() s[A1].compute_at(s[A2], yo) bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) @@ -41,6 +43,8 @@ def test_bound3(): xi0, xi1 = s[A2].split(xi, nparts=16) s[A2].bind(xi0, tvm.thread_axis("threadIdx.x")) yo, yi = s[A2].split(A2.op.axis[1], 16) + # test normalize not affecting schedule + _ = s.normalize() s[A2].reorder(xo, xi0, yo, xi1, yi) s[A1].compute_at(s[A2], yo) @@ -63,7 +67,7 @@ def test_bound_scan(): XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) s[XX].compute_at(s[s_update], xo) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) assert bounds[XX.op.axis[1]].extent.value == 4 @@ -77,7 +81,7 @@ def test_bound_conv1d(): B = tvm.compute(n, computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[0]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) @@ -92,7 +96,7 @@ def test_bound_blur(): B = tvm.compute((n-2, n-2), computeB, name='B') s = tvm.create_schedule(B.op) s[A].compute_at(s[B], B.op.axis[1]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A.op.axis[0]].extent.value == 3) assert(bounds[A.op.axis[1]].extent.value == 3) @@ -106,7 +110,7 @@ def test_bound_rfactor(): s = tvm.create_schedule(B.op) kf, ki = s[B].split(k, nparts=4) BF = s.rfactor(B, kf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BF.op.axis[0]].extent.value == 4) @@ -123,7 +127,7 @@ def test_bound_group_schedule(): g.compute_at(s[x2], x2.op.axis[0]) assert s[x1].group == g assert s[x].group == g - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent == n @@ -141,7 +145,7 @@ def test_bound_nest_group(): assert s[x1].group == g2 g2.compute_at(s[x2], x2.op.axis[0]) g1.compute_at(s[x1], s[x1].op.axis[1]) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert bounds[x.op.axis[0]].extent.value == 1 assert bounds[x.op.axis[1]].extent.value == 1 @@ -169,7 +173,7 @@ def test_bound_nest_thread(): _, xi = s[A2].split(A2.op.axis[0], nparts=1) s[A2].bind(xi, thread_x) s[A1].compute_at(s[A3], tx) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[A1.op.axis[0]].extent.value==1) assert(bounds[A2.op.axis[0]].extent.value==32) @@ -225,7 +229,7 @@ def test_gemm_bound(): tx, xi = s[BB].split(xi, nparts=num_thread) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[BB.op.axis[0]].extent.value==64) assert(bounds[AA.op.axis[0]].extent.value==64) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index f52f9eca385279cf8e47716176db972091b55d0e..ea02c60b81e5ffae97d9bae652be2172b41166cf 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -51,7 +51,7 @@ def test_schedule_scan(): assert tuple(res.shape) == (m, n) s = tvm.create_schedule(res.op) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -68,7 +68,7 @@ def test_auto_inline(): s = tvm.create_schedule(T2.op) tvm.schedule.AutoInlineElemWise(s) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) @@ -83,7 +83,7 @@ def test_inline_mixed(): xo, xi = s[C].split(C.op.axis[0], factor=8) s[A1].compute_at(s[C], xo) s[A2].compute_inline() - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) diff --git a/tests/verilog/integration/test_codegen_verilog.py b/tests/verilog/integration/test_codegen_verilog.py index 7003fa936bd1d258bbd90e344ac479c386e9f756..f1994e38c5f9009b52947bc484cddad1a6adee03 100644 --- a/tests/verilog/integration/test_codegen_verilog.py +++ b/tests/verilog/integration/test_codegen_verilog.py @@ -11,7 +11,7 @@ def lower(s, args, name): buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) binds[x] = buf arg_list.append(buf) - s.normalize() + s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.StorageFlatten(stmt, binds)