From e9debc9be42237b7a43d88f9b7f5fc541aadde4f Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Tue, 9 May 2017 22:44:30 -0700 Subject: [PATCH] [PASS] Use likely tag & enable LoopPartition by default (#132) * [PASS] Use likely tag & enable LoopPartition by default * [PASS] Support thread_axis partition * Take IfThenElse branch method * [PASS] Insert branch at the innermost thread scope * [PASS] Select candidates before trying to partition & add test for select * [PASS] Clean code * Fix * Remove print & assert vectorize happens --- python/tvm/build.py | 1 + python/tvm/ir_builder.py | 14 + src/op/compute_op.cc | 13 +- src/pass/loop_partition.cc | 284 ++++++++++++++---- tests/python/integration/test_ewise.py | 4 +- tests/python/unittest/test_codegen_device.py | 1 + .../unittest/test_pass_loop_partition.py | 100 ++++-- 7 files changed, 332 insertions(+), 85 deletions(-) diff --git a/python/tvm/build.py b/python/tvm/build.py index 6b0b8debd..54464ec9f 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -67,6 +67,7 @@ def lower(sch, sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) + stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.VectorizeLoop(stmt) diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index e288384cb..b14b7442f 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -9,6 +9,7 @@ from . import ir_pass as _pass from . import collections as _collections from ._ffi.base import string_types from ._ffi.node import NodeGeneric +from .expr import Call as _Call class WithScope(object): """Auxiliary scope with""" @@ -308,6 +309,19 @@ class IRBuilder(object): """ return BufferVar(self, buf.data, buf.dtype) + def likely(self, expr): + """Add likely tag for expression. + Parameters + ---------- + expr : Expr + The expression. Usually a condition expression. + Returns + ------- + expr : Expr + The expression will likely tag. + """ + return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0) + def get(self): """Return the builded IR. diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 6aeace7a1..a2d3b25e2 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide( std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); - nest.push_back(op::MakeIfNest(op::MakeBoundCheck( - stage, dom_map, false, - std::unordered_set<IterVar>(), value_map))); + auto preds = op::MakeBoundCheck(stage, dom_map, false, + std::unordered_set<IterVar>(), value_map); + for (auto& e : preds) e = likely(e); + nest.push_back(op::MakeIfNest(preds)); if (stage->store_predicate.defined()) { nest.emplace_back(op::MakeIfNest({stage->store_predicate})); } @@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide( auto init_nest = op::MakeLoopNest( stage, dom_map, begin_loop, true, skip_iter, &init_value_map); - init_nest.push_back( - op::MakeIfNest( - op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map))); + auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map); + for (auto& e : preds) e = likely(e); + init_nest.push_back(op::MakeIfNest(preds)); init = Substitute(init, init_value_map); init = MergeNest(init_nest, init); // common nest diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 3a8f30e7d..bc8aea33d 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -10,6 +10,7 @@ #include <unordered_map> #include <unordered_set> #include "../arithmetic/int_set_internal.h" +#include "../runtime/thread_storage_scope.h" namespace tvm { namespace ir { @@ -37,12 +38,84 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) { return success; } +// Select potential candidate IRs that can be partitioned. +// Rule: +// - the range should not be const +// - there exist a condition expression in the scope that use the var +class CandidateSelector : public IRVisitor { + public: + using VarIsUsed = bool; + CandidateSelector() {} + + void Visit_(const For* op) { + if (!is_const(op->min) || !is_const(op->extent)) { + const Variable* var = op->loop_var.get(); + record_.insert({var, false}); + IRVisitor::Visit_(op); + if (record_.at(var)) { + candidates.insert(op); + } + record_.erase(var); + } else { + IRVisitor::Visit_(op); + } + } + + void Visit_(const AttrStmt* op) { + if (op->attr_key == attr::thread_extent) { + const IterVarNode *iv = op->node.as<IterVarNode>(); + CHECK(iv); + Var var = iv->var; + runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + if ((scope.rank == 0) && !is_const(op->value)) { + record_.insert({var.get(), false}); + IRVisitor::Visit_(op); + if (record_.at(var.get())) { + candidates.insert(op); + } + record_.erase(var.get()); + return; + } + } + IRVisitor::Visit_(op); + } + + void Visit_(const Call* op) { + if (op->is_intrinsic(Call::likely)) { + in_likely_ = true; + IRVisitor::Visit_(op); + in_likely_ = false; + } else { + IRVisitor::Visit_(op); + } + } + + void Visit_(const Variable* op) { + if (in_likely_ && record_.count(op)) { + record_.at(op) = true; + } + } + + std::unordered_set<const Node*> candidates; + + private: + bool in_likely_; + std::unordered_map<const Variable*, VarIsUsed> record_; +}; + +// Find valid partition for specific variable class PartitionFinder : public IRVisitor { public: explicit PartitionFinder(VarExpr current_var, - const std::unordered_map<const Variable*, IntSet>& dom_map) - : current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) { - for (const auto& kv : dom_map) out_vars_.insert(kv.first); + const std::unordered_map<const Variable*, IntSet>& hint_map, + const std::unordered_map<const Variable*, IntSet>& relax_map) + : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + for (const auto& kv : hint_map) { + out_vars_.insert(kv.first); + } + for (const auto& kv : relax_map) { + out_vars_.insert(kv.first); + } } void Visit_(const For* op) { @@ -73,10 +146,15 @@ class PartitionFinder : public IRVisitor { } } - void Visit_(const IfThenElse* op) { - if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) { - IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_); - partitions[op->condition.get()] = Partition{op->condition, interval}; + void Visit_(const Call* op) { + if (op->is_intrinsic(Call::likely)) { + Expr cond = op->args[0]; + if (ExprUseVars(cond, + std::unordered_set<const Variable*>({current_var_.get()}))) { + IntSet interval = + DeduceBound(current_var_, cond, hint_map_, relax_map_); + partitions[cond.get()] = Partition{cond, interval}; + } } else { IRVisitor::Visit_(op); } @@ -91,54 +169,124 @@ class PartitionFinder : public IRVisitor { std::unordered_map<const Variable*, IntSet> relax_map_; }; -class PartitionReplacer : public IRMutator { +// Eliminate the condition expressions by partitions +class ConditionEliminator : public IRMutator { public: - explicit PartitionReplacer(const std::unordered_map<const Node*, Partition>& ps) + explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps) : ps_(ps) {} - Expr Mutate(Expr e) override { - if (ps_.count(e.get())) { - return Mutate(const_true()); - } + using IRMutator::Mutate; + Expr Mutate(Expr e) final { + if (ps_.count(e.get())) return Mutate(const_true()); return IRMutator::Mutate(e); } - using IRMutator::Mutate; private: const std::unordered_map<const Node*, Partition>& ps_; }; + +// Insert the partition branch at the innermost thread scope +class ThreadPartitionInserter : public IRMutator { + public: + explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps, + Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::thread_extent) { + innermost_thread_scope_ = true; + Stmt stmt = IRMutator::Mutate_(op, s); + // add branch code inside the innermost thread scope + if (innermost_thread_scope_) { + Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body); + Stmt body = IfThenElse::make(cond_, simplified_body, op->body); + Expr value = this->Mutate(op->value); + stmt = AttrStmt::make(op->node, op->attr_key, value, body); + } + innermost_thread_scope_ = false; + return stmt; + } else { + return IRMutator::Mutate_(op, s); + } + } + + private: + const std::unordered_map<const Node*, Partition>& ps_; + Expr cond_; + bool innermost_thread_scope_; +}; + +// Try to do partition at the candidate IRs class LoopPartitioner : public IRMutator { public: - LoopPartitioner() {} + explicit LoopPartitioner(std::unordered_set<const Node*> candidates) + : candidates_(candidates) {} Stmt Mutate_(const For* op, const Stmt& stmt) { - if (!is_const(op->min) || !is_const(op->extent)) { - Stmt s = DoPartition(op, stmt); + if (candidates_.count(op)) { + Stmt s = TryPartition(op, stmt, op->loop_var, + op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; } - dom_map_.insert({op->loop_var.get(), + + // normal path when loop parittion fails + // normal loop variable can be put into hint map. + hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = IRMutator::Mutate_(op, stmt); - dom_map_.erase(op->loop_var.get()); + hint_map_.erase(op->loop_var.get()); + return res; + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { + if (op->attr_key != attr::thread_extent) { + return IRMutator::Mutate_(op, stmt); + } + + const IterVarNode *iv = op->node.as<IterVarNode>(); + CHECK(iv); + Var var = iv->var; + if (candidates_.count(op)) { + Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); + if (s.defined()) return s; + } + + // normal path when loop parittion fails. + runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + Stmt res; + if (scope.rank == 1) { + // threadIdx should be put into relax map, in case of divergence. + relax_map_.insert({var.get(), + IntSet::interval(make_zero(var.type()), op->value - 1)}); + res = IRMutator::Mutate_(op, stmt); + relax_map_.erase(var.get()); + } else { + hint_map_.insert({var.get(), + IntSet::interval(make_zero(var.type()), op->value - 1)}); + res = IRMutator::Mutate_(op, stmt); + hint_map_.erase(var.get()); + } return res; } private: - Stmt DoPartition(const For* op, const Stmt& stmt); + Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var, + Expr min, Expr max, Stmt body, bool partition_thread_scope); + inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); - std::unordered_map<const Variable*, IntSet> dom_map_; + /* Candidate IRs that may be partitioned potentially */ + std::unordered_set<const Node*> candidates_; + std::unordered_map<const Variable*, IntSet> hint_map_; + std::unordered_map<const Variable*, IntSet> relax_map_; }; -Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { - PartitionFinder finder(op->loop_var, dom_map_); - finder.Visit(op->body); +Stmt LoopPartitioner::TryPartition(const Node* node, const Stmt& stmt, + VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope) { + PartitionFinder finder(var, hint_map_, relax_map_); + finder.Visit(body); const auto& partitions = finder.partitions; - if (partitions.empty()) return Stmt(); - Expr min = op->min; - Expr max = op->min + op->extent - 1; Array<IntSet> sets; // merge partitions (take their intersect) for (const auto& kv : partitions) { @@ -146,64 +294,92 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { } IntSet true_itrv = Intersect(sets); - Stmt pre_stmt; Expr body_begin; + Stmt pre_stmt; if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) { body_begin = true_itrv.min(); if (!can_prove(body_begin == min)) { - if (!can_prove(body_begin - min >= 0)) { - LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) + Expr cond = (body_begin - min >= 0); + if (!can_prove(cond)) { + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); } // [min, body_begin) - Stmt body = Substitute(op->body, - {{Var{op->loop_var}, op->loop_var + min}}); - pre_stmt = For::make(op->loop_var, 0, - body_begin - min, op->for_type, op->device_api, body); + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(node, body_begin - min, pre_body); + } } } else { body_begin = min; } - Stmt post_stmt; Expr post_doubt_begin; + Stmt post_stmt; if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) { post_doubt_begin = true_itrv.max() + 1; if (!can_prove(true_itrv.max() == max)) { - if (!can_prove(max - post_doubt_begin >= 0)) { - LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) + Expr cond = (max - post_doubt_begin >= 0); + if (!can_prove(cond)) { + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); } // [post_doubt_begin, max] - Stmt body = Substitute(op->body, - {{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); - post_stmt = For::make(op->loop_var, 0, - max - post_doubt_begin + 1, op->for_type, op->device_api, body); + if (!partition_thread_scope) { + Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + } } } else { post_doubt_begin = max + 1; } - // [body_begin, post_doubt_begin) - Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); - Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); - Stmt simplified_stmt = For::make(op->loop_var, 0, - post_doubt_begin - body_begin, op->for_type, op->device_api, body); - Stmt s = simplified_stmt; - if (pre_stmt.defined()) { - s = Block::make(pre_stmt, s); - } - if (post_stmt.defined()) { - s = Block::make(s, post_stmt); + Stmt s; + if (!partition_thread_scope) { + // [body_begin, post_doubt_begin) + Stmt simplified_body = ConditionEliminator(partitions).Mutate(body); + Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); + s = MakeFor(node, post_doubt_begin - body_begin, new_body); + if (pre_stmt.defined()) s = Block::make(pre_stmt, s); + if (post_stmt.defined()) s = Block::make(s, post_stmt); + } else { + Expr cond = const_true(); + if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); + if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + s = ThreadPartitionInserter(partitions, cond).Mutate(stmt); } + s = ConvertSSA(s); + return s; +} - return Simplify(ConvertSSA(s)); +inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { + const For *for_node = static_cast<const For*>(node); + CHECK(for_node); + return For::make(for_node->loop_var, 0, extent, + for_node->for_type, for_node->device_api, body); } +class RemoveLikelyTags : public IRMutator { + public: + using IRMutator::Mutate; + + Expr Mutate_(const Call *op, const Expr& e) { + if (op->is_intrinsic(Call::likely)) { + CHECK_EQ(op->args.size(), 1); + return IRMutator::Mutate(op->args[0]); + } else { + return IRMutator::Mutate_(op, e); + } + } +}; + Stmt LoopPartition(Stmt stmt) { - stmt = LoopPartitioner().Mutate(stmt); + CandidateSelector selector; + selector.Visit(stmt); + stmt = LoopPartitioner(selector.candidates).Mutate(stmt); + stmt = RemoveLikelyTags().Mutate(stmt); return stmt; } diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 0c3ccaeaa..e2e867472 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler @tvm.register_func def tvm_callback_cuda_compile(code): print(code) - ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"]) + ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"]) return ptx def test_add(): # graph - n = tvm.convert(1024) + n = tvm.var('n') A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') bias = tvm.var("bias", dtype="float32") diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 1e3c4a53a..f7746f57d 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -22,6 +22,7 @@ def test_add_pipeline(): Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') Cb = tvm.decl_buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.LoopPartition(stmt) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.Simplify(stmt) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index ce1747f22..b5213b3bc 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -17,8 +17,8 @@ def test_basic(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt) + stmt = tvm.ir_pass.Simplify(stmt) assert('if' not in str(stmt.body.body.body.first)) - print(stmt) def test_multi_loop(): ib = tvm.ir_builder.create() @@ -27,41 +27,40 @@ def test_multi_loop(): with ib.for_range(0, 4, "i") as i: with ib.for_range(0, n, "j") as j: with ib.for_range(0, m, "k") as k: - with ib.if_scope(i*m+j+k < n): + with ib.if_scope(ib.likely(i*m+j+k < n)): ib.emit(tvm.make.Evaluate(m)) with ib.else_scope(): ib.emit(tvm.make.Evaluate(n)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt) - assert(not any(collect_visit(stmt.body.first, - lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) def test_multi_if(): - i = tvm.var('i') - j = tvm.var('j') - k = tvm.var('k') + ib = tvm.ir_builder.create() m = tvm.var('m') n = tvm.var('n') - stmt = tvm.make.For( - i, 0, 4, 0, 0, - tvm.make.For( - j, 0, n, 0, 0, - tvm.make.For( - k, 0, m, 0, 0, - tvm.make.Block( - tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)), - tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)) - )))) + with ib.for_range(0, 4, 'i') as i: + with ib.for_range(0, n, 'j') as j: + with ib.for_range(0, m, 'k') as k: + with ib.if_scope(ib.likely(i*m+j+k < n)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + with ib.if_scope(ib.likely(i*m+j-k < n)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt) + stmt = tvm.ir_pass.Simplify(stmt) assert('if' not in str(stmt.body.first)) - print(stmt) def test_thread_axis(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B') - s = tvm.create_schedule(B.op) s[B].set_scope("shared") @@ -72,12 +71,67 @@ def test_thread_axis(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) - stmt_ = tvm.ir_pass.LoopPartition(stmt) - assert('if' not in str(stmt_.body.body.body.first)) - print(stmt_) + stmt = tvm.ir_pass.LoopPartition(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + assert('if' not in str(stmt.body.body.body.first)) + +def test_vectorize(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + bias = tvm.var("bias", dtype="float32") + scale = tvm.var("scale", dtype="float32") + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C') + # schedule + s = tvm.create_schedule(C.op) + # create iter var and assign them tags. + num_thread = 32 + bx, x = s[C].split(C.op.axis[0], factor=num_thread*4) + tx, x = s[C].split(x, nparts=num_thread) + _, x = s[C].split(x, factor=4) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + s[C].vectorize(x) + stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False) + body = stmt.body.body.body.body.body + assert(x.var.name not in str(body.condition)) + assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp)))) + +def test_select(): + ib = tvm.ir_builder.create() + m = tvm.var('m') + n = tvm.var('n') + with ib.for_range(0, ((n+3)/4), 'i') as i: + with ib.for_range(0, 4, 'j') as j: + ib.emit(tvm.make.Evaluate( + tvm.make.Select(ib.likely(i*4+j<n), m, n))) + stmt = ib.get() + stmt = tvm.ir_pass.LoopPartition(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) + +def test_thread_axis2(): + n = tvm.convert(4096) + m = tvm.var('m') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = tvm.create_schedule(C.op) + num_thread = 32 + bx, x = s[C].split(C.op.axis[0], factor=32) + tx, x = s[C].split(x, nparts=num_thread) + _, x = s[C].split(x, factor=m) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False) + for_body = stmt.body.body.body.body.body.first + assert('threadIdx' not in str(for_body.extent)) if __name__ == "__main__": - test_multi_loop() test_basic() + test_multi_loop() test_multi_if() test_thread_axis() + test_vectorize() + test_select() + test_thread_axis2() -- GitLab