diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 0de8a88edb00da9e577f75b863584bdc530f5c9f..95ce130785d73d191ab3576fe013baaa06d9c13c 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -239,11 +239,16 @@ class ThreadPartitionInserter : public IRMutator { // Try to do partition at the candidate IRs class LoopPartitioner : public IRMutator { public: - explicit LoopPartitioner(std::unordered_set<const Node*> candidates) - : candidates_(candidates) {} + explicit LoopPartitioner(bool split_const_loop) + : selector(CandidateSelector(split_const_loop)) {} + + Stmt VisitAndMutate(const Stmt& stmt) { + selector.Visit(stmt); + return Mutate(stmt); + } Stmt Mutate_(const For* op, const Stmt& stmt) { - if (candidates_.count(op)) { + if (selector.candidates.count(op)) { Stmt s = TryPartition(op, stmt, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; @@ -266,7 +271,7 @@ class LoopPartitioner : public IRMutator { const IterVarNode *iv = op->node.as<IterVarNode>(); CHECK(iv); Var var = iv->var; - if (candidates_.count(op)) { + if (selector.candidates.count(op)) { Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; } @@ -295,9 +300,9 @@ class LoopPartitioner : public IRMutator { inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ - std::unordered_set<const Node*> candidates_; std::unordered_map<const Variable*, IntSet> hint_map_; std::unordered_map<const Variable*, IntSet> relax_map_; + CandidateSelector selector; }; Stmt LoopPartitioner::TryPartition(const Node* node, @@ -322,7 +327,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr body_begin; Stmt pre_stmt; if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) { - body_begin = true_itrv.min(); + body_begin = ir::Simplify(true_itrv.min()); if (!can_prove(body_begin == min)) { Expr cond = (body_begin - min >= 0); if (!can_prove(cond)) { @@ -343,7 +348,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr post_doubt_begin; Stmt post_stmt; if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) { - post_doubt_begin = true_itrv.max() + 1; + post_doubt_begin = ir::Simplify(true_itrv.max() + 1); if (!can_prove(true_itrv.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); @@ -354,8 +359,17 @@ Stmt LoopPartitioner::TryPartition(const Node* node, } // [post_doubt_begin, max] if (!partition_thread_scope) { - Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + Stmt post_body; + // If the loop is going from 0 to 1, replace the loop var with min value + if (as_const_int(max) && as_const_int(post_doubt_begin)) { + if (*as_const_int(max) == *as_const_int(post_doubt_begin)) { + post_body = Substitute(body, {{Var{var}, post_doubt_begin}}); + post_stmt = post_body; + } + } else { + post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + } } } } else { @@ -368,8 +382,15 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Stmt simplified_body = ConditionEliminator(partitions).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); s = MakeFor(node, post_doubt_begin - body_begin, new_body); - if (pre_stmt.defined()) s = Block::make(pre_stmt, s); - if (post_stmt.defined()) s = Block::make(s, post_stmt); + + if (!(pre_stmt.defined() && post_stmt.defined())) s = VisitAndMutate(s); + if (pre_stmt.defined()) s = Block::make(pre_stmt, s); + if (post_stmt.defined()) { + if (as_const_int(max) && as_const_int(post_doubt_begin)) { + post_stmt = VisitAndMutate(post_stmt); + } + s = Block::make(s, post_stmt); + } } else { Expr cond = const_true(); if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); @@ -402,9 +423,7 @@ class RemoveLikelyTags : public IRMutator { }; Stmt LoopPartition(Stmt stmt, bool split_const_loop) { - CandidateSelector selector(split_const_loop); - selector.Visit(stmt); - stmt = LoopPartitioner(selector.candidates).Mutate(stmt); + stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt); stmt = RemoveLikelyTags().Mutate(stmt); return stmt; } diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index a1025e1f662cfdfbba0a632fa4d69cfbb8c31764..85860ce824d000c59f1030748770d70dc5f87e19 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -177,6 +177,157 @@ def test_everything_during_deduction(): stmt = tvm.ir_pass.Simplify(stmt) assert(isinstance(stmt.body.body, tvm.stmt.IfThenElse)) +def test_single_likely(): + n = 60 + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.create_schedule(T.op) + x = T.op.axis[0] + xo, xi = s[T].split(x, factor=16) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_multi_likely(): + n = 94 + m = 62 + A = tvm.placeholder((n, m), name='A') + B = tvm.placeholder((n, m), name='B') + + T = tvm.compute((n, m), lambda i, j: A[i, j]+B[i, j]) + s = tvm.create_schedule(T.op) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + x, y = T.op.axis + xo, xi = s[T].split(x, factor=16) + yo, yi = s[T].split(y, factor=16) + s[T].reorder(xo, yo, xi, yi) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_oneD_pool(): + m = tvm.var('m') + ib = tvm.ir_builder.create() + #data = tvm.placeholder((16,), name = 'data') + data = ib.pointer("float32", name="A") + out = ib.pointer("float32", name="A") + with ib.for_range(0, 16, 'ow') as ow: + with ib.for_range(0, 3, 'kw') as kw: + with ib.if_scope(ib.likely(ow > 0)): + with ib.if_scope(ib.likely(ow < 15)): + out[ow] = tvm.max(out[ow], data[ow + kw - 1]) + with ib.for_range(0, 16, 'ow') as ow: + with ib.for_range(0, 3, 'kw') as kw: + with ib.if_scope(ib.likely(ow < 1)): + with ib.if_scope(ib.likely(kw > 0)): + out[ow] = tvm.max(out[ow], data[ow + kw - 1]) + with ib.for_range(0, 16, 'ow') as ow: + with ib.for_range(0, 3, 'kw') as kw: + with ib.if_scope(ib.likely(ow > 14)): + with ib.if_scope(ib.likely(kw < 2)): + out[ow] = tvm.max(out[ow], data[ow + kw - 1]) + + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_cce_loop_1(): + ib = tvm.ir_builder.create() + dtype = 'float16' + n = 514 + m = 514 + _A = tvm.placeholder((n*m,), name = 'A') + Ab = tvm.decl_buffer((n*m,), dtype, name="A") + A = ib.buffer_ptr(Ab) + _B = tvm.placeholder((n*m,), name = 'B') + Bb = tvm.decl_buffer((n*m,), dtype, name="B") + B = ib.buffer_ptr(Bb) + #for i in 0 to n-1: + with ib.for_range(0, 11, name="i") as i: + with ib.for_range(0, 160, name="j") as j: + with ib.if_scope(ib.likely(((i*160) + j) < 1600)): + A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1] + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_cce_loop_2(): + ib = tvm.ir_builder.create() + len = 112 + tile = 32 + loop = (len + tile - 1) // tile + with ib.for_range(0, loop, 'i') as i: + head = i * tile + with ib.if_scope(ib.likely(head + tile > len)): + tail = len + ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) + with ib.else_scope(): + tail = head + tile + ib.emit(tvm.call_extern('float32', "cce_intrisic", head, tail)) + + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + + +def test_cce_loop_3(): + ib = tvm.ir_builder.create() + loop1 = 4 + loop2 = 9998 + tile = 39991 + with ib.for_range(0,loop2,'i') as i: + with ib.for_range(0,loop1,'j') as j: + head1 = i + head2 = j + with ib.if_scope(ib.likely(head1*loop1 + head2 < tile)): + ib.emit(tvm.call_extern('float16',"cce_intrisic",head1)) + + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt,True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + +def test_conv_tiling(): + HSTR = WSTR = 1 + in_channel = 128 + kernel_height = kernel_width = 3 + out_channel = 64 + batch_size = 1 + in_height = in_width = 64 + out_height = out_width = in_height - kernel_height + 1 + data = tvm.placeholder((batch_size, in_channel, in_height, in_width), name='data') + kernel = tvm.placeholder((kernel_height, kernel_width, in_channel, + out_channel), name='kernel') + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + conv = tvm.compute((batch_size, out_channel, out_height, out_width), + lambda n, oc, oh, ow: tvm.sum(data[n, ic, oh*HSTR + kh, ow*WSTR + kw] * + kernel[kh, kw, ic, oc], + axis=[ic, kh, kw]), + name="conv2d") + s = tvm.create_schedule(conv.op) + + n, oc, oh, ow = conv.op.axis + oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + if __name__ == "__main__": test_basic() test_const_loop() @@ -187,3 +338,10 @@ if __name__ == "__main__": test_select() test_thread_axis2() test_everything_during_deduction() + test_single_likely() + test_multi_likely() + test_oneD_pool() + test_cce_loop_1() + test_cce_loop_2() + test_cce_loop_3() + test_conv_tiling()