diff --git a/python/tvm/build.py b/python/tvm/build.py
index 273c16a07ad4eea55e71e87ce90936efca262920..a592bb314d5f6f2c4682030afef3b95a3a8a1f3b 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -90,10 +90,10 @@ def lower(sch,
     sch = sch.normalize()
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
-    if not simple_mode:
-        stmt = ir_pass.LoopPartition(stmt)
     stmt = ir_pass.StorageFlatten(stmt, binds)
     stmt = ir_pass.CanonicalSimplify(stmt)
+    if not simple_mode:
+        stmt = ir_pass.LoopPartition(stmt)
     stmt = ir_pass.VectorizeLoop(stmt)
     stmt = ir_pass.InjectVirtualThread(stmt)
     stmt = ir_pass.StorageRewrite(stmt)
diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc
index c4567b6ae9791d4023f94c8f4ccf0ba2ab429f52..1e524583b198011ac7f98b308b8177437af3f1eb 100644
--- a/src/pass/loop_partition.cc
+++ b/src/pass/loop_partition.cc
@@ -52,7 +52,7 @@ class CandidateSelector : public IRVisitor {
       const Variable* var = op->loop_var.get();
       record_.insert({var, false});
       IRVisitor::Visit_(op);
-      if (record_.at(var)) {
+      if (record_.at(var) && !no_split_) {
         candidates.insert(op);
       }
       record_.erase(var);
@@ -70,7 +70,7 @@ class CandidateSelector : public IRVisitor {
       if ((scope.rank == 0) && !is_const(op->value)) {
         record_.insert({var.get(), false});
         IRVisitor::Visit_(op);
-        if (record_.at(var.get())) {
+        if (record_.at(var.get()) && !no_split_) {
           candidates.insert(op);
         }
         record_.erase(var.get());
@@ -80,11 +80,25 @@ class CandidateSelector : public IRVisitor {
     IRVisitor::Visit_(op);
   }
 
+  void Visit_(const Block* op) {
+    bool temp = no_split_;
+    this->Visit(op->first);
+    // erase the no split state of first when visit rest.
+    std::swap(temp, no_split_);
+    this->Visit(op->rest);
+    // restore the no split flag.
+    no_split_ = no_split_ || temp;
+  }
+
   void Visit_(const Call* op) {
     if (op->is_intrinsic(Call::likely)) {
       in_likely_ = true;
       IRVisitor::Visit_(op);
       in_likely_ = false;
+    } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
+      // no split if the body contains allreduce.
+      no_split_ = true;
+      return;
     } else {
       IRVisitor::Visit_(op);
     }
@@ -100,6 +114,7 @@ class CandidateSelector : public IRVisitor {
 
  private:
   bool in_likely_;
+  bool no_split_{false};
   std::unordered_map<const Variable*, VarIsUsed> record_;
 };
 
diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc
index 9b01e24fee388be18122d38f967a7c36d2b84265..ed12160af01b7384da63343fc9e85e030caee53e 100644
--- a/src/pass/ssa.cc
+++ b/src/pass/ssa.cc
@@ -1,6 +1,8 @@
 /*!
  *  Copyright (c) 2016 by Contributors
  *  SSA related checks and pass.
+ *
+ *  SSA requires each varaible to be only defined once.
  * \file ssa.cc
  */
 #include <tvm/ir.h>
@@ -14,138 +16,155 @@
 namespace tvm {
 namespace ir {
 namespace {
-
-// global functor to get var definition from
-struct FGetVarDef {
-  using FType = IRFunctor<VarExpr (const NodeRef&)>;
-  static FType& vtable() {  // NOLINT(*)
-    static FType inst; return inst;
-  }
-};
-TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
-.set_dispatch<Let>([](const Let* op) {
-    return op->var;
-  })
-.set_dispatch<LetStmt>([](const LetStmt* op) {
-    return op->var;
-  })
-.set_dispatch<For>([](const For* op) {
-    return op->loop_var;
-  })
-.set_dispatch<Allocate>([](const Allocate* op) {
-    return op->buffer_var;
-  });
-
-struct FSetVarDef {
-  using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>;
-  using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>;
-  static FTypeExpr& vtable_expr() {  // NOLINT(*)
-    static FTypeExpr inst; return inst;
-  }
-  static FTypeStmt& vtable_stmt() {  // NOLINT(*)
-    static FTypeStmt inst; return inst;
-  }
-};
-TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_expr)
-.set_dispatch<Let>([](const Let* op, VarExpr var) {
-    std::shared_ptr<Let> x = std::make_shared<Let>(*op);
-    x->var = var;
-    return Expr(x);
-  });
-
-TVM_STATIC_IR_FUNCTOR(FSetVarDef, vtable_stmt)
-.set_dispatch<LetStmt>([](const LetStmt* op, VarExpr var) {
-    std::shared_ptr<LetStmt> x = std::make_shared<LetStmt>(*op);
-    x->var = var;
-    return Stmt(x);
-  })
-.set_dispatch<For>([](const For* op, VarExpr var) {
-    std::shared_ptr<For> x = std::make_shared<For>(*op);
-    x->loop_var = var;
-    return Stmt(x);
-  });
-
-class IRVerifySSA : public IRVisitor {
+class IRVerifySSA final : public IRVisitor {
  public:
   bool is_ssa{true};
 
   void Visit(const NodeRef& n) final {
     if (!is_ssa) return;
-    static auto& fget_var_def = FGetVarDef::vtable();
-    if (fget_var_def.can_dispatch(n)) {
-      VarExpr v = fget_var_def(n);
-      if (defined_.count(v.get()) != 0) {
-        is_ssa = false; return;
-      } else {
-        defined_[v.get()] = 1;
-      }
-    }
     IRVisitor::Visit(n);
   }
+  void Visit_(const Let* op) final {
+    MarkDef(op->var.get());
+    IRVisitor::Visit_(op);
+  }
+  void Visit_(const LetStmt* op) final {
+    MarkDef(op->var.get());
+    IRVisitor::Visit_(op);
+  }
+  void Visit_(const For* op) final {
+    MarkDef(op->loop_var.get());
+    IRVisitor::Visit_(op);
+  }
+  void Visit_(const Allocate* op) final {
+    MarkDef(op->buffer_var.get());
+    IRVisitor::Visit_(op);
+  }
 
  private:
+  void MarkDef(const Variable* v) {
+    if (defined_.count(v) != 0) {
+      is_ssa = false; return;
+    } else {
+      defined_[v] = 1;
+    }
+  }
   std::unordered_map<const Variable*, int> defined_;
 };
 
-class IRConvertSSA : public IRMutator {
+class IRConvertSSA final : public IRMutator {
  public:
-  Expr Mutate(Expr expr) final {
-    static auto& fget_var_def = FGetVarDef::vtable();
-    static auto& fset_var_def = FSetVarDef::vtable_expr();
-    if (fget_var_def.can_dispatch(expr)) {
-      VarExpr v = fget_var_def(expr);
-      VarExpr new_var = v;
-      if (defined_.count(v.get()) != 0) {
-        CHECK(expr.as<Allocate>() == nullptr)
-            << "One allocation in two places, cannot rename buffer in allocate";
-        new_var = Variable::make(v->type, v->name_hint);
-      } else {
-        defined_.insert(v.get());
-      }
+  Expr Mutate_(const Variable* op, const Expr& e) final {
+    if (scope_.count(op)) {
+      return scope_[op].back();
+    } else {
+      return e;
+    }
+  }
+  Expr Mutate_(const Let* op, const Expr& e) final {
+    const VarExpr& v = op->var;
+    if (defined_.count(v.get())) {
+      Expr value = IRMutator::Mutate(op->value);
+      VarExpr new_var = Variable::make(v.type(), v->name_hint);
       scope_[v.get()].push_back(new_var);
-      Expr new_expr = IRMutator::Mutate(expr);
+      Expr body = IRMutator::Mutate(op->body);
       scope_[v.get()].pop_back();
-
-      if (!new_var.same_as(v)) {
-        return fset_var_def(new_expr, new_var);
-      } else {
-        return new_expr;
-      }
-    } else if (expr.as<Variable>()) {
-      const Variable* v = expr.as<Variable>();
-      if (scope_.count(v) != 0) {
-        return scope_[v].back();
-      } else {
-        return expr;
-      }
+      return Let::make(new_var, value, body);
     } else {
-      Expr e = IRMutator::Mutate(expr);
-      return e;
+      defined_.insert(v.get());
+      return IRMutator::Mutate_(op, e);
     }
   }
-
-  Stmt Mutate(Stmt stmt) final {
-    static auto& fget_var_def = FGetVarDef::vtable();
-    static auto& fset_var_def = FSetVarDef::vtable_stmt();
-    if (fget_var_def.can_dispatch(stmt)) {
-      VarExpr v = fget_var_def(stmt);
-      VarExpr new_var = v;
-      if (defined_.count(v.get()) != 0) {
-        new_var = Variable::make(v->type, v->name_hint);
-      } else {
-        defined_.insert(v.get());
-      }
+  Expr Mutate_(const Load* op, const Expr& e) final {
+    Expr expr = IRMutator::Mutate_(op, e);
+    op = expr.as<Load>();
+    if (scope_.count(op->buffer_var.get())) {
+      return Load::make(
+          op->type, scope_[op->buffer_var.get()].back(),
+          op->index, op->predicate);
+    } else {
+      return expr;
+    }
+  }
+  Stmt Mutate_(const Store* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Store>();
+    if (scope_.count(op->buffer_var.get())) {
+      return Store::make(
+          scope_[op->buffer_var.get()].back(), op->value,
+          op->index, op->predicate);
+    } else {
+      return stmt;
+    }
+  }
+  Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
+    const VarExpr& v = op->var;
+    if (defined_.count(v.get())) {
+      Expr value = IRMutator::Mutate(op->value);
+      VarExpr new_var = Variable::make(v.type(), v->name_hint);
       scope_[v.get()].push_back(new_var);
-      Stmt new_stmt = IRMutator::Mutate(stmt);
+      Stmt body = IRMutator::Mutate(op->body);
       scope_[v.get()].pop_back();
-
-      if (!new_var.same_as(v)) {
-        return fset_var_def(new_stmt, new_var);
+      return LetStmt::make(new_var, value, body);
+    } else {
+      defined_.insert(v.get());
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const For* op, const Stmt& s) final {
+    const VarExpr& v = op->loop_var;
+    if (defined_.count(v.get())) {
+      VarExpr new_var = Variable::make(v.type(), v->name_hint);
+      scope_[v.get()].push_back(new_var);
+      Stmt stmt = IRMutator::Mutate_(op, s);
+      scope_[v.get()].pop_back();
+      op = stmt.as<For>();
+      return For::make(
+          new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
+    } else {
+      defined_.insert(v.get());
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    const VarExpr& v = op->buffer_var;
+    if (defined_.count(v.get())) {
+      VarExpr new_var = Variable::make(v.type(), v->name_hint);
+      scope_[v.get()].push_back(new_var);
+      Stmt stmt = IRMutator::Mutate_(op, s);
+      scope_[v.get()].pop_back();
+      op = stmt.as<Allocate>();
+      return Allocate::make(
+          new_var, op->type, op->extents, op->condition,
+          op->body, op->new_expr, op->free_function);
+    } else {
+      defined_.insert(v.get());
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
+    if (const Variable* v = op->node.as<Variable>()) {
+      if (op->attr_key == attr::storage_scope) {
+        const Allocate* alloc = op->body.as<Allocate>();
+        if (alloc && op->node.same_as(alloc->buffer_var)) {
+          Stmt new_alloc = Mutate(op->body);
+          if (new_alloc.same_as(op->body)) return s;
+          alloc = new_alloc.as<Allocate>();
+          CHECK(alloc);
+          return AttrStmt::make(
+              alloc->buffer_var, op->attr_key, op->value, new_alloc);
+        }
+      }
+      Stmt stmt = IRMutator::Mutate_(op, s);
+      op = stmt.as<AttrStmt>();
+      if (scope_.count(v) && scope_[v].size() != 0) {
+        return AttrStmt::make(
+            scope_[v].back(), op->attr_key, op->value, op->body);
       } else {
-        return new_stmt;
+        return stmt;
       }
     } else {
-      return IRMutator::Mutate(stmt);
+      return IRMutator::Mutate_(op, s);
     }
   }
 
diff --git a/tests/python/unittest/test_build_lower.py b/tests/python/unittest/test_build_lower.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ef2a41741de2d2911928266f5a351bef28f705
--- /dev/null
+++ b/tests/python/unittest/test_build_lower.py
@@ -0,0 +1,20 @@
+import tvm
+
+def test_lower_rfactor():
+    n = tvm.var("n")
+    m = tvm.var("m")
+    A = tvm.placeholder((n, m), name='A')
+    k = tvm.reduce_axis((0, m), "k")
+    B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
+    s = tvm.create_schedule(B.op)
+    ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
+    BF = s.rfactor(B, ki)
+    xo, xi = s[B].split(s[B].op.axis[0], factor=32)
+    s[B.op].bind(xo, tvm.thread_axis("blockIdx.x"))
+    s[B.op].bind(xi, tvm.thread_axis("threadIdx.y"))
+    s[B].bind(s[B].op.reduce_axis[0], tvm.thread_axis("threadIdx.x"))
+    s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
+    fapi = tvm.lower(s, [A, B])
+
+if __name__ == "__main__":
+    test_lower_rfactor()