From 54593ca1cc90e98abb09e64d73f1694f8082c757 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Fri, 31 Mar 2017 21:47:54 -0700
Subject: [PATCH] [LANG/GPU] Cross Thread Reduction (#79)

* [LANG/GPU] Cross Thread Reduction.

* Fix doxygen error

* Upgrade verilog testcase to new one
---
 include/tvm/expr.h                            |   2 +-
 include/tvm/ir.h                              |  30 ++
 include/tvm/ir_pass.h                         |   7 +
 include/tvm/schedule.h                        |   8 +
 python/tvm/build.py                           |   3 +-
 python/tvm/schedule.py                        |  18 ++
 src/api/api_lang.cc                           |   7 +
 src/api/api_pass.cc                           |   1 +
 src/codegen/codegen_c.cc                      |  82 ++++--
 src/codegen/codegen_c.h                       |  15 +-
 src/codegen/codegen_opencl.cc                 |  10 +-
 src/codegen/codegen_opencl.h                  |   5 +-
 src/codegen/codegen_source_base.cc            |   1 +
 src/lang/ir.cc                                |  26 ++
 src/lang/tensor.cc                            |   1 -
 src/op/compute_op.cc                          | 127 ++++++--
 src/op/op_util.cc                             |  24 +-
 src/op/op_util.h                              |  10 +-
 src/op/scan_op.cc                             |   3 +-
 src/pass/ir_util.h                            |  15 +
 src/pass/lower_thread_allreduce.cc            | 275 ++++++++++++++++++
 src/runtime/thread_storage_scope.h            |  14 +-
 src/schedule/message_passing.cc               |  12 +-
 src/schedule/schedule_dataflow_rewrite.cc     |  12 +-
 src/schedule/schedule_lang.cc                 |  19 ++
 tests/python/integration/test_reduce.py       |  52 +++-
 .../{ => unittest}/test_buffer_doublebuff.py  |  12 +-
 .../{ => unittest}/test_buffer_doublebuff.v   |   0
 .../{ => unittest}/test_buffer_fifo.py        |   4 +-
 .../verilog/{ => unittest}/test_buffer_fifo.v |   0
 .../{ => unittest}/test_buffer_linebuff.py    |  10 +-
 .../{ => unittest}/test_buffer_linebuff.v     |   0
 32 files changed, 692 insertions(+), 113 deletions(-)
 create mode 100644 src/pass/lower_thread_allreduce.cc
 rename tests/verilog/{ => unittest}/test_buffer_doublebuff.py (89%)
 rename tests/verilog/{ => unittest}/test_buffer_doublebuff.v (100%)
 rename tests/verilog/{ => unittest}/test_buffer_fifo.py (94%)
 rename tests/verilog/{ => unittest}/test_buffer_fifo.v (100%)
 rename tests/verilog/{ => unittest}/test_buffer_linebuff.py (92%)
 rename tests/verilog/{ => unittest}/test_buffer_linebuff.v (100%)

diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index 91efe1727..7162c92d4 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -299,7 +299,7 @@ inline const char* IterVarType2String(IterVarType t) {
   switch (t) {
     case kDataPar: return "DataPar";
     case kThreadIndex: return "ThreadIndex";
-    case kCommReduce: return "CommRedude";
+    case kCommReduce: return "CommReduce";
     case kOrdered: return "Ordered";
     case kOpaque: return "Opaque";
     case kUnrolled: return "Unrolled";
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 5fdc6fa21..a70de3586 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -42,6 +42,21 @@ struct Reduce : public ExprNode<Reduce> {
   static Expr make(std::string op, Expr src,
                    Array<IterVar> rdom,
                    Expr condition = const_true());
+  /*!
+   * \brief Get initial value for reduction.
+   * \param op The operator
+   * \param type The data type.
+   * \return The initial value that can be assigned to reduction.
+   */
+  static Expr InitValue(const std::string& op, Type type);
+  /*!
+   * \brief Combine two values with given reduction.
+   * \param op The operator
+   * \param a The left operand.
+   * \param b The left operand.
+   * \return The combined reduction result.
+   */
+  static Expr Combine(const std::string& op, Expr a, Expr b);
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("dtype", &type);
@@ -86,6 +101,10 @@ constexpr const char* thread_extent = "thread_extent";
  * \brief Mark launching of a virtual thread.
  */
 constexpr const char* virtual_thread = "virtual_thread";
+/*!
+ * \brief Mark the scope as volatile access for certain handle.
+ */
+constexpr const char* volatile_scope = "volatile_scope";
 /*!
  * \brief Mark storage scope of buffers
  */
@@ -164,6 +183,17 @@ constexpr const char* tvm_call_packed = "tvm_call_packed";
  *  }
  */
 constexpr const char* tvm_storage_sync = "tvm_storage_sync";
+/*!
+ * \brief See pesudo code
+ *
+ *  Expr tvm_thread_allreduce(std::string op, Expr value, Expr cond,
+ *                             Var thread_idx1, thread_idx2...) {
+ *     // constraint by the other thread_idx remain the same.
+ *     return reduce(op, value, cond,
+ *                   over [thread_idx1, thread_idx2] passed by any caller)
+ *  }
+ */
+constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
 
 /*! \brief The field id of each field in array */
 enum TVMArrayFieldKind {
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 082400b61..9f57724f7 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -234,6 +234,13 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
  */
 LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
 
+/*!
+ * \brief Lower cross thread alleduce in the stmt.
+ * \param f The device function to be lowered.
+ * \param warp_size the size of warp where no sync is needed.
+ * \return Transformed function.
+ */
+LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
 }  // namespace ir
 }  // namespace tvm
 
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 93b93a62c..5f7c0e0eb 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -73,6 +73,14 @@ class Stage : public NodeRef {
    * \return reference to self.
    */
   Stage& compute_root();  // NOLINT(*)
+  /*!
+   * \brief Rebase the parent iter var as rebased variable.
+   *
+   * \param parent The parent iteration domain.
+   * \param rebased The variable to be used in rebase.
+   * \return reference to self.
+   */
+  Stage& rebase(IterVar parent, IterVar rebased);
   /*!
    * \brief Split the parent by factor, generate
    * \param parent The parent iteration domain.
diff --git a/python/tvm/build.py b/python/tvm/build.py
index ec5a0dba1..588ead632 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -71,7 +71,6 @@ def lower(sch,
     return fapi
 
 
-
 def build(sch,
           args=None,
           target="llvm",
@@ -128,6 +127,8 @@ def build(sch,
     fsplits = [x for x in fsplits]
     for i in range(1, len(fsplits)):
         fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
+        warp_size = 32 if target == "cuda" else 1
+        fsplits[i] = ir_pass.LowerThreadAllreduce(fsplits[i], warp_size)
 
     if len(fsplits) > 1:
         mhost = codegen.build(fsplits[0], target_host)
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index dcddafa4c..7d3562b3b 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -112,6 +112,24 @@ class Schedule(NodeBase):
 @register_node
 class Stage(NodeBase):
     """A Stage represents schedule for one operation."""
+    def rebase(self, parent, rebased):
+        """Rebase parent  by an existing thread axis.
+
+        Parameters
+        ----------
+        parent : IterVar
+             The parent iter var.
+
+        rebased : IterVar
+             The rebased iter var.
+        Returns
+        -------
+        rebased : IterVar
+            The rebased itervar.
+        """
+        _api_internal._StageRebase(self, parent, rebased)
+        return rebased
+
     def split(self, parent, factor=None, outer=None):
         """Split the stage either by factor providing outer scope, or both
 
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 933adc872..cb7437333 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -219,6 +219,13 @@ TVM_REGISTER_API(_StageSetScope)
         .set_scope(args[1]);
   });
 
+TVM_REGISTER_API(_StageRebase)
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    IterVar outer, inner;
+    args[0].operator Stage()
+        .rebase(args[1], args[2]);
+  });
+
 TVM_REGISTER_API(_StageSplitByFactor)
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     IterVar outer, inner;
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 22624f608..96e94f2b8 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -73,6 +73,7 @@ REGISTER_PASS1(InjectVirtualThread);
 REGISTER_PASS1(LoopPartition);
 REGISTER_PASS1(RemoveNoOp);
 REGISTER_PASS2(SplitPipeline);
+REGISTER_PASS2(LowerThreadAllreduce);
 REGISTER_PASS1(NarrowChannelAccess);
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
index b288ab82e..95b7901f4 100644
--- a/src/codegen/codegen_c.cc
+++ b/src/codegen/codegen_c.cc
@@ -88,14 +88,26 @@ void CodeGenC::PrintSSAAssign(
 }
 
 // Print a reference expression to a buffer.
-void CodeGenC::PrintBufferRef(
+std::string CodeGenC::GetBufferRef(
     const Variable* buffer,
-    Type t, Expr index,
-    std::ostream& os) {  // NOLINT(*)
+    Type t, Expr index) {
+  std::ostringstream os;
   std::string vid = GetVarID(buffer);
+  std::string scope;
+  if (alloc_storage_scope_.count(buffer)) {
+    scope = alloc_storage_scope_.at(buffer);
+  }
+  bool is_vol = volatile_buf_.count(buffer);
   if (t.lanes() == 1) {
-    if (!HandleTypeMatch(buffer, t)) {
+    if (!HandleTypeMatch(buffer, t) || is_vol) {
       os << "((";
+      if (is_vol) {
+        os << "volatile ";
+      }
+      if (scope.length() != 0) {
+        PrintStorageScope(scope, os);
+      }
+      os << ' ';
       PrintType(t, os);
       os << "*)" << vid << ')';
     } else {
@@ -107,17 +119,24 @@ void CodeGenC::PrintBufferRef(
   } else {
     // Buffer declared as vector type.
     // optimize for case where it is in register,
-    if (HandleTypeMatch(buffer, t)) {
+    if (HandleTypeMatch(buffer, t) && !is_vol) {
       // optimize for constant access
       int offset;
       if (arith::GetConstInt(index, &offset)) {
         CHECK_EQ(offset % t.lanes(), 0)
             << "Find unaligned vector load to a vector type";
         os << vid << '[' << (offset / t.lanes()) << ']';
-        return;
+        return os.str();
       }
     }
     os << "((";
+    if (is_vol) {
+      os << "volatile ";
+    }
+    if (scope.length() != 0) {
+      PrintStorageScope(scope, os);
+    }
+    os << ' ';
     PrintType(t, os);
     os << "*)(";
     if (!HandleTypeMatch(buffer, t.element_of())) {
@@ -129,6 +148,7 @@ void CodeGenC::PrintBufferRef(
     PrintExpr(index, os);
     os << "))[0]";
   }
+  return os.str();
 }
 
 
@@ -162,18 +182,17 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
          << " = " << value << ";\n";
 }
 
-void CodeGenC::PrintVecLoad(const Variable* buffer,
-                            Type t, Expr base,
-                            std::ostream& os) {
-  PrintBufferRef(buffer, t, base, os);
+std::string CodeGenC::GetVecLoad(const Variable* buffer,
+                                   Type t, Expr base) {
+  return GetBufferRef(buffer, t, base);
 }
 
 void CodeGenC::PrintVecStore(const Variable* buffer,
                              Type t, Expr base,
                              const std::string& value) {
+  std::string ref = GetBufferRef(buffer, t, base);
   this->PrintIndent();
-  PrintBufferRef(buffer, t, base, stream);
-  stream << " = " << value << ";\n";
+  stream << ref << " = " << value << ";\n";
 }
 
 void CodeGenC::PrintThreadIndexExpr(
@@ -483,24 +502,21 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
 
 void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
   int lanes = op->type.lanes();
-  std::string svalue = GetUniqueName("_");
   // delcare type.
-  this->PrintIndent();
-  this->PrintType(op->type, stream);
-  stream << ' ' << svalue;
   if (op->type.lanes() == 1) {
-    stream << " = ";
-    this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream);
-    stream << ";\n";
+    std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index);
+    os << ref;
   } else {
     Expr base;
     if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
-      stream << " = ";
-      this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream);
-      stream << ";\n";
+      std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base);
+      os << ref;
     } else {
-      // Load elements seperately
-      stream << ";\n";
+      // load seperately.
+      std::string svalue = GetUniqueName("_");
+      this->PrintIndent();
+      this->PrintType(op->type, stream);
+      stream << ' ' << svalue << ";\n";
       std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
       std::string vid = GetVarID(op->buffer_var.get());
       Type elem_type = op->type.element_of();
@@ -518,18 +534,18 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
         value_temp << ']';
         PrintVecElemStore(svalue, op->type, i, value_temp.str());
       }
+      os << svalue;
     }
   }
-  os << svalue;
 }
 
 void CodeGenC::VisitStmt_(const Store* op) {
   Type t = op->value.type();
   if (t.lanes() == 1) {
     std::string value = this->PrintExpr(op->value);
+    std::string ref  = this->GetBufferRef(op->buffer_var.get(), t, op->index);
     this->PrintIndent();
-    this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
-    stream << " = " << value << ";\n";
+    stream << ref << " = " << value << ";\n";
   } else {
     Expr base;
     if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
@@ -577,7 +593,13 @@ void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*
 }
 
 void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
-  LOG(FATAL) << "Select: not supported ";
+  os << "(";
+  PrintExpr(op->condition, os);
+  os << " ? ";
+  PrintExpr(op->true_value, os);
+  os << " : ";
+  PrintExpr(op->false_value, os);
+  os << ")";
 }
 
 void CodeGenC::VisitStmt_(const LetStmt* op) {
@@ -649,6 +671,10 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
     const Variable* v = op->node.as<Variable>();
     CHECK(v);
     alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
+  } else if (op->type_key == ir::attr::volatile_scope) {
+    const Variable* v = op->node.as<Variable>();
+    CHECK(v);
+    volatile_buf_.insert(v);
   }
   this->PrintStmt(op->body);
 }
diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h
index bd7ef9ba3..e682e089b 100644
--- a/src/codegen/codegen_c.h
+++ b/src/codegen/codegen_c.h
@@ -13,6 +13,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include <unordered_set>
 #include "./codegen_source_base.h"
 
 namespace tvm {
@@ -132,9 +133,8 @@ class CodeGenC :
       const std::string&op, Type op_type,
       Expr lhs, Expr rhs, std::ostream& os);  // NOLINT(*)
   // print vector load
-  virtual void PrintVecLoad(const Variable* buffer,
-                            Type t, Expr base,
-                            std::ostream& os);  // NOLINT(*)
+  virtual std::string GetVecLoad(const Variable* buffer,
+                                 Type t, Expr base);
   // print vector store
   virtual void PrintVecStore(const Variable* buffer,
                              Type t, Expr base,
@@ -149,9 +149,8 @@ class CodeGenC :
 
  protected:
   // print reference to a buffer as type t in index.
-  void PrintBufferRef(const Variable* buffer,
-                      Type t, Expr index,
-                      std::ostream& os);  // NOLINT(*)
+  std::string GetBufferRef(const Variable* buffer,
+                           Type t, Expr index);
   /*!
    * \brief If buffer is allocated as type t.
    * \param buf_var The buffer variable.
@@ -172,9 +171,11 @@ class CodeGenC :
 
  private:
   /*! \brief whether to print in SSA form */
-  bool print_ssa_form_{true};
+  bool print_ssa_form_{false};
   /*! \brief the data type of allocated buffers */
   std::unordered_map<const Variable*, Type> handle_data_type_;
+  /*! \brief set of volatile buf access */
+  std::unordered_set<const Variable*> volatile_buf_;
 };
 
 }  // namespace codegen
diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc
index 715f72068..5b3c93c7a 100644
--- a/src/codegen/codegen_opencl.cc
+++ b/src/codegen/codegen_opencl.cc
@@ -95,12 +95,13 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
   os << GetVarID(buffer) << " + ";
   PrintExpr(base, os);
 }
-void CodeGenOpenCL::PrintVecLoad(const Variable* buffer,
-                                 Type t, Expr base,
-                                 std::ostream& os) {
+std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer,
+                                      Type t, Expr base) {
+  std::ostringstream os;
   os << "vload" << t.lanes() << "(0, ";
   PrintVecAddr(buffer, t, base, os);
   os << ")";
+  return os.str();
 }
 
 void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
@@ -121,7 +122,8 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
   }
 }
 
-void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::PrintStorageScope(
+    const std::string& scope, std::ostream& os) { // NOLINT(*)
   if (scope == "global") {
     os << "__global";
   } else if (scope == "shared") {
diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h
index 55168fdfe..fdd8d5615 100644
--- a/src/codegen/codegen_opencl.h
+++ b/src/codegen/codegen_opencl.h
@@ -24,9 +24,8 @@ class CodeGenOpenCL : public CodeGenC {
   void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
   void PrintStorageSync(const std::string& scope) final;  // NOLINT(*)
   void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
-  void PrintVecLoad(const Variable* buffer,
-                    Type t, Expr base,
-                    std::ostream& os) final;  // NOLINT(*)
+  std::string GetVecLoad(const Variable* buffer,
+                         Type t, Expr base) final;
   void PrintVecStore(const Variable* buffer,
                       Type t, Expr base,
                       const std::string& value) final;  // NOLINT(*)
diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc
index cf3a6ec5a..2066e90bb 100644
--- a/src/codegen/codegen_source_base.cc
+++ b/src/codegen/codegen_source_base.cc
@@ -35,6 +35,7 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
 }
 
 std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
+  LOG(INFO) << "ssa get id";
   if (name_alloc_map_.count(src)) return src;
   auto it = ssa_assign_map_.find(src);
   if (it != ssa_assign_map_.end()) {
diff --git a/src/lang/ir.cc b/src/lang/ir.cc
index 55a4d7a0d..f7aa94b09 100644
--- a/src/lang/ir.cc
+++ b/src/lang/ir.cc
@@ -61,6 +61,32 @@ Expr Reduce::make(std::string op, Expr source,
   return Expr(n);
 }
 
+Expr Reduce::InitValue(const std::string& op, Type type) {
+  if (op == "Add") {
+    return make_zero(type);
+  } else if (op == "Max") {
+    return type.min();
+  } else if (op == "Min") {
+    return type.max();
+  } else {
+    LOG(FATAL) << "Unsupported reduction " << op;
+    return Expr();
+  }
+}
+
+Expr Reduce::Combine(const std::string& op, Expr a, Expr b) {
+  if (op == "Add") {
+    return Add::make(a, b);
+  } else if (op == "Max") {
+    return Max::make(a, b);
+  } else if (op == "Min") {
+    return Min::make(a, b);
+  } else {
+    LOG(FATAL) << "Unsupported reduction " << op;
+    return Expr();
+  }
+}
+
 TVM_REGISTER_NODE_TYPE(Reduce);
 TVM_REGISTER_NODE_TYPE(AttrStmt);
 
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
index 962960f56..3d894a04a 100644
--- a/src/lang/tensor.cc
+++ b/src/lang/tensor.cc
@@ -20,7 +20,6 @@ Expr Tensor::operator()(Array<Expr> indices) const {
   return n;
 }
 
-
 Tensor TensorNode::make(Array<Expr> shape,
                         Type dtype,
                         Operation op,
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index e2467bc32..185714971 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -174,19 +174,8 @@ void MakeReduction(const ComputeOpNode* op,
   }
   const Reduce* reduce = op->body.as<Reduce>();
   CHECK(reduce);
-  Expr init_value, update_value;
-  if (reduce->op == "Add") {
-    init_value = make_zero(reduce->type);
-    update_value = Add::make(t(args), reduce->source);
-  } else if (reduce->op == "Max") {
-    init_value = reduce->type.min();
-    update_value = Max::make(t(args), reduce->source);
-  } else if (reduce->op == "Min") {
-    init_value = reduce->type.max();
-    update_value = Min::make(t(args), reduce->source);
-  } else {
-    LOG(FATAL) << "Unsupported reduction " << reduce->op;
-  }
+  Expr init_value = Reduce::InitValue(reduce->op, reduce->type);
+  Expr update_value = Reduce::Combine(reduce->op, t(args), reduce->source);
   *init = Provide::make(t->op, t->value_index, init_value, args);
   *provide = Provide::make(t->op, t->value_index, update_value, args);
   if (!is_one(reduce->condition)) {
@@ -194,15 +183,6 @@ void MakeReduction(const ComputeOpNode* op,
   }
 }
 
-Stmt MakeProvide(const ComputeOpNode* op,
-                 const Tensor& t) {
-  Array<Expr> args;
-  for (IterVar iv : op->axis) {
-    args.push_back(iv->var);
-  }
-  return Provide::make(t->op, t->value_index, op->body, args);
-}
-
 Stmt Substitute(Stmt s,
                 const std::unordered_map<IterVar, Expr>& value_map) {
   Map<Var, Expr> temp;
@@ -212,11 +192,107 @@ Stmt Substitute(Stmt s,
   return ir::Substitute(s, temp);
 }
 
+// Cross Thread reduction marker.
+bool IsCrossThreadReduction(const ComputeOpNode* self,
+                            const Stage& stage) {
+  std::unordered_set<IterVar> rebase_thread;
+  for (IterVarRelation rel : stage->relations) {
+    if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (s->parent->iter_type == kCommReduce &&
+          s->rebased->iter_type == kThreadIndex) {
+        rebase_thread.insert(s->rebased);
+      }
+    }
+  }
+  if (rebase_thread.size() == 0) return false;
+  // Verify correctness of leaf nest.
+  bool reduce_start =  false;
+  for (IterVar iv : stage->leaf_iter_vars) {
+    if (iv->iter_type == kCommReduce) {
+      LOG(FATAL) << "Cannot mix cross thread reduce with normal reduce";
+    } else if (rebase_thread.count(iv)) {
+      reduce_start = true;
+    } else {
+      CHECK(!reduce_start)
+          << "Cross thread reduce cannot swap with normal data axis";
+    }
+  }
+  return true;
+}
+
+Stmt MakeCrossThreadReduction(
+    const ComputeOpNode* self,
+    const Stage& stage,
+    const std::unordered_map<IterVar, Range>& dom_map) {
+  Array<Expr>  args;
+  for (IterVar iv : self->axis) {
+    args.push_back(iv->var);
+  }
+  const Reduce* reduce = self->body.as<Reduce>();
+  CHECK(reduce);
+  std::unordered_map<IterVar, Expr> value_map;
+  auto nest = op::MakeLoopNest(
+      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
+  auto conds = op::MakeBoundCheck(
+      stage, dom_map, false,
+      std::unordered_set<IterVar>(), value_map);
+  Expr cond = reduce->condition;
+  for (Expr v : conds) {
+    cond = cond && v;
+  }
+  Var res_handle("reduce_temp", Handle());
+  Array<Expr> freduce_args;
+  freduce_args.push_back(StringImm::make(reduce->op));
+  freduce_args.push_back(reduce->source);
+  freduce_args.push_back(cond);
+
+  std::vector<Expr> thread_head_check;
+  for (IterVarRelation rel : stage->relations) {
+    if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (s->parent->iter_type == kCommReduce &&
+          s->rebased->iter_type == kThreadIndex) {
+        freduce_args.push_back(s->rebased->var);
+        thread_head_check.push_back(s->rebased->var == 0);
+      }
+    }
+  }
+  Stmt reduce_body = Store::make(
+      res_handle, Call::make(
+          reduce->type,
+          ir::intrinsic::tvm_thread_allreduce,
+          freduce_args, Call::Intrinsic),
+      0);
+  Stmt assign_body = Provide::make(
+      stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
+  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
+  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
+  Stmt body = Allocate::make(
+      res_handle, reduce->type, {1}, const_true(),
+      Block::make(reduce_body, assign_body));
+  body = AttrStmt::make(
+      res_handle, attr::storage_scope, StringImm::make("local"), body);
+  body = Substitute(body, value_map);
+  return MergeNest(nest, body);
+}
+
+Stmt MakeProvide(const ComputeOpNode* op,
+                 const Tensor& t) {
+  Array<Expr> args;
+  for (IterVar iv : op->axis) {
+    args.push_back(iv->var);
+  }
+  return Provide::make(t->op, t->value_index, op->body, args);
+}
+
 Stmt ComputeOpNode::BuildProvide(
     const Stage& stage,
     const std::unordered_map<IterVar, Range>& dom_map) const {
   CHECK_EQ(stage->op.operator->(), this);
 
+  if (IsCrossThreadReduction(this, stage)) {
+    // specially handle cross thread reduction.
+    return MakeCrossThreadReduction(this, stage, dom_map);
+  }
   Stmt init, provide;
   if (this->reduce_axis.size() == 0) {
     provide = MakeProvide(this, stage->op.output(0));
@@ -227,9 +303,9 @@ Stmt ComputeOpNode::BuildProvide(
   std::unordered_map<IterVar, Expr> value_map;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
-  nest.push_back(op::MakeBoundCheck(
+  nest.push_back(op::MakeIfNest(op::MakeBoundCheck(
       stage, dom_map, false,
-      std::unordered_set<IterVar>(), value_map));
+      std::unordered_set<IterVar>(), value_map)));
   provide = Substitute(provide, value_map);
 
   if (init.defined()) {
@@ -266,7 +342,8 @@ Stmt ComputeOpNode::BuildProvide(
         stage, dom_map, begin_loop, true,
         skip_iter, &init_value_map);
     init_nest.push_back(
-        op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map));
+        op::MakeIfNest(
+            op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map)));
     init = Substitute(init, init_value_map);
     init  = MergeNest(init_nest, init);
     // common nest
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index 487be17cc..640652008 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -160,37 +160,45 @@ void PassUpBoundCheck(const Stage& s,
   }
 }
 
-std::vector<Stmt> MakeBoundCheck(
+std::vector<Expr> MakeBoundCheck(
     const Stage& stage,
     const Map<IterVar, Range>& dom_map,
     bool skip_ivar_domain,
     const std::unordered_set<IterVar>& skip_iter,
     const std::unordered_map<IterVar, Expr>& value_map) {
-  Stmt no_op = Evaluate::make(0);
   std::unordered_map<IterVar, bool> bound_state;
   for (IterVar iv : stage->leaf_iter_vars) {
     bound_state[iv] = false;
   }
   PassUpBoundCheck(stage, dom_map, &bound_state);
-  // insert conditions
-  std::vector<Stmt> nest;
+  std::vector<Expr> preds;
   for (IterVar iv : stage->op->root_iter_vars()) {
     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
     Range dom = dom_map.at(iv);
     if (bound_state.at(iv)) {
-      Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent;
-      nest.emplace_back(IfThenElse::make(condition, no_op));
+      preds.emplace_back(
+          ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent);
     }
     CHECK(iv->dom.defined());
     if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
-      Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent;
-      nest.emplace_back(IfThenElse::make(condition, no_op));
+      preds.emplace_back(
+          ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent);
     }
   }
+  return preds;
+}
+
+std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
+  Stmt no_op = Evaluate::make(0);
+  std::vector<Stmt> nest;
+  for (const Expr& cond : predicates) {
+    nest.emplace_back(IfThenElse::make(cond, no_op));
+  }
   return nest;
 }
 
 
+
 // replacer to replace tensors
 class TensorReplacer : public ir::IRMutator {
  public:
diff --git a/src/op/op_util.h b/src/op/op_util.h
index ca37c0d5f..914815f9a 100644
--- a/src/op/op_util.h
+++ b/src/op/op_util.h
@@ -43,13 +43,21 @@ MakeLoopNest(const Stage& stage,
  * \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
  * \param skip_iter Whether skip certain iteration.
  * \param value_map The result value of each IterVar.
+ * \return List of predicates that we need to check.
  */
-std::vector<Stmt>
+std::vector<Expr>
 MakeBoundCheck(const Stage& stage,
                const Map<IterVar, Range>& dom_map,
                bool skip_ivar_domain,
                const std::unordered_set<IterVar>& skip_iter,
                const std::unordered_map<IterVar, Expr>& value_map);
+/*!
+ * \brief Create a nest of if checking the predicates.
+ *
+ * \param predicates The predicates to be checked.
+ * \return List of If nest that checks the predicates.
+ */
+std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
 
 /*!
  * \brief Replace the tensor reference in stmt by the replace map.
diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc
index f34aee579..45031082d 100644
--- a/src/op/scan_op.cc
+++ b/src/op/scan_op.cc
@@ -269,7 +269,8 @@ Stmt ScanOpNode::BuildProvide(
       stage, dom_map, 0, false, empty, &vmap);
   nest[begin_scan].push_back(init);
   nest.push_back(
-      op::MakeBoundCheck(stage, dom_map, false, empty, vmap));
+      op::MakeIfNest(
+          op::MakeBoundCheck(stage, dom_map, false, empty, vmap)));
   return MergeNest(nest, provide);
 }
 }  // namespace tvm
diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h
index 2fbff8099..47a57a9ed 100644
--- a/src/pass/ir_util.h
+++ b/src/pass/ir_util.h
@@ -70,6 +70,21 @@ inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
   return body;
 }
 
+
+/*!
+ * \brief combine sequence of operations.
+ * \param seq The sequence.
+ * \return The combined Stmt
+ */
+inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
+  if (seq.size() == 0) return Evaluate::make(0);
+  Stmt body = seq[0];
+  for (size_t i = 1; i < seq.size(); ++i) {
+    body = Block::make(body, seq[i]);
+  }
+  return body;
+}
+
 }  // namespace ir
 }  // namespace tvm
 #endif  // TVM_PASS_IR_UTIL_H_
diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc
new file mode 100644
index 000000000..1b70c52ab
--- /dev/null
+++ b/src/pass/lower_thread_allreduce.cc
@@ -0,0 +1,275 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ *  Lower allreduce to device implementable ir.
+ * \file lower_thread_allreduce.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+#include "./ir_util.h"
+#include "../arithmetic/compute_expr.h"
+#include "../runtime/thread_storage_scope.h"
+
+namespace tvm {
+namespace ir {
+
+class ThreadAllreduceBuilder : public IRMutator {
+ public:
+  explicit ThreadAllreduceBuilder(int warp_size)
+      : warp_size_(warp_size) {}
+
+  Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
+    if (op->type_key == attr::thread_extent) {
+      thread_extents_.push_back(op);
+      Stmt ret = IRMutator::Mutate_(op, s);
+      thread_extents_.pop_back();
+      return ret;
+    } else if (op->type_key == attr::storage_scope) {
+      Stmt ret = IRMutator::Mutate_(op, s);
+      op = ret.as<AttrStmt>();
+      const Variable* v = op->node.as<Variable>();
+      if (alloc_remap_.count(v)) {
+        return op->body;
+      } else {
+        return ret;
+      }
+    } else {
+      return IRMutator::Mutate_(op, s);
+    }
+  }
+  Stmt Mutate_(const Store* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Store>();
+    const Call* call = op->value.as<Call>();
+    if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
+      return MakeAllreduce(op, call);
+    } else {
+      return stmt;
+    }
+  }
+  Stmt Mutate_(const Allocate* op, const Stmt& s) final {
+    Stmt stmt = IRMutator::Mutate_(op, s);
+    op = stmt.as<Allocate>();
+    auto it = alloc_remap_.find(op->buffer_var.get());
+    if (it != alloc_remap_.end()) {
+      const Allocate* repl = it->second.as<Allocate>();
+      // use volatile access to shared buffer.
+      stmt = AttrStmt::make(
+          repl->buffer_var, attr::volatile_scope, 1, op->body);
+      stmt = Allocate::make(
+          repl->buffer_var, repl->type,
+          repl->extents, repl->condition, stmt);
+      stmt = AttrStmt::make(
+          repl->buffer_var, attr::storage_scope,
+          StringImm::make("shared"), stmt);
+      return stmt;
+    } else {
+      return stmt;
+    }
+  }
+  Expr Mutate_(const Load* op, const Expr& e) final {
+    auto it = load_remap_.find(op->buffer_var.get());
+    if (it != load_remap_.end()) {
+      CHECK(is_zero(op->index));
+      return it->second;
+    } else {
+      return IRMutator::Mutate_(op, e);
+    }
+  }
+
+ private:
+  // Thread entry
+  struct ThreadEntry {
+    runtime::ThreadScope scope;
+    IterVar iv;
+    int extent;
+    // comparator
+    bool operator<(const ThreadEntry& other) const {
+      return scope.dim_index < other.scope.dim_index;
+    }
+  };
+  // make allreduce.
+  Stmt MakeAllreduce(const Store* op, const Call* call) {
+    const std::string& op_code = call->args[0].as<StringImm>()->value;
+    Expr value = call->args[1];
+    Expr cond = call->args[2];
+    if (!is_one(cond)) {
+      value = Select::make(
+          cond, value, Reduce::InitValue(op_code, value.type()));
+    }
+
+    std::unordered_set<const Variable*> reduce_index_;
+    for (size_t i = 3; i < call->args.size(); ++i) {
+      const Variable* v = call->args[i].as<Variable>();
+      CHECK(v);
+      reduce_index_.insert(v);
+    }
+    size_t nmatch = 0;
+    std::vector<ThreadEntry> vred, vpar;
+    for (const AttrStmt* attr : thread_extents_) {
+      ThreadEntry e;
+      IterVar iv(attr->node.node_);
+      e.scope = runtime::ThreadScope::make(iv->thread_tag);
+      e.iv = iv;
+      CHECK(arith::GetConstInt(attr->value, &(e.extent)))
+          << "Need constant extent for thread group";
+      CHECK_LE(e.scope.rank, 1);
+      CHECK_GE(e.scope.dim_index, 0)
+          << "vthread do not work with cross thread reduction";
+      if (e.scope.rank == 1) {
+        if (reduce_index_.count(iv->var.get())) {
+          vred.push_back(e);
+          ++nmatch;
+        } else {
+          vpar.push_back(e);
+        }
+      }
+    }
+    CHECK_EQ(nmatch, reduce_index_.size())
+        << "Not all reduce index are presented in the context";
+    std::sort(vred.begin(), vred.end());
+    std::sort(vpar.begin(), vpar.end());
+    // the size of each index.
+    int reduce_extent, group_extent;
+    int threadx_extent = 1;
+    Expr reduce_index = FlattenThread(vred, &reduce_extent);
+    Expr group_index = FlattenThread(vpar, &group_extent);
+    if (reduce_extent == 1) {
+      // special case, no reduction is needed.
+      return Store::make(op->buffer_var, value, 0);
+    }
+    // Whether the threadIdx.x is involved in reduction.
+    if (vred[0].scope.dim_index == 0) {
+      threadx_extent = vred[0].extent;
+    }
+    Var shared_buf("red_buf", Handle());
+    std::vector<Stmt> seq;
+    seq.emplace_back(Store::make(
+        shared_buf, value,
+        BufIndex(reduce_index, group_index, reduce_extent)));
+    seq.emplace_back(SyncThread());
+    seq.emplace_back(MakeBufAllreduce(
+        op_code, value.type(), shared_buf,
+        reduce_index, group_index, reduce_extent, threadx_extent));
+    CHECK(!load_remap_.count(op->buffer_var.get()));
+    load_remap_[op->buffer_var.get()] =
+        Load::make(
+            value.type(), shared_buf,
+            BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent));
+    alloc_remap_[op->buffer_var.get()] =
+        Allocate::make(shared_buf, value.type(),
+                       {Expr(group_extent), Expr(reduce_extent)},
+                       const_true(), Evaluate::make(0));
+    return MergeSeq(seq);
+  }
+  // make allreduce.
+  Stmt MakeBufAllreduce(const std::string& op,
+                        Type type,
+                        Var shared_buf,
+                        Expr reduce_index,
+                        Expr group_index,
+                        int reduce_extent,
+                        int threadx_extent) {
+    // Get next power of two
+    int reduce_align = 1;
+    while (reduce_extent > reduce_align) {
+      reduce_align = reduce_align << 1;
+    }
+    CHECK_GT(reduce_align, 1);
+    std::vector<Stmt> seq;
+
+    Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
+    // make reduction
+    auto freduce = [&](int offset) {
+      Expr b = Load::make(
+          type, shared_buf,
+          BufIndex(reduce_index + offset, group_index, reduce_extent));
+      Expr a = Load::make(type, shared_buf, buf_index);
+      return Store::make(shared_buf, Reduce::Combine(op, a, b), buf_index);
+    };
+    // Step one, check for
+    if (reduce_align > reduce_extent) {
+      // reduction with the boundary condition
+      reduce_align = reduce_align >> 1;
+      Expr cond = reduce_index < (reduce_extent - reduce_align);
+      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+      seq.emplace_back(SyncThread());
+    }
+    CHECK(threadx_extent >= 1 && warp_size_ >= 1);
+    // normal synchronization
+    while (reduce_align > threadx_extent ||
+           reduce_align > warp_size_) {
+      reduce_align =  reduce_align >> 1;
+      Expr cond = reduce_index < reduce_align;
+      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+      seq.emplace_back(SyncThread());
+    }
+    // in warp synchronization.
+    std::vector<Stmt> in_warp_seq;
+    Expr in_warp_cond = reduce_index < (reduce_align >> 1);
+    while (reduce_align > 1) {
+      reduce_align = reduce_align >> 1;
+      in_warp_seq.emplace_back(freduce(reduce_align));
+    }
+    if (in_warp_seq.size() != 0) {
+      Stmt warp_body = MergeSeq(in_warp_seq);
+      seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
+    }
+    return MergeSeq(seq);
+  }
+  // Flatten the thread index.
+  // Also return a warp number,
+  Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
+                     int* out_total_extent) {
+    int& total_extent = *out_total_extent;
+    total_extent = 1;
+    if (tvec.size() == 0) {
+      return make_zero(Int(32));
+    }
+
+    Expr ret;
+    for (const ThreadEntry& e : tvec) {
+      if (ret.defined()) {
+        ret = ret + e.iv->var * total_extent;
+      } else {
+        CHECK_EQ(total_extent, 1);
+        ret = e.iv->var;
+      }
+      total_extent *= e.extent;
+    }
+    return ret;
+  }
+  // sync thread op.
+  static Stmt SyncThread() {
+    return Evaluate::make(
+        Call::make(Int(32), intrinsic::tvm_storage_sync,
+                   {StringImm::make("shared")},
+                   Call::Intrinsic));
+  }
+  // The local buffer index.
+  static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
+    if (!is_zero(group_index)) {
+      return ir::Simplify(group_index * reduce_extent + reduce_index);
+    } else {
+      return reduce_index;
+    }
+  }
+  // The warp size of the device.
+  int warp_size_{1};
+  // surrounding scope of thread extent.
+  std::vector<const AttrStmt*> thread_extents_;
+  // The load remap
+  std::unordered_map<const Variable *, Expr> load_remap_;
+  // Allocate remap
+  std::unordered_map<const Variable *, Stmt> alloc_remap_;
+};
+
+LoweredFunc
+LowerThreadAllreduce(LoweredFunc f, int warp_size) {
+  auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
+  n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body);
+  return LoweredFunc(n);
+}
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h
index da623567b..b523b9318 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -112,18 +112,10 @@ class ThreadAxisConfig {
       arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
       filled[ts.rank * 3 + ts.dim_index] = true;
     }
-    work_dim_ = 3;
+    work_dim_ = 1;
     for (int i = 0; i < 3; ++i) {
-      if (!filled[i]) {
-        for (int j = i; j < 3; ++j) {
-          CHECK(!filled[j] && !filled[j + 3])
-              << "Invalid thread group configuration";
-        }
-        work_dim_ = i;
-        break;
-      } else {
-        CHECK(filled[i])
-            << "Must have both threadIdx and blockIdx";
+      if (filled[i] || filled[i + 3]) {
+        work_dim_ = i + 1;
       }
     }
   }
diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc
index 68d28df2c..51d1aa229 100644
--- a/src/schedule/message_passing.cc
+++ b/src/schedule/message_passing.cc
@@ -75,8 +75,16 @@ void PassDownDomain(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      state[r->rebased] = Range::make_with_min_extent(
-          0, state.at(r->parent)->extent);
+      Range res = Range::make_with_min_extent(
+            0, state.at(r->parent)->extent);
+      if (r->rebased->dom.defined()) {
+        Range rebase_rng = r->rebased->dom;
+        bool match = is_zero(rebase_rng->min);
+        if (!prove_equal(rebase_rng->extent, res->extent)) match = false;
+        CHECK(match) << r->rebased
+                     << " does not match parent scope's range";
+      }
+      state[r->rebased] = res;
     } else {
       LOG(FATAL) << "unknown relation type";
     }
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index b577f0a43..9545a35cd 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -305,8 +305,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
     }
   }
   // predicate generation, copy not touched axis.
+  const Reduce* reduce = compute_op->body.as<Reduce>();
+  CHECK(reduce) << "Can only rfactor non-inline reductions";
+  Expr predicate = reduce->condition;
   std::unordered_map<const Variable*, Expr> vsub;
-  Expr predicate;
   for (IterVar iv : compute_op->reduce_axis) {
     if (!touch_map.count(iv)) {
       n->reduce_axis.push_back(iv);
@@ -316,10 +318,10 @@ Tensor Schedule::rfactor(const Tensor& tensor,
       vsub[iv->var.get()] = index;
       if (!index.same_as(iv->var)) {
         Expr cond = (index < dom_map.at(iv)->extent);
-        if (predicate.defined()) {
-          predicate = predicate && cond;
-        } else {
+        if (is_one(predicate)) {
           predicate = cond;
+        } else {
+          predicate = predicate && cond;
         }
       }
     }
@@ -333,8 +335,6 @@ Tensor Schedule::rfactor(const Tensor& tensor,
       n->reduce_axis.push_back(IterVar(ncpy));
     }
   }
-  const Reduce* reduce = compute_op->body.as<Reduce>();
-  CHECK(reduce) << "Can only rfactor non-inline reductions";
   n->body = Reduce::make(reduce->op,
                          VarReplacer(vsub).Mutate(reduce->source),
                          n->reduce_axis,
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index 318e9b057..723588e20 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -136,6 +136,25 @@ Stage& Stage::compute_root() {   // NOLINT(*)
   return *this;
 }
 
+Stage& Stage::rebase(IterVar parent, IterVar rebased) {  // NOLINT(*)
+  CHECK(parent->iter_type == kDataPar ||
+        parent->iter_type == kCommReduce)
+      << "Cannot rebase " << IterVarType2String(parent->iter_type);
+  CHECK(rebased->iter_type == kThreadIndex)
+      << "Cannot rebase by " << IterVarType2String(rebased->iter_type)
+      << ", only thread axis is allowed so far";
+  ArrayNode* all_vars = (*this)->all_iter_vars.CopyOnWrite();
+  ArrayNode* leaf_vars = (*this)->leaf_iter_vars.CopyOnWrite();
+  size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
+  (*this)->relations.push_back(RebaseNode::make(parent, rebased));
+  // add vars to all vars
+  all_vars->data.push_back(rebased.node_);
+  // replace the position.
+  leaf_vars->data.erase(leaf_vars->data.begin() + pos);
+  leaf_vars->data.insert(leaf_vars->data.begin() + pos, rebased.node_);
+  return *this;
+}
+
 Stage& Stage::split(
     IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) {  // NOLINT(*)
   CheckSplit(operator->(), parent, IterVar());
diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py
index 726cd3f11..fbb3c9f10 100644
--- a/tests/python/integration/test_reduce.py
+++ b/tests/python/integration/test_reduce.py
@@ -51,7 +51,7 @@ def test_rfactor():
     n = tvm.convert(1027)
     A = tvm.placeholder((n,), name='A')
     k = tvm.reduce_axis((0, n))
-    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k, where=(i>1)), name='B')
+    B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B')
     kf = tvm.reduce_axis((0, 4))
     # schedule
     s = tvm.Schedule(B.op)
@@ -78,6 +78,56 @@ def test_rfactor():
 
     check_target()
 
+
+def test_rfactor_threads():
+    nn = 1027
+    mm = 10
+    n = tvm.convert(nn)
+    m = tvm.convert(mm)
+    A = tvm.placeholder((m, n), name='A')
+    k = tvm.reduce_axis((0, n))
+    nthread = 16
+    B = tvm.compute((m,), lambda i: tvm.sum(A[i, k], axis=k, where=(i>1)), name='B')
+    tx = tvm.thread_axis((0, nthread), "threadIdx.x")
+    ty = tvm.thread_axis((0, nthread), "threadIdx.y")
+    bx = tvm.thread_axis(None, "blockIdx.x")
+    # schedule
+    s = tvm.Schedule(B.op)
+    ko, kf = s[B].split(k, factor=nthread)
+    BF = s.rfactor(B, kf)
+    xo, xi = s[B].split(s[B].op.axis[0], factor=nthread, outer=bx)
+    s[B].rebase(xi, ty)
+    s[B].rebase(s[B].op.reduce_axis[0], tx)
+    s[BF].compute_at(s[B], tx)
+
+    # one line to build the function.
+    def check_target(device, host="stackvm"):
+        if not tvm.codegen.enabled(device):
+            return
+        ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
+        fapi = tvm.lower(s, args=[A, B])
+        fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32)
+        fsum = tvm.build(fapi,
+                         target=device,
+                         name="mysum")
+        print(fsum.imported_modules[0].get_source())
+        # launch the kernel.
+        n = nn
+        m = mm
+        a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
+        b  = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
+        fsum(a, b)
+        res = np.sum(a.asnumpy(), axis=1)
+        res[:2] = 0
+        np.testing.assert_allclose(
+            b.asnumpy(), res, rtol=1e-4)
+
+    if tvm.module.enabled("opencl"):
+        tvm.module.init_opencl()
+    check_target("cuda")
+    check_target("opencl")
+
 if __name__ == "__main__":
+    test_rfactor_threads()
     test_rfactor()
     test_sum()
diff --git a/tests/verilog/test_buffer_doublebuff.py b/tests/verilog/unittest/test_buffer_doublebuff.py
similarity index 89%
rename from tests/verilog/test_buffer_doublebuff.py
rename to tests/verilog/unittest/test_buffer_doublebuff.py
index e0439d9c9..7d8cb1d98 100644
--- a/tests/verilog/test_buffer_doublebuff.py
+++ b/tests/verilog/unittest/test_buffer_doublebuff.py
@@ -35,11 +35,11 @@ def test_buffer_doublebuff():
     write_data.put_int(0)
 
     # De-assert reset
-    sess.yield_until_posedge()
+    sess.yield_until_next_cycle()
     rst.put_int(0)
 
     # Leave the following signals set to true
-    sess.yield_until_posedge()
+    sess.yield_until_next_cycle()
     write_valid.put_int(1)
 
     # Main simulation loop
@@ -50,15 +50,15 @@ def test_buffer_doublebuff():
         if (write_idx < len(test_data)):
             write_advance.put_int(0)
             if (write_ready.get_int()):
-                write_data.put_int(test_data[write_idx])
-                write_addr.put_int(write_idx%window_width)
+                write_data.put_int(int(test_data[write_idx]))
+                write_addr.put_int(write_idx % window_width)
                 if (write_idx%window_width==window_width-1):
                     write_advance.put_int(1)
                 write_idx += 1
         else:
             write_advance.put_int(0)
             write_valid.put_int(0)
-            
+
         # correctness checks
         if (read_data_valid.get_int()):
             assert(read_data.get_int()==test_data[read_idx])
@@ -66,7 +66,7 @@ def test_buffer_doublebuff():
             read_idx += 1
 
         # step
-        sess.yield_until_posedge()
+        sess.yield_until_next_cycle()
 
 
 if __name__ == "__main__":
diff --git a/tests/verilog/test_buffer_doublebuff.v b/tests/verilog/unittest/test_buffer_doublebuff.v
similarity index 100%
rename from tests/verilog/test_buffer_doublebuff.v
rename to tests/verilog/unittest/test_buffer_doublebuff.v
diff --git a/tests/verilog/test_buffer_fifo.py b/tests/verilog/unittest/test_buffer_fifo.py
similarity index 94%
rename from tests/verilog/test_buffer_fifo.py
rename to tests/verilog/unittest/test_buffer_fifo.py
index 3255ceafb..f95fe9796 100644
--- a/tests/verilog/test_buffer_fifo.py
+++ b/tests/verilog/unittest/test_buffer_fifo.py
@@ -27,7 +27,7 @@ def test_buffer_fifo():
     write_data.put_int(0)
 
     # De-assert reset
-    sess.yield_until_posedge()
+    sess.yield_until_next_cycle()
     rst.put_int(0)
 
     # Main simulation loop
@@ -46,7 +46,7 @@ def test_buffer_fifo():
             assert(read_data.get_int()==test_data[read_idx])
             read_idx += 1
         # step
-        sess.yield_until_posedge()
+        sess.yield_until_next_cycle()
 
 
 if __name__ == "__main__":
diff --git a/tests/verilog/test_buffer_fifo.v b/tests/verilog/unittest/test_buffer_fifo.v
similarity index 100%
rename from tests/verilog/test_buffer_fifo.v
rename to tests/verilog/unittest/test_buffer_fifo.v
diff --git a/tests/verilog/test_buffer_linebuff.py b/tests/verilog/unittest/test_buffer_linebuff.py
similarity index 92%
rename from tests/verilog/test_buffer_linebuff.py
rename to tests/verilog/unittest/test_buffer_linebuff.py
index da01f3fc0..b4d2b34c1 100644
--- a/tests/verilog/test_buffer_linebuff.py
+++ b/tests/verilog/unittest/test_buffer_linebuff.py
@@ -33,11 +33,11 @@ def test_buffer_linebuff():
     write_data.put_int(0)
 
     # De-assert reset
-    sess.yield_until_posedge()
+    sess.yield_until_next_cycle()
     rst.put_int(0)
 
     # Leave the following signals set to true
-    sess.yield_until_posedge()
+    sess.yield_until_next_cycle()
     write_advance.put_int(1)
     write_valid.put_int(1)
 
@@ -48,12 +48,12 @@ def test_buffer_linebuff():
         # write logic
         if (write_idx < len(test_data)):
             if (write_ready.get_int()):
-                write_data.put_int(test_data[write_idx])
+                write_data.put_int(int(test_data[write_idx]))
                 write_idx += 1
         else:
             write_advance.put_int(0)
             write_valid.put_int(0)
-            
+
         # correctness checks
         if (read_data_valid.get_int()):
             # Derive convolution window indices
@@ -67,7 +67,7 @@ def test_buffer_linebuff():
             read_idx += 1
 
         # step
-        sess.yield_until_posedge()
+        sess.yield_until_next_cycle()
 
 
 if __name__ == "__main__":
diff --git a/tests/verilog/test_buffer_linebuff.v b/tests/verilog/unittest/test_buffer_linebuff.v
similarity index 100%
rename from tests/verilog/test_buffer_linebuff.v
rename to tests/verilog/unittest/test_buffer_linebuff.v
-- 
GitLab