diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 1e5d0e0a94c0b18f19a253169896ed8d76e4e14d..91efe172759385f82499c81a4781c241dfad235d 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -40,6 +40,8 @@ using Halide::Internal::as_const_uint;
 using Halide::Internal::const_true;
 using Halide::Internal::const_false;
 using Halide::Internal::is_no_op;
+using Halide::likely;
+using Halide::likely_if_innermost;
 
 inline Type TVMShapeIndexType() {
   if (std::is_signed<tvm_index_t>::value) {
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 2672a49b1bdadeba72fdcf5571903ec8697ce51f..5fdc6fa212409af53dc21cdfc32640a5fb27014f 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -41,7 +41,7 @@ struct Reduce : public ExprNode<Reduce> {
   /*! \brief construct expr from op and rdom */
   static Expr make(std::string op, Expr src,
                    Array<IterVar> rdom,
-                   Expr condition = make_const(Bool(1), true));
+                   Expr condition = const_true());
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("dtype", &type);
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index dff7991862175cc2cc61cbe4469abb05e1765559..93b93a62cc2cd7a3c7889070dc36f9e61e8ea957 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -210,6 +210,18 @@ class Schedule : public NodeRef {
    * \return The created tensor.
    */
   Tensor cache_write(const Tensor& tensor, const std::string& scope);
+  /*!
+   * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
+   * This will create a new stage that generated the new tensor with axis
+   * as the first dimension. The tensor's body wil be rewriten as a reduction
+   * over the factored tensor.
+   *
+   * \param tensor The tensor to be factored.
+   * \param axis The reduction axis in tensor's schedule to be factored.
+   * \return The created factored tensor.
+   */
+  Tensor rfactor(const Tensor& tensor,
+                 const IterVar& axis);
   /*!
    * \brief Normalize the schedule.
    *  This is needed before bound inference.
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 51cb4a17943658c653b046faca545a585f7ae76f..75d1a727c7837f7cf33e47218a57760bf31bf371 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -19,4 +19,4 @@ from .ndarray import cpu, gpu, opencl, cl, vpi
 
 from ._base import TVMError
 from .api import *
-from .build import build
+from .build import build, lower
diff --git a/python/tvm/api.py b/python/tvm/api.py
index 9b55801b1f578e81d4648aae5e66b73a1c690b46..9c44f8f5556eb64b152db85ea30f794da5478c94 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -372,7 +372,7 @@ def reduce_axis(dom, name="rv"):
     return _IterVar(dom, name, 2)
 
 
-def sum(expr, axis):
+def sum(expr, axis, where=None):
     """Create a sum expression over axis
 
     Parameters
@@ -382,13 +382,16 @@ def sum(expr, axis):
 
     axis : IterVar
         The reduction IterVar axis
+
+    where : optional, Expr
+        Filtering predicate of the reduction.
     """
     axis = axis if isinstance(axis, list) else [axis]
-    x = _make.Reduce("Add", expr, axis)
+    x = _make.Reduce("Add", expr, axis, where)
     return x
 
 
-def min(lhs, rhs=None, axis=None):
+def min(lhs, rhs=None, axis=None, where=None):
     """Create a min expression.
 
     Parameters
@@ -401,6 +404,9 @@ def min(lhs, rhs=None, axis=None):
 
     axis : IterVar, optional
         The reduction IterVar axis
+
+    where : optional, Expr
+        Filtering predicate of the reduction.
     """
     if rhs and axis:
         raise ValueError("Can only take one argument, rhs or axis")
@@ -409,11 +415,11 @@ def min(lhs, rhs=None, axis=None):
     if rhs:
         return _make.Min(lhs, rhs)
     axis = axis if isinstance(axis, list) else [axis]
-    x = _make.Reduce("Min", expr, axis)
+    x = _make.Reduce("Min", expr, axis, where)
     return x
 
 
-def max(lhs, rhs=None, axis=None):
+def max(lhs, rhs=None, axis=None, where=None):
     """Create a max expression.
 
     Parameters
@@ -426,6 +432,9 @@ def max(lhs, rhs=None, axis=None):
 
     axis : IterVar, optional
         The reduction IterVar axis
+
+    where : optional, Expr
+        Filtering predicate of the reduction.
     """
     if rhs and axis:
         raise ValueError("Can only take one argument, rhs or axis")
@@ -434,7 +443,7 @@ def max(lhs, rhs=None, axis=None):
     if rhs:
         return _make.Max(lhs, rhs)
     axis = axis if isinstance(axis, list) else [axis]
-    x = _make.Reduce("Max", expr, axis)
+    x = _make.Reduce("Max", expr, axis, where)
     return x
 
 
diff --git a/python/tvm/build.py b/python/tvm/build.py
index 8d5ba26c857033903cdd39f5d2b7ad56b3d12dae..ec5a0dba1c4f2ca34be2b89a3aa3603c1d98971a 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -9,16 +9,15 @@ from . import tensor
 from . import schedule
 from . import expr
 from . import ir_pass
+from . import collections
 from . import codegen
 
-def build(sch,
+def lower(sch,
           args,
-          target,
-          target_host="stackvm",
           name="default_function",
           binds=None,
           max_auto_unroll_step=8):
-    """Build a function with arguments as signiture.
+    """Lowering step before build into target.
 
     Parameters
     ----------
@@ -28,12 +27,6 @@ def build(sch,
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
 
-    target : str
-        The target of the compilation.
-
-    target_host :
-        Host compilation target, if target is device.
-
     name : str
         The name of result function.
 
@@ -46,10 +39,8 @@ def build(sch,
 
     Returns
     -------
-    f : Function, or pair of functions
+    f : LoweredFunc
        The result function.
-       If the function requires host space allocation,
-       a pair of functions will be returned.
     """
     binds = {} if binds is None else binds.copy()
     arg_list = []
@@ -77,6 +68,62 @@ def build(sch,
     stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
     stmt = ir_pass.Simplify(stmt)
     fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0)
+    return fapi
+
+
+
+def build(sch,
+          args=None,
+          target="llvm",
+          target_host="stackvm",
+          name="default_function",
+          binds=None,
+          max_auto_unroll_step=8):
+    """Build a function with arguments as signiture.
+
+    Parameters
+    ----------
+    sch : tvm.Schedule, or LoweredFunc
+        The schedule to be builded
+
+    args : list of Buffer or Tensor or Var
+        The argument lists to the function.
+
+    target : str
+        The target of the compilation.
+
+    target_host :
+        Host compilation target, if target is device.
+
+    name : str
+        The name of result function.
+
+    binds : dict, optional
+        Dictionary that maps the binding of symbolic buffer to Tensor.
+        By default, a new buffer is created for each tensor in the argument.
+
+    max_auto_unroll_step: int
+        Maximum step to perform automatic unrolling
+
+    Returns
+    -------
+    f : Function, or pair of functions
+       The result function.
+    """
+    if isinstance(sch, schedule.Schedule):
+        if args is None:
+            raise ValueError("args must be given for build from schedule")
+        fapi = lower(sch, args,
+                     name=name,
+                     binds=binds,
+                     max_auto_unroll_step=max_auto_unroll_step)
+    elif isinstance(sch, collections.LoweredFunc):
+        if args:
+            raise ValueError("args must be done when build from LoweredFunc")
+        fapi = sch
+    else:
+        raise ValueError("sch have to be Schedule or LoweredFunc")
+
     fsplits = ir_pass.SplitHostDevice(fapi)
     fsplits = [x for x in fsplits]
     for i in range(1, len(fsplits)):
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 813b48b6fc330b942c9cb314478febb63409505b..dcddafa4c7a694d658e61c7b0499c58ec8449c91 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -87,6 +87,27 @@ class Schedule(NodeBase):
         """
         return _api_internal._ScheduleCacheWrite(self, tensor, scope)
 
+    def rfactor(self, tensor, axis):
+        """ Factor a reduction axis in tensor's schedule to be an explicit axis.
+
+        This will create a new stage that generated the new tensor with axis
+        as the first dimension. The tensor's body wil be rewriten as a reduction
+        over the factored tensor.
+
+        Parameters
+        ----------
+        tensor : Tensor
+            The tensor to be factored.
+        axis : IterVar
+            The reduction axis in the schedule to be factored.
+
+        Returns
+        -------
+        tfactor : Tensor
+            The created factored tensor.
+        """
+        return _api_internal._ScheduleRFactor(self, tensor, axis)
+
 
 @register_node
 class Stage(NodeBase):
@@ -114,8 +135,6 @@ class Stage(NodeBase):
             The inner variable of iteration.
         """
         if outer is not None:
-            if outer.thread_tag == '':
-                raise ValueError("split by outer must have special thread_tag")
             inner = _api_internal._StageSplitByOuter(self, parent, outer, factor)
         else:
             if factor is None:
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
index af07958000fa622339291bcd29f2acd4bf7b191e..61cf3e365f9f360c97bc4c0259c93312baf4b560 100644
--- a/src/api/api_ir.cc
+++ b/src/api/api_ir.cc
@@ -89,7 +89,7 @@ TVM_REGISTER_API(_make_Allocate)
       *ret = Node::make(a, b);                               \
     })
 
-REGISTER_MAKE3(Reduce);
+REGISTER_MAKE4(Reduce);
 REGISTER_MAKE4(AttrStmt);
 
 REGISTER_MAKE2(IntImm);
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 14c555fb3fb0f897a65cd7339cee39213c76ba71..933adc872cc251a8d3223cb6db324e1c8842877e 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -318,4 +318,10 @@ TVM_REGISTER_API(_ScheduleCacheWrite)
         .cache_write(args[1], args[2]);
   });
 
+TVM_REGISTER_API(_ScheduleRFactor)
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    *ret = args[0].operator Schedule()
+        .rfactor(args[1], args[2]);
+  });
+
 }  // namespace tvm
diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
index 55971fe865d3335596783f90b177f20537e2906f..b288ab82e46d9048c61e89358b01343de2c8f031 100644
--- a/src/codegen/codegen_c.cc
+++ b/src/codegen/codegen_c.cc
@@ -526,8 +526,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
 void CodeGenC::VisitStmt_(const Store* op) {
   Type t = op->value.type();
   if (t.lanes() == 1) {
-    this->PrintIndent();
     std::string value = this->PrintExpr(op->value);
+    this->PrintIndent();
     this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
     stream << " = " << value << ";\n";
   } else {
diff --git a/src/lang/ir.cc b/src/lang/ir.cc
index edd93dac1e457c1879447a734e0695cca06eaf2f..55a4d7a0de5648757d133bb2af75783e338e83e3 100644
--- a/src/lang/ir.cc
+++ b/src/lang/ir.cc
@@ -28,7 +28,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->print(op->source);
   p->stream << ", axis=" << op->axis;
   if (!is_const(op->condition, 1)) {
-    p->stream << ", condition=" << op->condition;
+    p->stream << ", where=" << op->condition;
   }
   p->stream << ")";
 });
@@ -45,6 +45,9 @@ Expr Reduce::make(std::string op, Expr source,
     CHECK_EQ(axis[i]->iter_type, kCommReduce)
         << "Can only take axis created by reduce_axis";
   }
+  if (!condition.defined()) {
+    condition = const_true();
+  }
   auto n = std::make_shared<Reduce>();
   CHECK(source.defined());
   for (size_t i = 0; i < axis.size(); ++i) {
diff --git a/src/lang/operation.cc b/src/lang/operation.cc
deleted file mode 100644
index f6cdaa72b4f0a1c303f5d682916864c66e3dd3f8..0000000000000000000000000000000000000000
--- a/src/lang/operation.cc
+++ /dev/null
@@ -1,6 +0,0 @@
-/*!
- *  Copyright (c) 2016 by Contributors
- * \file operation.cc
- */
-#include <tvm/operation.h>
-#include <tvm/tensor.h>
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 7884e454e8daf9c8847cc202b4d3f611a1481742..e2467bc32fcc8be03529cbf43870aef31b25e53c 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -10,6 +10,7 @@
 #include <tvm/ir_pass.h>
 #include <unordered_set>
 #include "./op_util.h"
+#include "../schedule/message_passing.h"
 
 namespace tvm {
 
@@ -64,10 +65,7 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
     args.push_back(axis.back()->var);
   }
 
-  op_node->axis = Array<IterVar>(axis);
-  op_node->body = fcompute(args);
-  op_node->name = name;
-  return Operation(op_node).output(0);
+  return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
 }
 
 Operation ComputeOpNode::make(std::string name,
@@ -191,6 +189,9 @@ void MakeReduction(const ComputeOpNode* op,
   }
   *init = Provide::make(t->op, t->value_index, init_value, args);
   *provide = Provide::make(t->op, t->value_index, update_value, args);
+  if (!is_one(reduce->condition)) {
+    *provide = IfThenElse::make(reduce->condition, *provide);
+  }
 }
 
 Stmt MakeProvide(const ComputeOpNode* op,
@@ -202,31 +203,6 @@ Stmt MakeProvide(const ComputeOpNode* op,
   return Provide::make(t->op, t->value_index, op->body, args);
 }
 
-// message passing to find if IterVar is related to reduction.
-void PassDownReduceFlag(const Stage& s,
-                        std::unordered_map<IterVar, int>* p_state) {
-  auto& state = *p_state;
-  for (IterVarRelation rel : s->relations) {
-    if (rel.as<SplitNode>()) {
-      const SplitNode* s = rel.as<SplitNode>();
-      int flag = state.at(s->parent);
-      state[s->outer] = flag;
-      state[s->inner] = flag;
-    } else if (rel.as<FuseNode>()) {
-      const FuseNode* s = rel.as<FuseNode>();
-      int flag_outer = state.at(s->outer);
-      int flag_inner = state.at(s->inner);
-      state[s->fused] = flag_outer | flag_inner;
-    } else if (rel.as<RebaseNode>()) {
-      const RebaseNode* s = rel.as<RebaseNode>();
-      int flag = state.at(s->parent);
-      state[s->rebased] = flag;
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
 Stmt Substitute(Stmt s,
                 const std::unordered_map<IterVar, Expr>& value_map) {
   Map<Var, Expr> temp;
@@ -267,7 +243,7 @@ Stmt ComputeOpNode::BuildProvide(
       update_state[iv] = 1;
     }
     // find which iter var is related to reduction and which is related to axis.
-    PassDownReduceFlag(stage, &update_state);
+    schedule::PassDownBitMaskOr(stage, &update_state);
     auto leaf_iter_vars = stage->leaf_iter_vars;
     std::unordered_map<IterVar, Expr> init_value_map;
     // first first loop that is related to reduction.
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index 7aa31405748f31f98ba598f9ed2b0db13519c19f..487be17cc80be7f1d11223c3d4c54f17730b8a16 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -8,6 +8,7 @@
 #include <tvm/operation.h>
 #include <tvm/ir_mutator.h>
 #include "./op_util.h"
+#include "../schedule/message_passing.h"
 #include "../arithmetic/compute_expr.h"
 
 namespace tvm {
@@ -16,61 +17,6 @@ namespace op {
 using namespace arith;
 using namespace ir;
 
-/*!
- * \brief use message passing to calculate the assignment of each Var inside the loop body.
- * \param s The schedule to be used.
- * \param dom_map The domain map of each iteration variable's domain
- * \param p_state The message passing state
- *     IterVar->The assignment.
- */
-void PassUpOffset(const Stage& s,
-                  const Map<IterVar, Range>& dom_map,
-                  std::unordered_map<IterVar, Expr>* p_state) {
-  auto& state = *p_state;
-  for (size_t i = s->relations.size(); i != 0; --i) {
-    IterVarRelation rel = s->relations[i - 1];
-    if (rel.as<SplitNode>()) {
-      const SplitNode* s = rel.as<SplitNode>();
-      Expr outer = state.at(s->outer);
-      Expr inner = state.at(s->inner);
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr parent_min = dom_map.at(s->parent)->min;
-      state[s->parent] = inner + outer * factor;
-      // add min if they exist
-      if (!is_zero(parent_min)) {
-        state[s->parent] = state[s->parent] + parent_min;
-      }
-    } else if (rel.as<FuseNode>()) {
-      const FuseNode* s = rel.as<FuseNode>();
-      Expr value = state.at(s->fused);
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr outer_min = dom_map.at(s->outer)->min;
-      Expr inner_min = dom_map.at(s->inner)->min;
-      state[s->outer] = value / factor;
-      state[s->inner] = value % factor;
-      // add min if they exist
-      if (!is_zero(outer_min)) {
-        state[s->outer] = state[s->outer] + outer_min;
-      }
-      if (!is_zero(inner_min)) {
-        state[s->inner] = state[s->inner] + inner_min;
-      }
-    } else if (rel.as<RebaseNode>()) {
-      const RebaseNode* s = rel.as<RebaseNode>();
-      Expr value = state.at(s->rebased);
-      Expr parent_min = dom_map.at(s->parent)->min;
-      // add min if they exist
-      if (!is_zero(parent_min)) {
-        state[s->parent] = value + parent_min;
-      } else {
-        state[s->parent] = value;
-      }
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
 std::vector<std::vector<Stmt> >
 MakeLoopNest(const Stage& stage,
              const std::unordered_map<IterVar, Range>& dom_map,
@@ -166,7 +112,7 @@ MakeLoopNest(const Stage& stage,
     }
   }
   // message passing to get offset of root iter vars.
-  PassUpOffset(stage, dom_map, &value_map);
+  schedule::PassUpIndex(stage, dom_map, &value_map);
   return nest;
 }
 
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 39fbc929ffe6374e78fdaa3486c632f51fc28fb9..ae6cf678e0237d39d7f9939e3e01a4ff7324e33b 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -3,200 +3,18 @@
  * \file bound.cc
  * \brief The bound inference logic.
  */
-#include <tvm/ir.h>
 #include <tvm/ir_visitor.h>
-#include <tvm/ir_pass.h>
 #include <tvm/schedule_pass.h>
-#include <tvm/arithmetic.h>
 #include <tvm/operation.h>
 #include <unordered_map>
 #include <unordered_set>
 #include "./graph.h"
+#include "./message_passing.h"
 #include "../runtime/thread_storage_scope.h"
 
 namespace tvm {
 namespace schedule {
 
-using namespace arith;
-
-// result = ceil((a / b)), both a and b are positive integer
-inline Expr DivCeil(Expr a, Expr b) {
-  return ir::Simplify((a + b - 1) / b);
-}
-
-inline bool prove_equal(Expr lhs, Expr rhs) {
-  return is_zero(ir::Simplify(lhs - rhs));
-}
-
-// Downward message passing algorithm on stage schedule s,
-// pass the range state down from the root to the leaves
-// after this pass, every IterVar in the stage hyper graph will have a range(domain)
-void PassDown(const Stage& s,
-              std::unordered_map<IterVar, Range>* p_state) {
-  auto& state = *p_state;
-  // forwar iteration on relations
-  for (IterVarRelation rel : s->relations) {
-    if (rel.as<SplitNode>()) {
-      const SplitNode* r = rel.as<SplitNode>();
-      CHECK(state.count(r->parent));
-      CHECK(!state.count(r->inner));
-      const Range& range_parent = state.at(r->parent);
-      if (r->factor.defined()) {
-        state[r->inner] = Range::make_with_min_extent(0, r->factor);
-        if (r->outer->dom.defined()) {
-          state[r->outer] = r->outer->dom;
-        } else {
-          if (!state.count(r->outer)) {
-            state[r->outer] = Range::make_with_min_extent(
-                0, DivCeil(range_parent->extent, r->factor));
-          } else {
-            Expr outer_ext = DivCeil(range_parent->extent, r->factor);
-            Range outer_rng = state.at(r->outer);
-            bool match = is_zero(outer_rng->min);
-            if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
-            CHECK(match)
-                << r->outer
-                << "IterVar is used in two places as outer scope,"
-                << " cannot prove their extents are the same "
-                << outer_ext << " vs " << outer_rng->extent;
-          }
-        }
-      } else {
-        CHECK(r->outer->dom.defined());
-        state[r->outer] = r->outer->dom;
-        state[r->inner] = Range::make_with_min_extent(
-            0, DivCeil(range_parent->extent, r->outer->dom->extent));
-      }
-    } else if (rel.as<FuseNode>()) {
-      const FuseNode* r = rel.as<FuseNode>();
-      CHECK(state.count(r->outer));
-      CHECK(state.count(r->inner));
-      const Range& range_outer = state.at(r->outer);
-      const Range& range_inner = state.at(r->inner);
-      state[r->fused] = Range::make_with_min_extent(
-          0, range_outer->extent * range_inner->extent);
-    } else if (rel.as<RebaseNode>()) {
-      const RebaseNode* r = rel.as<RebaseNode>();
-      CHECK(state.count(r->parent));
-      state[r->rebased] = Range::make_with_min_extent(
-          0, state.at(r->parent)->extent);
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-// upward message passing algorithm
-// pass the integer set on each leave loop up to the root
-// dom_map is the result of PassDown, it records the domain of each IterVar.
-// dom_map can be used to get cached result in reverse construction.
-// Implementation of Evaluations and passing.
-void PassUp(const SplitNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& outer,
-            const IntSet& inner,
-            IntSet* parent) {
-  if (dom_map.count(s->outer) &&
-      dom_map.count(s->inner) &&
-      dom_map.count(s->parent) &&
-      outer.match_range(dom_map.at(s->outer)) &&
-      inner.match_range(dom_map.at(s->inner))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  Expr factor = dom_map.at(s->inner)->extent;
-  Expr parent_min = dom_map.at(s->parent)->min;
-  CHECK(outer.defined());
-  CHECK(inner.defined());
-  CHECK(factor.defined());
-  *parent = EvalSet(
-      s->outer->var * factor + s->inner->var + parent_min,
-      {{s->outer, outer}, {s->inner, inner}});
-}
-
-void PassUp(const FuseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& fused,
-            IntSet* outer,
-            IntSet* inner) {
-  CHECK(dom_map.count(s->outer));
-  CHECK(dom_map.count(s->inner));
-  CHECK(dom_map.count(s->fused));
-
-  if (fused.match_range(dom_map.at(s->fused))) {
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
-    return;
-  }
-  Expr outer_min = dom_map.at(s->outer)->min;
-  Expr inner_min = dom_map.at(s->inner)->min;
-
-  if (fused.is_single_point()) {
-    Expr value = fused.point_value();
-    Expr factor = dom_map.at(s->inner)->extent;
-    Expr v_outer  = value / factor;
-    Expr v_inner  = value % factor;
-    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
-    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
-    *outer = IntSet::single_point(v_outer);
-    *inner = IntSet::single_point(v_inner);
-  } else {
-    LOG(WARNING) << "use fallback inference rule in fuse";
-    // simply use the entire set, this rule can be enhanced.
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
-    return;
-  }
-}
-
-void PassUp(const RebaseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& rebased,
-            IntSet* parent) {
-  CHECK(dom_map.count(s->parent));
-  if (rebased.match_range(dom_map.at(s->rebased))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  Expr parent_min = dom_map.at(s->parent)->min;
-  *parent = EvalSet(s->rebased->var + parent_min,
-                    {{s->rebased, rebased}});
-}
-
-void PassUp(const Stage& s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            std::unordered_map<IterVar, IntSet>* p_state) {
-  auto& state = *p_state;
-  for (size_t i = s->relations.size(); i != 0; --i) {
-    IterVarRelation rel = s->relations[i - 1];
-    if (rel.as<SplitNode>()) {
-      IntSet parent;
-      const SplitNode* r = rel.as<SplitNode>();
-      PassUp(r, dom_map,
-             state.at(r->outer), state.at(r->inner),
-             &parent);
-      state[r->parent] = parent;
-    } else if (rel.as<FuseNode>()) {
-      IntSet outer, inner;
-      const FuseNode* r = rel.as<FuseNode>();
-      PassUp(r, dom_map,
-             state.at(r->fused),
-             &outer, &inner);
-      state[r->outer] = outer;
-      state[r->inner] = inner;
-    } else if (rel.as<RebaseNode>()) {
-      IntSet parent;
-      const RebaseNode* r = rel.as<RebaseNode>();
-      PassUp(r, dom_map,
-             state.at(r->rebased),
-             &parent);
-      state[r->parent] = parent;
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
 // check if scope
 inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
   using runtime::ThreadScope;
@@ -285,7 +103,7 @@ void InferRootBound(const Stage& stage,
       }
     }
     // get the bound of the root IterVars given current location.
-    PassUp(parent, *rmap, &up_state);
+    PassUpDomain(parent, *rmap, &up_state);
 
     std::unordered_map<const Variable*, IntSet> dom_map;
     for (auto iv : parent->op->root_iter_vars()) {
@@ -358,7 +176,7 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
     const Stage& stage = sch->stages[i - 1];
     InferRootBound(stage, ctx, attach_path, &ret);
     // pass down to get bound of all iter vars.
-    PassDown(stage, &ret);
+    PassDownDomain(stage, &ret);
     // setup outer most threads.
     for (IterVar iv : stage->outermost_threads) {
       CHECK(iv->dom.defined());
diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc
new file mode 100644
index 0000000000000000000000000000000000000000..68d28df2c1dcde20e952a8d642b61c4b57e2ddab
--- /dev/null
+++ b/src/schedule/message_passing.cc
@@ -0,0 +1,343 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file message_passing.cc
+ * \brief The message passing domain.
+ */
+#include <tvm/arithmetic.h>
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include "./message_passing.h"
+
+namespace tvm {
+namespace schedule {
+
+using namespace arith;
+
+// result = ceil((a / b)), both a and b are positive integer
+inline Expr DivCeil(Expr a, Expr b) {
+  return ir::Simplify((a + b - 1) / b);
+}
+
+inline bool prove_equal(Expr lhs, Expr rhs) {
+  return is_zero(ir::Simplify(lhs - rhs));
+}
+
+void PassDownDomain(const Stage& stage,
+                    std::unordered_map<IterVar, Range>* p_state,
+                    bool allow_missing) {
+  auto& state = *p_state;
+  // forwar iteration on relations
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      if (!state.count(r->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      CHECK(!state.count(r->inner));
+      const Range& range_parent = state.at(r->parent);
+      if (r->factor.defined()) {
+        state[r->inner] = Range::make_with_min_extent(0, r->factor);
+        if (r->outer->dom.defined()) {
+          state[r->outer] = r->outer->dom;
+        } else {
+          if (!state.count(r->outer)) {
+            state[r->outer] = Range::make_with_min_extent(
+                0, DivCeil(range_parent->extent, r->factor));
+          } else {
+            Expr outer_ext = DivCeil(range_parent->extent, r->factor);
+            Range outer_rng = state.at(r->outer);
+            bool match = is_zero(outer_rng->min);
+            if (!prove_equal(outer_ext, outer_rng->extent)) match = false;
+            CHECK(match)
+                << r->outer
+                << "IterVar is used in two places as outer scope,"
+                << " cannot prove their extents are the same "
+                << outer_ext << " vs " << outer_rng->extent;
+          }
+        }
+      } else {
+        CHECK(r->outer->dom.defined());
+        state[r->outer] = r->outer->dom;
+        state[r->inner] = Range::make_with_min_extent(
+            0, DivCeil(range_parent->extent, r->outer->dom->extent));
+      }
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      if (!state.count(r->outer) || !state.count(r->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      const Range& range_outer = state.at(r->outer);
+      const Range& range_inner = state.at(r->inner);
+      state[r->fused] = Range::make_with_min_extent(
+          0, range_outer->extent * range_inner->extent);
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      if (!state.count(r->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      state[r->rebased] = Range::make_with_min_extent(
+          0, state.at(r->parent)->extent);
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+void PassUpIndex(const Stage& stage,
+                 const Map<IterVar, Range>& dom_map,
+                 std::unordered_map<IterVar, Expr>* p_state,
+                 bool allow_missing) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->outer) || !state.count(s->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Expr outer = state.at(s->outer);
+      Expr inner = state.at(s->inner);
+      Expr factor = dom_map.at(s->inner)->extent;
+      Expr parent_min = dom_map.at(s->parent)->min;
+      state[s->parent] = inner + outer * factor;
+      // add min if they exist
+      if (!is_zero(parent_min)) {
+        state[s->parent] = state[s->parent] + parent_min;
+      }
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->fused)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Expr value = state.at(s->fused);
+      Expr factor = dom_map.at(s->inner)->extent;
+      Expr outer_min = dom_map.at(s->outer)->min;
+      Expr inner_min = dom_map.at(s->inner)->min;
+      state[s->outer] = value / factor;
+      state[s->inner] = value % factor;
+      // add min if they exist
+      if (!is_zero(outer_min)) {
+        state[s->outer] = state[s->outer] + outer_min;
+      }
+      if (!is_zero(inner_min)) {
+        state[s->inner] = state[s->inner] + inner_min;
+      }
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Expr value = state.at(s->rebased);
+      Expr parent_min = dom_map.at(s->parent)->min;
+      // add min if they exist
+      if (!is_zero(parent_min)) {
+        state[s->parent] = value + parent_min;
+      } else {
+        state[s->parent] = value;
+      }
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+// Domain message passing.
+void PassUpDomain(const SplitNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& outer,
+                  const IntSet& inner,
+                  IntSet* parent) {
+  if (dom_map.count(s->outer) &&
+      dom_map.count(s->inner) &&
+      dom_map.count(s->parent) &&
+      outer.match_range(dom_map.at(s->outer)) &&
+      inner.match_range(dom_map.at(s->inner))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  Expr factor = dom_map.at(s->inner)->extent;
+  Expr parent_min = dom_map.at(s->parent)->min;
+  CHECK(outer.defined());
+  CHECK(inner.defined());
+  CHECK(factor.defined());
+  *parent = EvalSet(
+      s->outer->var * factor + s->inner->var + parent_min,
+      {{s->outer, outer}, {s->inner, inner}});
+}
+
+void PassUpDomain(const FuseNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& fused,
+                  IntSet* outer,
+                  IntSet* inner) {
+  CHECK(dom_map.count(s->outer));
+  CHECK(dom_map.count(s->inner));
+  CHECK(dom_map.count(s->fused));
+
+  if (fused.match_range(dom_map.at(s->fused))) {
+    *outer = IntSet::range(dom_map.at(s->outer));
+    *inner = IntSet::range(dom_map.at(s->inner));
+    return;
+  }
+  Expr outer_min = dom_map.at(s->outer)->min;
+  Expr inner_min = dom_map.at(s->inner)->min;
+
+  if (fused.is_single_point()) {
+    Expr value = fused.point_value();
+    Expr factor = dom_map.at(s->inner)->extent;
+    Expr v_outer  = value / factor;
+    Expr v_inner  = value % factor;
+    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
+    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
+    *outer = IntSet::single_point(v_outer);
+    *inner = IntSet::single_point(v_inner);
+  } else {
+    LOG(WARNING) << "use fallback inference rule in fuse";
+    // simply use the entire set, this rule can be enhanced.
+    *outer = IntSet::range(dom_map.at(s->outer));
+    *inner = IntSet::range(dom_map.at(s->inner));
+    return;
+  }
+}
+
+void PassUpDomain(const RebaseNode* s,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  const IntSet& rebased,
+                  IntSet* parent) {
+  CHECK(dom_map.count(s->parent));
+  if (rebased.match_range(dom_map.at(s->rebased))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  Expr parent_min = dom_map.at(s->parent)->min;
+  *parent = EvalSet(s->rebased->var + parent_min,
+                    {{s->rebased, rebased}});
+}
+
+void PassUpDomain(const Stage& stage,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  std::unordered_map<IterVar, IntSet>* p_state) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      IntSet parent;
+      PassUpDomain(r, dom_map,
+                   state.at(r->outer), state.at(r->inner),
+                   &parent);
+      state[r->parent] = parent;
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      IntSet outer, inner;
+      PassUpDomain(r, dom_map,
+                   state.at(r->fused),
+                   &outer, &inner);
+      state[r->outer] = outer;
+      state[r->inner] = inner;
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      IntSet parent;
+      PassUpDomain(r, dom_map,
+                   state.at(r->rebased),
+                   &parent);
+      state[r->parent] = parent;
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+// Pass up bit mask with or relation.
+void PassUpBitMaskOr(const Stage& stage,
+                     std::unordered_map<IterVar, int>* p_state,
+                     bool allow_missing) {
+  auto& state = *p_state;
+  for (size_t i = stage->relations.size(); i != 0; --i) {
+    IterVarRelation rel = stage->relations[i - 1];
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->inner) && !state.count(s->outer)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      int res = 0;
+      if (!state.count(s->parent)) res |= state[s->parent];
+      if (!state.count(s->inner)) res |= state[s->inner];
+      if (!state.count(s->outer)) res |= state[s->outer];
+      state[s->parent] = res;
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->fused)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->outer)) {
+        state[s->outer] = state[s->fused];
+      } else {
+        state[s->outer] |= state[s->fused];
+      }
+      if (!state.count(s->inner)) {
+        state[s->inner] = state[s->fused];
+      } else {
+        state[s->inner] |= state[s->fused];
+      }
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->parent)) {
+        state[s->parent] = state[s->rebased];
+      } else {
+        state[s->parent] |= state[s->rebased];
+      }
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+void PassDownBitMaskOr(const Stage& stage,
+                       std::unordered_map<IterVar, int>* p_state,
+                       bool allow_missing) {
+  auto& state = *p_state;
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->outer)) {
+        state[s->outer] = state.at(s->parent);
+      } else {
+        state[s->outer] |= state.at(s->parent);
+      }
+      if (!state.count(s->inner)) {
+        state[s->inner] = state.at(s->parent);
+      } else {
+        state[s->inner] |= state.at(s->parent);
+      }
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->outer) && !state.count(s->inner)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      int res = 0;
+      if (state.count(s->outer)) res |= state.at(s->outer);
+      if (state.count(s->inner)) res |= state.at(s->inner);
+      if (state.count(s->fused)) res |= state.at(s->fused);
+      state[s->fused] = res;
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      if (!state.count(s->rebased)) {
+        state[s->rebased] = state.at(s->parent);
+      } else {
+        state[s->rebased] |= state.at(s->parent);
+      }
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+}  // namespace schedule
+}  // namespace tvm
diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h
new file mode 100644
index 0000000000000000000000000000000000000000..5b7cf9d2400f157a8cda5498874f993381daae8c
--- /dev/null
+++ b/src/schedule/message_passing.h
@@ -0,0 +1,81 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file message_passing.h
+ * \brief Common utilities to do message passing
+ *  on the schedule hyper graph.
+ */
+#ifndef TVM_SCHEDULE_MESSAGE_PASSING_H_
+#define TVM_SCHEDULE_MESSAGE_PASSING_H_
+
+#include <tvm/expr.h>
+#include <tvm/schedule.h>
+#include <tvm/operation.h>
+#include <tvm/arithmetic.h>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace schedule {
+/*!
+ * \brief Downward inference of domain of each IterVar.
+ *  Caller set the range of the root, then the function
+ *  propagates it towards the leaves.
+ *
+ * \param stage The stage to operate on.
+ * \param p_state The state of the message passing.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownDomain(
+    const Stage& stage,
+    std::unordered_map<IterVar, Range>* p_state,
+    bool allow_missing = false);
+
+/*!
+ * \param Upward inference of index of each IterVar.
+ *  given index assignement of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpIndex(const Stage& stage,
+                 const Map<IterVar, Range>& dom_map,
+                 std::unordered_map<IterVar, Expr>* p_state,
+                 bool allow_missing = false);
+
+/*!
+ * \param Upward inference of domain set of each IterVar.
+ *  given domain assignment of the leaves,
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's maximum domain.
+ * \param p_state The index state of each IterVar.
+ */
+void PassUpDomain(const Stage& stage,
+                  const std::unordered_map<IterVar, Range>& dom_map,
+                  std::unordered_map<IterVar, IntSet>* p_state);
+
+/*!
+ * \brief Upward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassUpBitMaskOr(const Stage& stage,
+                     std::unordered_map<IterVar, int>* p_state,
+                     bool allow_missing = false);
+
+/*!
+ * \brief Downward message passing of bitmask with or relation.
+ * \param stage The stage to operate on.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownBitMaskOr(const Stage& stage,
+                       std::unordered_map<IterVar, int>* p_state,
+                       bool allow_missing = false);
+}  // namespace schedule
+}  // namespace tvm
+#endif  // TVM_SCHEDULE_MESSAGE_PASSING_H_
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index a6c193e876dd718c2fa264fb0cdedd8f8f800943..b577f0a431a7c3178af0caa39c9a58023ee73fe1 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -7,6 +7,7 @@
 #include <tvm/ir_mutator.h>
 #include <tvm/ir_pass.h>
 #include <unordered_set>
+#include "./message_passing.h"
 
 namespace tvm {
 
@@ -139,7 +140,6 @@ Tensor Schedule::cache_write(const Tensor& tensor,
   return cache_tensor;
 }
 
-
 void RebaseNonZeroMinLoop(const Schedule& sch) {
   std::unordered_map<IterVar, IterVar> rebase_map;
   std::unordered_map<const Node*, int> attach_mark;
@@ -244,4 +244,151 @@ void Schedule::normalize() {
   InjectInline(*this);
 }
 
+// Handle reduction factor.
+Tensor Schedule::rfactor(const Tensor& tensor,
+                         const IterVar& axis) {
+  using ir::Reduce;
+  CHECK_EQ(axis->iter_type, kCommReduce)
+      << "Can only factor reduction axis";
+  Stage reduce_stage = operator[](tensor->op);
+  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
+  CHECK(compute_op) << "Can only factor  ComputeOp";
+  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
+  {
+    size_t axis_pos = FindNodeRef(leaf_vars, axis);
+    CHECK_NE(axis_pos, leaf_vars->data.size())
+        << "Cannot find IterVar " << axis << " in leaf iter vars";
+  }
+  // Find touched reduction axis.
+  std::unordered_map<IterVar, int> touch_map;
+  touch_map[axis] = 1;
+  schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
+  schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
+  // Verify normal axis are not touched.
+  for (IterVar iv : compute_op->axis) {
+    CHECK(!touch_map.count(iv))
+        << "Factor axis touches normal axis.";
+  }
+  // Get the replace index
+  std::unordered_map<IterVar, Range> dom_map;
+  std::unordered_map<IterVar, Expr> value_map;
+  for (IterVar iv : compute_op->reduce_axis) {
+    if (touch_map.count(iv)) dom_map[iv] = iv->dom;
+  }
+  schedule::PassDownDomain(reduce_stage, &dom_map, true);
+  for (IterVar iv : reduce_stage->leaf_iter_vars) {
+    if (touch_map.count(iv)) {
+      Range dom = dom_map.at(iv);
+      if (is_one(dom->extent)) {
+        value_map[iv] = dom->min;
+      } else {
+        value_map[iv] = iv->var;
+      }
+    }
+  }
+  schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
+  // Get the factored op node.
+  auto n = std::make_shared<ComputeOpNode>();
+  n->name = compute_op->name + ".rf";
+  {
+    // axis relacement.
+    auto iv_node = std::make_shared<IterVarNode>();
+    iv_node->dom = dom_map.at(axis);
+    CHECK(is_zero(iv_node->dom->min))
+        << "Can only factor reduction domain starting from 0";
+    iv_node->var = axis->var;
+    iv_node->iter_type = kDataPar;
+    n->axis.push_back(IterVar(iv_node));
+
+    for (IterVar iv : compute_op->axis) {
+      n->axis.push_back(iv);
+    }
+  }
+  // predicate generation, copy not touched axis.
+  std::unordered_map<const Variable*, Expr> vsub;
+  Expr predicate;
+  for (IterVar iv : compute_op->reduce_axis) {
+    if (!touch_map.count(iv)) {
+      n->reduce_axis.push_back(iv);
+    } else {
+      CHECK(value_map.count(iv));
+      Expr index = value_map.at(iv);
+      vsub[iv->var.get()] = index;
+      if (!index.same_as(iv->var)) {
+        Expr cond = (index < dom_map.at(iv)->extent);
+        if (predicate.defined()) {
+          predicate = predicate && cond;
+        } else {
+          predicate = cond;
+        }
+      }
+    }
+  }
+  // Copy touched axis.
+  for (IterVar iv : reduce_stage->leaf_iter_vars) {
+    if (touch_map.count(iv) && !iv.same_as(axis)) {
+      CHECK_EQ(iv->iter_type, kCommReduce);
+      auto ncpy = std::make_shared<IterVarNode>(*iv.operator->());
+      ncpy->dom = dom_map.at(iv);
+      n->reduce_axis.push_back(IterVar(ncpy));
+    }
+  }
+  const Reduce* reduce = compute_op->body.as<Reduce>();
+  CHECK(reduce) << "Can only rfactor non-inline reductions";
+  n->body = Reduce::make(reduce->op,
+                         VarReplacer(vsub).Mutate(reduce->source),
+                         n->reduce_axis,
+                         predicate);
+  // refresh relations, keep the un-touched relations.
+  Array<IterVarRelation> rels;
+  for (IterVarRelation rel : reduce_stage->relations) {
+    bool touched = false;
+    if (const SplitNode* r = rel.as<SplitNode>()) {
+      if (touch_map.count(r->parent)) touched = true;
+    } else if (const FuseNode* r = rel.as<FuseNode>()) {
+      if (touch_map.count(r->fused)) touched = true;
+    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
+      if (touch_map.count(r->parent)) touched = true;
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+    if (!touched) {
+      rels.push_back(rel);
+    }
+  }
+  // initialize the factored stage.
+  Operation factor_op(n);
+  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  size_t stage_pos = FindNodeRef(stages, reduce_stage);
+  Stage factor_stage = Stage(factor_op);
+  factor_stage->relations = rels;
+  CHECK_LT(stage_pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + stage_pos,
+                      factor_stage.node_);
+  (*this)->stage_map.Set(factor_op, factor_stage);
+  // Replace the old reduction.
+  IterVar repl_red_axis = reduce_axis(
+      dom_map.at(axis), axis->var->name_hint + ".v");
+  Tensor factor_tensor = factor_op.output(0);
+  Tensor old_tensor = reduce_stage->op.output(0);
+  Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) {
+      Array<Expr> indices;
+      indices.push_back(repl_red_axis->var);
+      for (Var v : i) {
+        indices.push_back(v);
+      }
+      return Reduce::make(
+          reduce->op, factor_tensor(indices), {repl_red_axis}, const_true());
+    }, old_tensor->op->name + ".repl");
+
+  std::unordered_map<Tensor, Tensor> vmap;
+  vmap[old_tensor] = repl_tensor;
+  ReplaceDataFlow((*this)->stages, &vmap);
+  // revamp the reduction stage.
+  reduce_stage->op = repl_tensor->op;
+  reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars();
+  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
+  reduce_stage->relations = Array<IterVarRelation>();
+  return factor_tensor;
+}
 }  // namespace tvm
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index bbcd832e269e08c9ad8a89c8631ed52a1ccc67a0..318e9b057a0d01e4506bbe402f48adade9a6e166 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -43,11 +43,18 @@ void CheckSplit(StageNode* self, IterVar parent, IterVar outer) {
         << "Cannot split on axis[0] of scan update";
   }
   if (outer.defined()) {
-    CHECK_EQ(outer->iter_type, kThreadIndex)
-        << "outer in split have to be ThreadIndex";
-    CHECK_EQ(parent->iter_type, kDataPar)
-        << "Split by by kThreadIndex requires kDataPar IterVar "
-        << " given " << IterVarType2String(parent->iter_type);
+    if (outer->iter_type == kThreadIndex) {
+      CHECK_EQ(parent->iter_type, kDataPar)
+          << "Split by by kThreadIndex requires kDataPar IterVar "
+          << " given " << IterVarType2String(parent->iter_type);
+    } else if (outer->iter_type == kCommReduce) {
+      CHECK_EQ(parent->iter_type, kCommReduce)
+          << "Split by by kCommReduce requires kCommReduce IterVar "
+          << " given " << IterVarType2String(parent->iter_type);
+    } else {
+      LOG(FATAL) << "Cannot take " << IterVarType2String(parent->iter_type)
+                 << " as outer IterVar";
+    }
   } else {
     CHECK(parent->iter_type == kDataPar ||
           parent->iter_type == kCommReduce ||
@@ -73,18 +80,6 @@ void Split(StageNode* self, IterVar parent,
 
 }  // namespace
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
-  p->stream << "stage("
-            << op->op
-            << ")";
-});
-
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
-    p->stream << IterVarType2String(op->iter_type);
-  });
-
 Stage::Stage(Operation op) {
   auto n = std::make_shared<StageNode>();
   n->op = op;
@@ -374,4 +369,42 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
 TVM_REGISTER_NODE_TYPE(RebaseNode);
 TVM_REGISTER_NODE_TYPE(ScheduleNode);
 
+// Printer
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
+    p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
+})
+.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
+    p->stream << IterVarType2String(op->iter_type);
+})
+.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
+    p->stream << "split(parent=";
+    p->print(op->parent);
+    p->stream << ", outer=";
+    p->print(op->outer);
+    p->stream << ", inner=";
+    p->print(op->inner);
+    p->stream << ')';
+})
+.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
+    p->stream << "split(";
+    p->stream << "outer=";
+    p->print(op->outer);
+    p->stream << ", inner=";
+    p->print(op->inner);
+    p->stream << ", fused=";
+    p->print(op->fused);
+    p->stream << ')';
+})
+.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
+    p->stream << "rebase(";
+    p->stream << "parent=";
+    p->print(op->parent);
+    p->stream << ", rebased=";
+    p->print(op->rebased);
+    p->stream << ')';
+})
+.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
+    p->stream << "schedule(" << op << ")";
+  });
 }  // namespace tvm
diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py
index 4c341444fd39f5774f5015aa7fc5da7f8d155d0e..726cd3f11ed306fb19cdd34c584e04521d1a6b8b 100644
--- a/tests/python/integration/test_reduce.py
+++ b/tests/python/integration/test_reduce.py
@@ -7,7 +7,7 @@ def test_sum():
     m = tvm.Var('m')
     A = tvm.placeholder((n, m), name='A')
     k = tvm.reduce_axis((0, m))
-    B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name='B')
+    B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
     # schedule
     s = tvm.Schedule(B.op)
     # create iter var and assign them tags.
@@ -28,14 +28,17 @@ def test_sum():
                          args=[A, B],
                          target=device, target_host=host,
                          name="mysum")
+        print(fsum.imported_modules[0].get_source())
         # launch the kernel.
         n = 1028
         m = 129
         a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
         b  = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         fsum(a, b)
+        res = np.sum(a.asnumpy(), axis=1)
+        res[:2] = 0
         np.testing.assert_allclose(
-            b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
+            b.asnumpy(), res, rtol=1e-4)
 
     if tvm.module.enabled("opencl"):
         tvm.module.init_opencl()
@@ -43,5 +46,38 @@ def test_sum():
     check_device("cuda")
     check_device("opencl")
 
+
+def test_rfactor():
+    n = tvm.convert(1027)
+    A = tvm.placeholder((n,), name='A')
+    k = tvm.reduce_axis((0, n))
+    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
+    kf = tvm.reduce_axis((0, 4))
+    # schedule
+    s = tvm.Schedule(B.op)
+    _, ki = s[B].split(k, outer=kf)
+    BF = s.rfactor(B, kf)
+    s[BF].parallel(BF.op.axis[0])
+    # one line to build the function.
+    def check_target(target="llvm"):
+        if not tvm.codegen.enabled(target):
+            return
+        ctx = tvm.cpu(0)
+        fapi = tvm.lower(s, args=[A, B])
+        fsum = tvm.build(fapi,
+                         target=target,
+                         name="mysum")
+        # launch the kernel.
+        n = 1027
+        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
+        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+        fsum(a, b)
+        res = np.sum(a.asnumpy(), axis=0)
+        np.testing.assert_allclose(
+            b.asnumpy(), res, rtol=1e-4)
+
+    check_target()
+
 if __name__ == "__main__":
+    test_rfactor()
     test_sum()
diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py
index b4ca987ed30b2d9f9ebad72b7272c00791bbecae..c38dc59a019e19b9fb08a5aa0e8ea96a22f4c375 100644
--- a/tests/python/unittest/test_lang_schedule.py
+++ b/tests/python/unittest/test_lang_schedule.py
@@ -91,8 +91,33 @@ def test_vectorize():
     assert s[T].iter_var_attrs[xi].iter_type == UNROLL
     assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
 
+def test_rfactor():
+    n = tvm.Var('n')
+    k1 = tvm.reduce_axis((0, n), name="k1")
+    k2 = tvm.reduce_axis((0, n), name="k2")
+    A = tvm.placeholder((n, n, n), name='A')
+    B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k1, k2], axis=[k1, k2]))
+    # normal schedule
+    s = tvm.Schedule(B.op)
+    BF = s.rfactor(B, k1)
+    assert(tuple(BF.shape) == (n, n))
+    assert(set(BF.op.body.axis) == set([k2]))
+    assert(s[B].op.body.axis[0].dom.extent == n)
+    assert(len(s[B].all_iter_vars) == 2)
+    # schedule with splot
+    s = tvm.Schedule(B.op)
+    ko, ki = s[B].split(k1, factor=4)
+    xo, xi = s[B].split(B.op.axis[0], factor=8)
+    BF = s.rfactor(B, ki)
+    assert(BF.shape[0].value == 4)
+    assert(BF.shape[1] == n)
+    assert(BF.op.body.axis[0] ==  k2)
+    assert(BF.op.body.axis[1].var ==  ko.var)
+    assert(s[B].op.body.axis[0].dom.extent.value == 4)
+
 
 if __name__ == "__main__":
+    test_rfactor()
     test_schedule_create()
     test_reorder()
     test_tile()
diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py
index e60edd660a8a6ac687fbb976807017d1b5dcee78..15b18ec2f60bedd281bb90d8e4e07f9f49241638 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -100,7 +100,24 @@ def test_bound_blur():
     assert(bounds[A.op.axis[0]].extent.value == 3)
     assert(bounds[A.op.axis[1]].extent.value == 3)
 
+def test_bound_rfactor():
+    n = tvm.Var('n')
+    A = tvm.placeholder((n,), name='A')
+    k = tvm.reduce_axis((0, n))
+    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
+    kf = tvm.reduce_axis((0, 4))
+    # schedule
+    s = tvm.Schedule(B.op)
+    _, ki = s[B].split(k, outer=kf)
+    BF = s.rfactor(B, kf)
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    assert(bounds[BF.op.axis[0]].extent.value == 4)
+    assert(bounds[BF.op.axis[1]].extent.value == 1)
+
+
 if __name__ == "__main__":
+    test_bound_rfactor()
     test_bound_blur()
     test_bound_conv1d()
     test_bound_scan()