diff --git a/python/tvm/build.py b/python/tvm/build.py
index 6b0b8debd68e45bed340a93e15ea5bad65a2f811..54464ec9fb80d16d42699c9d83d6cf2e25f7fca6 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -67,6 +67,7 @@ def lower(sch,
     sch = sch.normalize()
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
+    stmt = ir_pass.LoopPartition(stmt)
     stmt = ir_pass.StorageFlatten(stmt, binds)
     stmt = ir_pass.CanonicalSimplify(stmt)
     stmt = ir_pass.VectorizeLoop(stmt)
diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py
index e288384cb8459fbe604430cf6b2ca7bd8241c3de..b14b7442f5f2a1dcd0befd95da103fec3cecd32c 100644
--- a/python/tvm/ir_builder.py
+++ b/python/tvm/ir_builder.py
@@ -9,6 +9,7 @@ from . import ir_pass as _pass
 from . import collections as _collections
 from ._ffi.base import string_types
 from ._ffi.node import NodeGeneric
+from .expr import Call as _Call
 
 class WithScope(object):
     """Auxiliary scope  with"""
@@ -308,6 +309,19 @@ class IRBuilder(object):
         """
         return BufferVar(self, buf.data, buf.dtype)
 
+    def likely(self, expr):
+        """Add likely tag for expression.
+        Parameters
+        ----------
+        expr : Expr
+            The expression. Usually a condition expression.
+        Returns
+        -------
+        expr : Expr
+            The expression will likely tag.
+        """
+        return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0)
+
     def get(self):
         """Return the builded IR.
 
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 6aeace7a1aee68c414276e92f1b5c2fbcbea2efc..a2d3b25e25e080423fb006cdf3bdcdb3a03c0886 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -311,9 +311,10 @@ Stmt ComputeOpNode::BuildProvide(
   std::unordered_map<IterVar, Expr> value_map;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
-  nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
-      stage, dom_map, false,
-      std::unordered_set<IterVar>(), value_map)));
+  auto preds = op::MakeBoundCheck(stage, dom_map, false,
+      std::unordered_set<IterVar>(), value_map);
+  for (auto& e : preds) e = likely(e);
+  nest.push_back(op::MakeIfNest(preds));
   if (stage->store_predicate.defined()) {
     nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
   }
@@ -352,9 +353,9 @@ Stmt ComputeOpNode::BuildProvide(
     auto init_nest = op::MakeLoopNest(
         stage, dom_map, begin_loop, true,
         skip_iter, &init_value_map);
-    init_nest.push_back(
-        op::MakeIfNest(
-            op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)));
+    auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
+    for (auto& e : preds) e = likely(e);
+    init_nest.push_back(op::MakeIfNest(preds));
     init = Substitute(init, init_value_map);
     init  = MergeNest(init_nest, init);
     // common nest
diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc
index 3a8f30e7d46b62a53f17d9be6ad1b8f5f233dd8c..bc8aea33dbc9594dc323d173e4f2c4016b8df78f 100644
--- a/src/pass/loop_partition.cc
+++ b/src/pass/loop_partition.cc
@@ -10,6 +10,7 @@
 #include <unordered_map>
 #include <unordered_set>
 #include "../arithmetic/int_set_internal.h"
+#include "../runtime/thread_storage_scope.h"
 
 namespace tvm {
 namespace ir {
@@ -37,12 +38,84 @@ bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
   return success;
 }
 
+// Select potential candidate IRs that can be partitioned.
+// Rule:
+//   - the range should not be const
+//   - there exist a condition expression in the scope that use the var
+class CandidateSelector : public IRVisitor {
+ public:
+  using VarIsUsed = bool;
+  CandidateSelector() {}
+
+  void Visit_(const For* op) {
+    if (!is_const(op->min) || !is_const(op->extent)) {
+      const Variable* var = op->loop_var.get();
+      record_.insert({var, false});
+      IRVisitor::Visit_(op);
+      if (record_.at(var)) {
+        candidates.insert(op);
+      }
+      record_.erase(var);
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  void Visit_(const AttrStmt* op) {
+    if (op->attr_key == attr::thread_extent) {
+      const IterVarNode *iv = op->node.as<IterVarNode>();
+      CHECK(iv);
+      Var var = iv->var;
+      runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
+      if ((scope.rank == 0) && !is_const(op->value)) {
+        record_.insert({var.get(), false});
+        IRVisitor::Visit_(op);
+        if (record_.at(var.get())) {
+          candidates.insert(op);
+        }
+        record_.erase(var.get());
+        return;
+      }
+    }
+    IRVisitor::Visit_(op);
+  }
+
+  void Visit_(const Call* op) {
+    if (op->is_intrinsic(Call::likely)) {
+      in_likely_ = true;
+      IRVisitor::Visit_(op);
+      in_likely_ = false;
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
+
+  void Visit_(const Variable* op) {
+    if (in_likely_ && record_.count(op)) {
+      record_.at(op) = true;
+    }
+  }
+
+  std::unordered_set<const Node*> candidates;
+
+ private:
+  bool in_likely_;
+  std::unordered_map<const Variable*, VarIsUsed> record_;
+};
+
+// Find valid partition for specific variable
 class PartitionFinder : public IRVisitor {
  public:
   explicit PartitionFinder(VarExpr current_var,
-    const std::unordered_map<const Variable*, IntSet>& dom_map)
-      : current_var_(current_var), out_vars_(dom_map.size()), hint_map_(dom_map) {
-        for (const auto& kv : dom_map) out_vars_.insert(kv.first);
+    const std::unordered_map<const Variable*, IntSet>& hint_map,
+    const std::unordered_map<const Variable*, IntSet>& relax_map)
+      : current_var_(current_var), hint_map_(hint_map),  relax_map_(relax_map) {
+        for (const auto& kv : hint_map) {
+          out_vars_.insert(kv.first);
+        }
+        for (const auto& kv : relax_map) {
+          out_vars_.insert(kv.first);
+        }
       }
 
   void Visit_(const For* op) {
@@ -73,10 +146,15 @@ class PartitionFinder : public IRVisitor {
     }
   }
 
-  void Visit_(const IfThenElse* op) {
-    if (ExprUseVars(op->condition, std::unordered_set<const Variable*>({current_var_.get()}))) {
-      IntSet interval = DeduceBound(current_var_, op->condition, hint_map_, relax_map_);
-      partitions[op->condition.get()] = Partition{op->condition, interval};
+  void Visit_(const Call* op) {
+    if (op->is_intrinsic(Call::likely)) {
+      Expr cond = op->args[0];
+      if (ExprUseVars(cond,
+          std::unordered_set<const Variable*>({current_var_.get()}))) {
+        IntSet interval =
+          DeduceBound(current_var_, cond, hint_map_, relax_map_);
+        partitions[cond.get()] = Partition{cond, interval};
+      }
     } else {
       IRVisitor::Visit_(op);
     }
@@ -91,54 +169,124 @@ class PartitionFinder : public IRVisitor {
   std::unordered_map<const Variable*, IntSet> relax_map_;
 };
 
-class PartitionReplacer : public IRMutator {
+// Eliminate the condition expressions by partitions
+class ConditionEliminator : public IRMutator {
  public:
-  explicit PartitionReplacer(const std::unordered_map<const Node*, Partition>& ps)
+  explicit ConditionEliminator(const std::unordered_map<const Node*, Partition>& ps)
     : ps_(ps) {}
 
-  Expr Mutate(Expr e) override {
-    if (ps_.count(e.get())) {
-      return Mutate(const_true());
-    }
+  using IRMutator::Mutate;
+  Expr Mutate(Expr e) final {
+    if (ps_.count(e.get())) return Mutate(const_true());
     return IRMutator::Mutate(e);
   }
-  using IRMutator::Mutate;
 
  private:
   const std::unordered_map<const Node*, Partition>& ps_;
 };
 
+
+// Insert the partition branch at the innermost thread scope
+class ThreadPartitionInserter : public IRMutator {
+ public:
+  explicit ThreadPartitionInserter(const std::unordered_map<const Node*, Partition>& ps,
+    Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
+
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    if (op->attr_key == attr::thread_extent) {
+      innermost_thread_scope_ = true;
+      Stmt stmt = IRMutator::Mutate_(op, s);
+      // add branch code inside the innermost thread scope
+      if (innermost_thread_scope_) {
+        Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body);
+        Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
+        Expr value = this->Mutate(op->value);
+        stmt = AttrStmt::make(op->node, op->attr_key, value, body);
+      }
+      innermost_thread_scope_ = false;
+      return stmt;
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+
+ private:
+  const std::unordered_map<const Node*, Partition>& ps_;
+  Expr cond_;
+  bool innermost_thread_scope_;
+};
+
+// Try to do partition at the candidate IRs
 class LoopPartitioner : public IRMutator {
  public:
-  LoopPartitioner() {}
+  explicit LoopPartitioner(std::unordered_set<const Node*> candidates)
+    : candidates_(candidates) {}
 
   Stmt Mutate_(const For* op, const Stmt& stmt) {
-    if (!is_const(op->min) || !is_const(op->extent)) {
-      Stmt s = DoPartition(op, stmt);
+    if (candidates_.count(op)) {
+      Stmt s = TryPartition(op, stmt, op->loop_var,
+          op->min, op->min + op->extent - 1, op->body, false);
       if (s.defined()) return s;
     }
-    dom_map_.insert({op->loop_var.get(),
+
+    // normal path when loop parittion fails
+    // normal loop variable can be put into hint map.
+    hint_map_.insert({op->loop_var.get(),
       IntSet::interval(op->min, op->min + op->extent - 1)});
     Stmt res = IRMutator::Mutate_(op, stmt);
-    dom_map_.erase(op->loop_var.get());
+    hint_map_.erase(op->loop_var.get());
+    return res;
+  }
+
+  Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
+    if (op->attr_key != attr::thread_extent) {
+      return IRMutator::Mutate_(op, stmt);
+    }
+
+    const IterVarNode *iv = op->node.as<IterVarNode>();
+    CHECK(iv);
+    Var var = iv->var;
+    if (candidates_.count(op)) {
+      Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true);
+      if (s.defined()) return s;
+    }
+
+    // normal path when loop parittion fails.
+    runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
+    Stmt res;
+    if (scope.rank == 1) {
+      // threadIdx should be put into relax map, in case of divergence.
+      relax_map_.insert({var.get(),
+        IntSet::interval(make_zero(var.type()), op->value - 1)});
+      res = IRMutator::Mutate_(op, stmt);
+      relax_map_.erase(var.get());
+    } else {
+      hint_map_.insert({var.get(),
+        IntSet::interval(make_zero(var.type()), op->value - 1)});
+      res = IRMutator::Mutate_(op, stmt);
+      hint_map_.erase(var.get());
+    }
     return res;
   }
 
  private:
-  Stmt DoPartition(const For* op, const Stmt& stmt);
+  Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var,
+      Expr min, Expr max, Stmt body, bool partition_thread_scope);
+  inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
 
-  std::unordered_map<const Variable*, IntSet> dom_map_;
+  /* Candidate IRs that may be partitioned potentially */
+  std::unordered_set<const Node*> candidates_;
+  std::unordered_map<const Variable*, IntSet> hint_map_;
+  std::unordered_map<const Variable*, IntSet> relax_map_;
 };
 
-Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
-  PartitionFinder finder(op->loop_var, dom_map_);
-  finder.Visit(op->body);
+Stmt LoopPartitioner::TryPartition(const Node* node, const Stmt& stmt,
+    VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope) {
+  PartitionFinder finder(var, hint_map_, relax_map_);
+  finder.Visit(body);
   const auto& partitions = finder.partitions;
-
   if (partitions.empty()) return Stmt();
 
-  Expr min = op->min;
-  Expr max = op->min + op->extent - 1;
   Array<IntSet> sets;
   // merge partitions (take their intersect)
   for (const auto& kv : partitions) {
@@ -146,64 +294,92 @@ Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) {
   }
   IntSet true_itrv  = Intersect(sets);
 
-  Stmt pre_stmt;
   Expr body_begin;
+  Stmt pre_stmt;
   if (true_itrv.as<arith::IntervalSet>()->i.has_lower_bound()) {
     body_begin = true_itrv.min();
     if (!can_prove(body_begin == min)) {
-      if (!can_prove(body_begin - min >= 0)) {
-        LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0)
+      Expr cond = (body_begin - min >= 0);
+      if (!can_prove(cond)) {
+        LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the pre doubt loop";
         body_begin = Max::make(body_begin, min);
       }
       // [min, body_begin)
-      Stmt body = Substitute(op->body,
-        {{Var{op->loop_var}, op->loop_var + min}});
-      pre_stmt = For::make(op->loop_var, 0,
-        body_begin - min, op->for_type, op->device_api, body);
+      if (!partition_thread_scope) {
+        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+        pre_stmt = MakeFor(node, body_begin - min, pre_body);
+      }
     }
   } else {
     body_begin = min;
   }
 
-  Stmt post_stmt;
   Expr post_doubt_begin;
+  Stmt post_stmt;
   if (true_itrv.as<arith::IntervalSet>()->i.has_upper_bound()) {
     post_doubt_begin = true_itrv.max() + 1;
     if (!can_prove(true_itrv.max() == max)) {
-      if (!can_prove(max - post_doubt_begin >= 0)) {
-        LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0)
+      Expr cond = (max - post_doubt_begin >= 0);
+      if (!can_prove(cond)) {
+        LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the post doubt loop";
         post_doubt_begin = Min::make(post_doubt_begin, max);
       }
       // [post_doubt_begin, max]
-      Stmt body = Substitute(op->body,
-        {{Var{op->loop_var}, op->loop_var + post_doubt_begin}});
-      post_stmt = For::make(op->loop_var, 0,
-        max - post_doubt_begin + 1, op->for_type, op->device_api, body);
+      if (!partition_thread_scope) {
+        Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
+        post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
+      }
     }
   } else {
     post_doubt_begin = max + 1;
   }
 
-  // [body_begin, post_doubt_begin)
-  Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body);
-  Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}});
-  Stmt simplified_stmt = For::make(op->loop_var, 0,
-    post_doubt_begin - body_begin, op->for_type, op->device_api, body);
-  Stmt s = simplified_stmt;
-  if (pre_stmt.defined()) {
-    s = Block::make(pre_stmt, s);
-  }
-  if (post_stmt.defined()) {
-    s = Block::make(s, post_stmt);
+  Stmt s;
+  if (!partition_thread_scope) {
+    // [body_begin, post_doubt_begin)
+    Stmt simplified_body = ConditionEliminator(partitions).Mutate(body);
+    Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
+    s = MakeFor(node, post_doubt_begin - body_begin, new_body);
+    if (pre_stmt.defined())  s = Block::make(pre_stmt, s);
+    if (post_stmt.defined()) s = Block::make(s, post_stmt);
+  } else {
+    Expr cond = const_true();
+    if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
+    if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
+    s = ThreadPartitionInserter(partitions, cond).Mutate(stmt);
   }
+  s = ConvertSSA(s);
+  return s;
+}
 
-  return Simplify(ConvertSSA(s));
+inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
+  const For *for_node = static_cast<const For*>(node);
+  CHECK(for_node);
+  return For::make(for_node->loop_var, 0, extent,
+    for_node->for_type, for_node->device_api, body);
 }
 
+class RemoveLikelyTags : public IRMutator {
+ public:
+  using IRMutator::Mutate;
+
+  Expr Mutate_(const Call *op, const Expr& e) {
+    if (op->is_intrinsic(Call::likely)) {
+      CHECK_EQ(op->args.size(), 1);
+      return IRMutator::Mutate(op->args[0]);
+    } else {
+      return IRMutator::Mutate_(op, e);
+    }
+  }
+};
+
 Stmt LoopPartition(Stmt stmt) {
-  stmt = LoopPartitioner().Mutate(stmt);
+  CandidateSelector selector;
+  selector.Visit(stmt);
+  stmt = LoopPartitioner(selector.candidates).Mutate(stmt);
+  stmt = RemoveLikelyTags().Mutate(stmt);
   return stmt;
 }
 
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index 0c3ccaeaa4848cf6b537aaf131bfd201087cb9bc..e2e867472710f28c95d2e359c082fad70ffc0537 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -64,12 +64,12 @@ from tvm.contrib import nvcc_compiler
 @tvm.register_func
 def tvm_callback_cuda_compile(code):
     print(code)
-    ptx =  nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
+    ptx =  nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_35"])
     return ptx
 
 def test_add():
     # graph
-    n = tvm.convert(1024)
+    n = tvm.var('n')
     A = tvm.placeholder((n,), name='A')
     B = tvm.placeholder((n,), name='B')
     bias = tvm.var("bias", dtype="float32")
diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py
index 1e3c4a53a9167a7b2dcc1e889b84a1930de090ce..f7746f57de94aadb75b7dad4ceab767ca017a241 100644
--- a/tests/python/unittest/test_codegen_device.py
+++ b/tests/python/unittest/test_codegen_device.py
@@ -22,6 +22,7 @@ def test_add_pipeline():
     Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
     Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
     Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
+    stmt = tvm.ir_pass.LoopPartition(stmt)
     stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
     stmt = tvm.ir_pass.Simplify(stmt)
     fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py
index ce1747f220fe38d13cbad3521c75ef8bd4b5fb13..b5213b3bcabaa0f286c545bf4c7a0aa6725113ff 100644
--- a/tests/python/unittest/test_pass_loop_partition.py
+++ b/tests/python/unittest/test_pass_loop_partition.py
@@ -17,8 +17,8 @@ def test_basic():
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     stmt = tvm.ir_pass.LoopPartition(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
     assert('if' not in str(stmt.body.body.body.first))
-    print(stmt)
 
 def test_multi_loop():
     ib = tvm.ir_builder.create()
@@ -27,41 +27,40 @@ def test_multi_loop():
     with ib.for_range(0, 4, "i") as i:
         with ib.for_range(0, n, "j") as j:
             with ib.for_range(0, m, "k") as k:
-                with ib.if_scope(i*m+j+k < n):
+                with ib.if_scope(ib.likely(i*m+j+k < n)):
                     ib.emit(tvm.make.Evaluate(m))
                 with ib.else_scope():
                     ib.emit(tvm.make.Evaluate(n))
     stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt)
-    assert(not any(collect_visit(stmt.body.first,
-                                 lambda x: isinstance(x, tvm.stmt.IfThenElse))))
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse))))
 
 def test_multi_if():
-    i = tvm.var('i')
-    j = tvm.var('j')
-    k = tvm.var('k')
+    ib = tvm.ir_builder.create()
     m = tvm.var('m')
     n = tvm.var('n')
-    stmt = tvm.make.For(
-        i, 0, 4, 0, 0,
-        tvm.make.For(
-            j, 0, n, 0, 0,
-            tvm.make.For(
-                k, 0, m, 0, 0,
-                tvm.make.Block(
-                    tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)),
-                    tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))
-                    ))))
+    with ib.for_range(0, 4, 'i') as i:
+        with ib.for_range(0, n, 'j') as j:
+            with ib.for_range(0, m, 'k') as k:
+                with ib.if_scope(ib.likely(i*m+j+k < n)):
+                    ib.emit(tvm.make.Evaluate(m))
+                with ib.else_scope():
+                    ib.emit(tvm.make.Evaluate(n))
+                with ib.if_scope(ib.likely(i*m+j-k < n)):
+                    ib.emit(tvm.make.Evaluate(m))
+                with ib.else_scope():
+                    ib.emit(tvm.make.Evaluate(n))
+    stmt = ib.get()
     stmt = tvm.ir_pass.LoopPartition(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
     assert('if' not in str(stmt.body.first))
-    print(stmt)
 
 def test_thread_axis():
     m = tvm.var('m')
     l = tvm.var('l')
     A = tvm.placeholder((m, l), name='A')
     B = tvm.compute((m, l), lambda i, j: A[i, j] + 3, name='B')
-
     s = tvm.create_schedule(B.op)
 
     s[B].set_scope("shared")
@@ -72,12 +71,67 @@ def test_thread_axis():
 
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
-    stmt_ = tvm.ir_pass.LoopPartition(stmt)
-    assert('if' not in str(stmt_.body.body.body.first))
-    print(stmt_)
+    stmt = tvm.ir_pass.LoopPartition(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert('if' not in str(stmt.body.body.body.first))
+
+def test_vectorize():
+    n = tvm.var('n')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    bias = tvm.var("bias", dtype="float32")
+    scale = tvm.var("scale", dtype="float32")
+    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
+    # schedule
+    s = tvm.create_schedule(C.op)
+    # create iter var and assign them tags.
+    num_thread = 32
+    bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
+    tx, x = s[C].split(x, nparts=num_thread)
+    _, x = s[C].split(x, factor=4)
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[C].vectorize(x)
+    stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
+    body = stmt.body.body.body.body.body
+    assert(x.var.name not in str(body.condition))
+    assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
+
+def test_select():
+    ib = tvm.ir_builder.create()
+    m = tvm.var('m')
+    n = tvm.var('n')
+    with ib.for_range(0, ((n+3)/4), 'i') as i:
+      with ib.for_range(0, 4, 'j') as j:
+        ib.emit(tvm.make.Evaluate(
+          tvm.make.Select(ib.likely(i*4+j<n), m, n)))
+    stmt = ib.get()
+    stmt = tvm.ir_pass.LoopPartition(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
+
+def test_thread_axis2():
+    n = tvm.convert(4096)
+    m = tvm.var('m')
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.placeholder((n,), name='B')
+    C = tvm.compute(A.shape, lambda i: A[i] + B[i], name='C')
+    s = tvm.create_schedule(C.op)
+    num_thread = 32
+    bx, x = s[C].split(C.op.axis[0], factor=32)
+    tx, x = s[C].split(x, nparts=num_thread)
+    _,  x = s[C].split(x, factor=m)
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
+    for_body = stmt.body.body.body.body.body.first
+    assert('threadIdx' not in str(for_body.extent))
 
 if __name__ == "__main__":
-    test_multi_loop()
     test_basic()
+    test_multi_loop()
     test_multi_if()
     test_thread_axis()
+    test_vectorize()
+    test_select()
+    test_thread_axis2()