diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index f2f47b8cb4b0a68167a8616cce2fe1564c9fcda1..834b1baf364faa8b080b89defd3a315191ad8b28 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -47,23 +47,27 @@ struct CommReducer : public NodeRef {
  *  binary operator with identity element
  */
 struct CommReducerNode : public Node {
-  /*! \brief The arguments of reducer */
-  Array<Var> args;
+  /*! \brief The left argument of reducer */
+  Array<Var> lhs;
+  /*! \brief The right argument of reducer */
+  Array<Var> rhs;
   /*! \brief The result of reducer */
-  Expr result;
+  Array<Expr> result;
   /*!
    * \brief The identity element of reducer, which leaves other
    *  elements unchanged when combined with it, with respect to
    *  the binary operation of this reducer uses.
    */
-  Expr identity_element;
+  Array<Expr> identity_element;
   /*! \brief Function call operator to combine a and b */
-  Expr operator()(Expr a, Expr b) const;
+  Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
   /*! \brief construct CommReducer from args, result and identity_element */
-  static CommReducer make(Array<Var> args, Expr result, Expr identity_element);
+  static CommReducer make(Array<Var> lhs, Array<Var> rhs,
+                          Array<Expr> result, Array<Expr> identity_element);
 
   void VisitAttrs(AttrVisitor* v) final {
-    v->Visit("args", &args);
+    v->Visit("lhs", &lhs);
+    v->Visit("rhs", &rhs);
     v->Visit("result", &result);
     v->Visit("identity_element", &identity_element);
   }
@@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> {
   /*! \brief The commutative combiner */
   CommReducer combiner;
   /*! \brief The source operand */
-  Expr source;
+  Array<Expr> source;
   /*! \brief The reduction axis */
   Array<IterVar> axis;
   /*!
@@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> {
    *  Only add the body to reduction if condition is true.
    */
   Expr condition;
+  /*! \brief the index of this reduce node */
+  int value_index;
 
   /*! \brief construct expr from op and rdom */
   static Expr make(CommReducer combiner,
-                   Expr src,
+                   Array<Expr> src,
                    Array<IterVar> rdom,
-                   Expr condition = const_true());
+                   Expr condition,
+                   int value_index);
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("dtype", &type);
     v->Visit("source", &source);
     v->Visit("axis", &axis);
     v->Visit("condition", &condition);
+    v->Visit("value_index", &value_index);
   }
   static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
   static constexpr const char* _type_key = "Reduce";
@@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
 /*!
  * \brief See pesudo code
  *
- *  Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond,
- *                             Var thread_idx1, thread_idx2...) {
+ *  void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond,
+ *                            Var reduce_temp0, .., Var thread_idx1, ...) {
  *     // constraint by the other thread_idx remain the same.
- *     return reduce(combiner, value, cond,
- *                   over [thread_idx1, thread_idx2] passed by any caller)
+ *     // reduce_temp is used to save intermediate result.
+ *     reduce_temp0, ... = reduce(combiner, source0, ..., cond
+ *       over [thread_idx1, thread_idx2] passed by any caller)
  *  }
  */
 constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 7e323870413768ef2df909cae90413beff828812..fc0bc1f1abd28930ccaf8e531f1bbc0ea59d214d 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
 /*!
  * \brief inline all calls of f in stmt.
  *
+ * \param stmt The statement to apply inline optimization.
  * \param f The function reference to be inlined
  * \param args The arguments variable of the function.
- * \param body The defintion body of the function.
- * \param stmt The statement to apply inline optimization.
+ * \param body The definition body of the function.
  * \return The result stmt
  *
  * \note All the passes in this file uses SSA form and outputs SSA form.
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index eb0ee37569f1c28238461234d90644c0234d4677..0533bdcea6fba8c25eb81d9bfa43435cc64486cc 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode {
   /*! \brief IterVar on each reduction axis, if the body is a Reduce */
   Array<IterVar> reduce_axis;
   /*! \brief the compute expression */
-  Expr body;
+  Array<Expr> body;
   /*! \brief constructor */
   ComputeOpNode() {}
   // override functions
@@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode {
   }
   static Operation make(std::string name,
                         Array<IterVar> axis,
-                        Expr body);
+                        Array<Expr> body);
 
   static constexpr const char* _type_key = "ComputeOp";
   TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
@@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode {
 /*! \brief The compute function to specify the input source of a Tensor */
 using FCompute = std::function<Expr (const Array<Var>& i)>;
 
+/*! \brief The compute function to specify the inputs source of Tensors */
+using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
+
 /*!
  * \brief create a place holder tensor.
  * \param shape The shape of the tensor.
@@ -377,6 +380,15 @@ Tensor placeholder(Array<Expr> shape,
  */
 Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor");
 
+/*!
+ * \brief Construct a new tensor by computing over shape,
+ *  using the computation rule: result_tensor[axis] = fcompute(axis)
+ * \param shape Shape of the tensor.
+ * \param fcompute The compute function to create the tensors.
+ * \param name The optional name of the tensor.
+ */
+Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor");
+
 /*!
  * \brief Construct new tensors by scan.
  *
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 4479c4fbee80d0f3e85685d949f3e45055802735..9f8a4bd51f2f99124644d8e7a85e685611076c9a 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -252,15 +252,15 @@ class Schedule : public NodeRef {
   /*!
    * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
    * This will create a new stage that generated the new tensor with axis
-   * as the first dimension. The tensor's body wil be rewriten as a reduction
+   * as the first dimension. The tensor's body will be rewritten as a reduction
    * over the factored tensor.
    *
    * \param tensor The tensor to be factored.
    * \param axis The reduction axis in tensor's schedule to be factored.
-   * \return The created factored tensor.
+   * \return The created factored tensors.
    */
-  Tensor rfactor(const Tensor& tensor,
-                 const IterVar& axis);
+  Array<Tensor> rfactor(const Tensor& tensor,
+                        const IterVar& axis);
   /*!
    * \brief Normalize the schedule.
    *  This is needed before bound inference.
diff --git a/python/tvm/api.py b/python/tvm/api.py
index 2ef18d210342a2429f193b44f675d594de9d6390..7ea6a8e81e6ba036ecb73ffcaba5cd1d543a966f 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"):
 
     dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)]
     body = fcompute(*[v.var for v in dim_var])
+    if not isinstance(body, (list, tuple)):
+        body = [body]
     body = convert(body)
     op_node = _api_internal._ComputeOp(
         name, dim_var, body)
-    return op_node.output(0)
+    num = op_node.num_outputs
+    outputs = tuple(op_node.output(i) for i in range(num))
+    return outputs[0] if num == 1 else outputs
 
 
 def scan(init, update, state_placeholder, inputs=None, name="scan"):
@@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
         return res
 
     def _make_reduce(expr, axis, where=None):
-        expr = convert(expr)
-        dtype = expr.dtype
         code = fcombine.__code__
         assert fcombine.__code__.co_argcount == 2
-        arg_vars = [var(name, dtype) for name in code.co_varnames]
-        result = fcombine(*[v for v in arg_vars])
+        expr = convert(expr)
+        if isinstance(expr, _collections.Array):
+            size = len(expr)
+            larr = []
+            rarr = []
+            dtypes = []
+            for i in range(size):
+                dtype = expr[i].dtype
+                dtypes.append(dtype)
+                lname = code.co_varnames[0] + '_' + str(i)
+                larr.append(var(lname, dtype))
+                rname = code.co_varnames[1] + '_' + str(i)
+                rarr.append(var(rname, dtype))
+            lhs = convert(larr)
+            rhs = convert(rarr)
+            result = fcombine(lhs, rhs)
+            id_elem = fidentity(*dtypes)
+        else:
+            assert isinstance(expr, _expr.Expr)
+            size = 1
+            dtype = expr.dtype
+            lvar = var(code.co_varnames[0], dtype)
+            rvar = var(code.co_varnames[1], dtype)
+            result = [fcombine(lvar, rvar)]
+            id_elem = [fidentity(dtype)]
+            lhs = convert([lvar])
+            rhs = convert([rvar])
+            expr = convert([expr])
         result = convert(result)
-        id_elem = fidentity(dtype)
-        assert isinstance(id_elem, _expr.Expr)
-        combiner = _make.CommReducer(arg_vars, result, id_elem)
-        axis = axis if isinstance(axis, list) else [axis]
-        return _make.Reduce(combiner, expr, axis, where)
+        id_elem = convert(id_elem)
+        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
+        axis = convert(axis if isinstance(axis, list) else [axis])
+        if where is None:
+            where = convert(True)
+        outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
+                        for i in range(size))
+        return outputs[0] if size == 1 else outputs
 
     def reducer(expr, axis, where=None, *args):
         if isinstance(axis, (_schedule.IterVar, list)):
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 8b2e63f9c884079423a9e318bcd45a0eec871269..e9c8a179c95f4252ecd9838d2a5b582b39df9dbc 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -181,7 +181,7 @@ class Schedule(NodeBase):
         """ Factor a reduction axis in tensor's schedule to be an explicit axis.
 
         This will create a new stage that generated the new tensor with axis
-        as the first dimension. The tensor's body wil be rewriten as a reduction
+        as the first dimension. The tensor's body will be rewritten as a reduction
         over the factored tensor.
 
         Parameters
@@ -193,10 +193,11 @@ class Schedule(NodeBase):
 
         Returns
         -------
-        tfactor : Tensor
+        tfactor : Tensor or Array of Tensor
             The created factored tensor.
         """
-        return _api_internal._ScheduleRFactor(self, tensor, axis)
+        factored = _api_internal._ScheduleRFactor(self, tensor, axis)
+        return factored[0] if len(factored) == 1 else factored
 
 
 @register_node
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
index 00ab79b19167fe1b2fd92b567fb7e17b9b2b8370..f66652b99157d5a29eaade5f7e8caef46f576a27 100644
--- a/src/api/api_ir.cc
+++ b/src/api/api_ir.cc
@@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call")
   });
 
 TVM_REGISTER_API("make.CommReducer")
-.set_body([](TVMArgs args,  TVMRetValue *ret) {
-    *ret = CommReducerNode::make(args[0], args[1], args[2]);
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    *ret = CommReducerNode::make(args[0],
+                                 args[1],
+                                 args[2],
+                                 args[3]);
   });
 
-
 // make from two arguments
 #define REGISTER_MAKE1(Node)                                 \
   TVM_REGISTER_API("make."#Node)                             \
@@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer")
       *ret = Node::make(a, b);                               \
     })
 
-REGISTER_MAKE4(Reduce);
+REGISTER_MAKE5(Reduce);
 REGISTER_MAKE4(AttrStmt);
 
 REGISTER_MAKE2(IntImm);
diff --git a/src/lang/expr.cc b/src/lang/expr.cc
index 795f34fe673fcd61a7b0e7b61cce79ce59b2561d..9e0feb44479f90482b6070b55171b38ce44a1577 100644
--- a/src/lang/expr.cc
+++ b/src/lang/expr.cc
@@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) {
   Var x("x"), y("y");
   Expr result = ir::Add::make(x, y);
   Expr identity_element = make_zero(source.type());
-  ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
-  return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
 Expr max(Expr source, Array<IterVar> rdom) {
   Var x("x"), y("y");
   Expr result = ir::Max::make(x, y);
   Expr identity_element = source.type().min();
-  ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
-  return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
 Expr min(Expr source, Array<IterVar> rdom) {
   Var x("x"), y("y");
   Expr result = ir::Min::make(x, y);
   Expr identity_element = source.type().max();
-  ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element);
-  return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true));
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
 std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
diff --git a/src/lang/ir.cc b/src/lang/ir.cc
index 52ba225253dae7f08f76d2852b060ac323c2c24a..e7903333562f9fd3f640f74a1298b775ec465be1 100644
--- a/src/lang/ir.cc
+++ b/src/lang/ir.cc
@@ -9,6 +9,7 @@
 #include <ir/IR.h>
 #include <ir/IRPrinter.h>
 #include <memory>
+#include "../pass/ir_util.h"
 
 namespace Halide {
 namespace Internal {
@@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
   p->stream << "reduce(combiner="
-            << op->combiner
-            << ", ";
-  p->print(op->source);
+            << op->combiner;
+  p->stream << ", source=" << op->source;
   p->stream << ", axis=" << op->axis;
-  if (!is_const(op->condition, 1)) {
-    p->stream << ", where=" << op->condition;
-  }
+  p->stream << ", where=" << op->condition;
+  p->stream << ", value_index=" << op->value_index;
   p->stream << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
-  p->stream << "comm_reducer(result="
-            << op->result
-            << ", args=" << op->args
-            << ", identity_element="
-            << op->identity_element
+  p->stream << "comm_reducer(result=" << op->result
+            << ", lhs=" << op->lhs
+            << ", rhs=" << op->rhs
+            << ", identity_element=" << op->identity_element
             << ")";
 });
 }  // namespace Internal
@@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 namespace tvm {
 namespace ir {
 
-CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) {
+CommReducer CommReducerNode::make(Array<Var> lhs,
+                                  Array<Var> rhs,
+                                  Array<Expr> result,
+                                  Array<Expr> identity_element) {
   auto node = std::make_shared<CommReducerNode>();
-  node->args   = args;
+  node->lhs = lhs;
+  node->rhs = rhs;
   node->result = result;
   node->identity_element = identity_element;
   return CommReducer(node);
 }
 
-Expr CommReducerNode::operator()(Expr a, Expr b) const {
+Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
+  CHECK_EQ(a.size(), b.size());
+  CHECK_EQ(lhs.size(), a.size());
+  CHECK_EQ(rhs.size(), b.size());
   Map<Var, Expr> value_map;
-  value_map.Set(args[0], a);
-  value_map.Set(args[1], b);
-  return Substitute(result, value_map);
+  for (size_t i = 0; i < a.size(); ++i) {
+    value_map.Set(lhs[i], a[i]);
+    value_map.Set(rhs[i], b[i]);
+  }
+  return UpdateArray(result, [&value_map] (const Expr& e) {
+      return Substitute(e, value_map);
+    });
 }
 
-Expr Reduce::make(CommReducer combiner, Expr source,
-                  Array<IterVar> axis, Expr condition) {
+Expr Reduce::make(CommReducer combiner, Array<Expr> source,
+                  Array<IterVar> axis, Expr condition, int value_index) {
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK_EQ(axis[i]->iter_type, kCommReduce)
         << "Can only take axis created by reduce_axis";
@@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source,
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK(axis[i].defined());
   }
-  n->type = source.type();
-  n->combiner = combiner;
-  n->source = source;
-  n->axis = axis;
+  n->type = source[value_index].type();
+  n->combiner = std::move(combiner);
+  n->source = std::move(source);
+  n->axis = std::move(axis);
   n->condition = condition;
+  n->value_index = value_index;
   return Expr(n);
 }
 
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index a2d3b25e25e080423fb006cdf3bdcdb3a03c0886..be594a6b6e4a2d2433d4de14d779025e77483128 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -24,7 +24,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 
 int ComputeOpNode::num_outputs() const {
-  return 1;
+  return body.size();
 }
 
 Array<IterVar> ComputeOpNode::root_iter_vars() const {
@@ -36,13 +36,14 @@ Array<IterVar> ComputeOpNode::root_iter_vars() const {
   return ret;
 }
 
-Type ComputeOpNode::output_dtype(size_t i) const {
-  CHECK_EQ(i, 0U);
-  return body.type();
+Type ComputeOpNode::output_dtype(size_t idx) const {
+  CHECK_LT(idx, num_outputs());
+  return body[idx].type();
 }
 
-Array<Expr> ComputeOpNode::output_shape(size_t i) const {
-  CHECK_EQ(i, 0U);
+Array<Expr> ComputeOpNode::output_shape(size_t idx) const {
+  CHECK_LT(idx, num_outputs());
+  // for now, all outputs of ComputeOp have the same shape
   std::vector<Expr> shape;
   for (size_t i = 0; i < axis.size(); ++i) {
     const Range& r = axis[i]->dom;
@@ -65,18 +66,55 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
     args.push_back(axis.back()->var);
   }
 
-  return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
+  return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0);
+}
+
+Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name) {
+  auto op_node = std::make_shared<ComputeOpNode>();
+  // compute dimension.
+  size_t ndim = shape.size();
+  std::vector<IterVar> axis;
+  std::vector<Var> args;
+  for (size_t i = 0; i < ndim; ++i) {
+    std::ostringstream os;
+    os << "ax" << i;
+    axis.emplace_back(IterVarNode::make(
+        Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
+    args.push_back(axis.back()->var);
+  }
+
+  Operation op = ComputeOpNode::make(name, axis, fcompute(args));
+  Array<Tensor> outputs;
+  for (int idx = 0; idx < op->num_outputs(); ++idx) {
+    outputs.push_back(op.output(idx));
+  }
+  return outputs;
+}
+
+bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+  return (a->combiner.same_as(b->combiner)) &&
+         (a->source.same_as(b->source)) &&
+         (a->axis.same_as(b->axis)) &&
+         (a->condition.same_as(b->condition));
 }
 
 Operation ComputeOpNode::make(std::string name,
                               Array<IterVar> axis,
-                              Expr body) {
+                              Array<Expr> body) {
   auto n = std::make_shared<ComputeOpNode>();
   n->name = name;
   n->axis = axis;
   n->body = body;
-  if (n->body->is_type<ir::Reduce>()) {
-    n->reduce_axis = n->body.as<ir::Reduce>()->axis;
+  if (n->body[0]->is_type<ir::Reduce>()) {
+    const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
+    for (size_t i = 1; i < n->body.size(); ++i) {
+      const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>();
+      CHECK(reduce_);
+      CHECK(ReduceEqual(reduce_, reduce))
+        << "The Reduce inputs of ComputeOp should "
+        << "have the same attribute except value_index";
+    }
+    n->reduce_axis = reduce->axis;
   }
   return Operation(n);
 }
@@ -85,16 +123,18 @@ Operation ComputeOpNode::make(std::string name,
 Array<Tensor> ComputeOpNode::InputTensors() const {
   Array<Tensor> ret;
   std::unordered_set<Tensor> visited;
-  ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) {
-      const ir::Call *call = n.as<ir::Call>();
-      if (call != nullptr && call->func.defined()) {
-        Tensor t = Operation(call->func.node_).output(call->value_index);
-        if (!visited.count(t)) {
-          ret.push_back(t);
-          visited.insert(t);
+  for (auto& e : body) {
+    ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
+        const ir::Call *call = n.as<ir::Call>();
+        if (call != nullptr && call->func.defined()) {
+          Tensor t = Operation(call->func.node_).output(call->value_index);
+          if (!visited.count(t)) {
+            ret.push_back(t);
+            visited.insert(t);
+          }
         }
-      }
-    });
+      });
+  }
   return ret;
 }
 
@@ -102,9 +142,11 @@ Operation ComputeOpNode::ReplaceInputs(
     const Operation& self,
     const std::unordered_map<Tensor, Tensor>& rmap) const {
   CHECK_EQ(self.operator->(), this);
-  Expr new_body = op::ReplaceTensor(this->body, rmap);
-  if (!new_body.same_as(this->body)) {
-    return ComputeOpNode::make(name, axis, new_body);
+  Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
+      return op::ReplaceTensor(e, rmap);
+    });
+  if (!arr.same_as(this->body)) {
+    return ComputeOpNode::make(name, axis, arr);
   } else {
     return self;
   }
@@ -127,7 +169,7 @@ void ComputeOpNode::PropBoundToInputs(
       }
     }
   };
-  ir::PostOrderVisit(body, fvisit);
+  for (auto& e : body) ir::PostOrderVisit(e, fvisit);
 }
 
 void ComputeOpNode::GatherBound(
@@ -151,34 +193,50 @@ Stmt ComputeOpNode::BuildRealize(
     const std::unordered_map<IterVar, Range>& realize_map,
     const Stmt& realize_body) const {
   CHECK_EQ(self.operator->(), this);
-  Tensor t = self.output(0);
   Halide::Internal::Region bounds;
   for (IterVar iv : this->axis) {
     bounds.push_back(realize_map.at(iv));
   }
-  return ir::Realize::make(t->op, t->value_index, t->dtype,
-                           bounds, const_true(), realize_body);
+  Stmt realize = realize_body;
+  for (int i = self->num_outputs(); i > 0; --i) {
+    Tensor t = self.output(i-1);
+    realize = ir::Realize::make(t->op, t->value_index,
+      t->dtype, bounds, const_true(), realize);
+  }
+  return realize;
 }
 
 // Build a reduction body.
 void MakeReduction(const ComputeOpNode* op,
-                   const Tensor& t,
+                   const Array<Tensor>& tensors,
                    Stmt* init,
                    Stmt* provide) {
-  Stmt no_op = Evaluate::make(0);
-  std::vector<Stmt> nest;
   Array<Expr>  args;
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
-  const Reduce* reduce = op->body.as<Reduce>();
+  std::vector<Stmt> inits, provides;
+
+  size_t size = op->body.size();
+  const Reduce* reduce = op->body[0].as<Reduce>();
   CHECK(reduce);
   const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
   CHECK(combiner);
-  Expr init_value = combiner->identity_element;
-  Expr update_value = (*combiner)(t(args), reduce->source);
-  *init = Provide::make(t->op, t->value_index, init_value, args);
-  *provide = Provide::make(t->op, t->value_index, update_value, args);
+  Array<Expr> lhs;
+  for (size_t i = 0; i < size; ++i) {
+    lhs.push_back(tensors[i](args));
+  }
+  Array<Expr> init_value = combiner->identity_element;
+  Array<Expr> update_value = (*combiner)(lhs, reduce->source);
+  for (size_t i = 0; i < size; ++i) {
+    Tensor t = tensors[i];
+    inits.emplace_back(Provide::make(
+          t->op, t->value_index, init_value[i], args));
+    provides.emplace_back(Provide::make(
+          t->op, t->value_index, update_value[i], args));
+  }
+  *init = Block::make(inits);
+  *provide = Block::make(provides);
   if (!is_one(reduce->condition)) {
     *provide = IfThenElse::make(reduce->condition, *provide);
   }
@@ -225,22 +283,36 @@ Stmt MakeCrossThreadReduction(
   for (IterVar iv : self->axis) {
     args.push_back(iv->var);
   }
-  const Reduce* reduce = self->body.as<Reduce>();
-  CHECK(reduce);
   std::unordered_map<IterVar, Expr> value_map;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
   auto conds = op::MakeBoundCheck(
       stage, dom_map, false,
       std::unordered_set<IterVar>(), value_map);
-  Expr cond = reduce->condition;
+
+  size_t size = self->body.size();
+  CHECK_GT(size, 0);
+  std::vector<const Reduce*> reduces(size);
+  for (size_t i = 0; i < size; ++i) {
+    const Reduce* reduce = self->body[i].as<Reduce>();
+    CHECK(reduce);
+    reduces[i] = reduce;
+  }
+  Expr cond = reduces[0]->condition;
   for (Expr v : conds) {
     cond = cond && v;
   }
-  Var res_handle("reduce_temp", Handle());
   Array<Expr> freduce_args;
-  freduce_args.push_back(reduce->source);
+  freduce_args.push_back(make_const(UInt(32), size));
+  for (size_t i = 0; i < size; ++i) {
+    freduce_args.push_back(reduces[0]->source[i]);
+  }
   freduce_args.push_back(cond);
+  std::vector<Var> res_handles(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
+    freduce_args.push_back(res_handles[idx]);
+  }
 
   for (IterVar iv : stage->leaf_iter_vars) {
     if (iv->iter_type == kCommReduce) {
@@ -257,28 +329,33 @@ Stmt MakeCrossThreadReduction(
   if (stage->store_predicate.defined()) {
     thread_head_check.emplace_back(stage->store_predicate);
   }
-  Type t = reduce->type;
-  Expr pred = const_true(t.lanes());
-  Stmt reduce_body = Store::make(res_handle,
-    Call::make(
-      reduce->type,
+
+  Stmt reduce_body = Evaluate::make(Call::make(
+      Handle(),
       ir::intrinsic::tvm_thread_allreduce,
-      freduce_args, Call::Intrinsic),
-     0, pred);
+      freduce_args, Call::Intrinsic));
   reduce_body = AttrStmt::make(
-      reduce->combiner,
+      reduces[0]->combiner,
       attr::reduce_scope,
-      make_zero(reduce->type),
+      make_zero(Handle()),
       reduce_body);
-  Stmt assign_body = Provide::make(
-      stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
+  std::vector<Stmt> assigns(size);
+  for (size_t idx = 0; idx < size; ++idx) {
+    Type t = reduces[idx]->type;
+    assigns[idx] = Provide::make(
+      stage->op, idx,
+      Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+  }
+  Stmt assign_body = Block::make(assigns);
   assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
   assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
-  Stmt body = Allocate::make(
-      res_handle, reduce->type, {1}, const_true(),
-      Block::make(reduce_body, assign_body));
-  body = AttrStmt::make(
-      res_handle, attr::storage_scope, StringImm::make("local"), body);
+  Stmt body = Block::make(reduce_body, assign_body);
+  for (size_t idx = size; idx != 0; --idx) {
+    body = Allocate::make(
+      res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
+    body = AttrStmt::make(
+      res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
+  }
   body = Substitute(body, value_map);
   return MergeNest(nest, body);
 }
@@ -289,7 +366,7 @@ Stmt MakeProvide(const ComputeOpNode* op,
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
-  return Provide::make(t->op, t->value_index, op->body, args);
+  return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
 }
 
 Stmt ComputeOpNode::BuildProvide(
@@ -301,12 +378,24 @@ Stmt ComputeOpNode::BuildProvide(
     // specially handle cross thread reduction.
     return MakeCrossThreadReduction(this, stage, dom_map);
   }
-  Stmt init, provide;
+
+  size_t size = this->body.size();
+  Stmt init;
+  Stmt provide;
   if (this->reduce_axis.size() == 0) {
-    provide = MakeProvide(this, stage->op.output(0));
+    std::vector<Stmt> provides;
+    for (size_t i = 0; i < size; ++i) {
+      provides.emplace_back(MakeProvide(this, stage->op.output(i)));
+    }
+    provide = Block::make(provides);
   } else {
-    MakeReduction(this, stage->op.output(0), &init, &provide);
+    Array<Tensor> source;
+    for (size_t i = 0; i < size; ++i) {
+      source.push_back(stage->op.output(i));
+    }
+    MakeReduction(this, source, &init, &provide);
   }
+
   // make loop nest
   std::unordered_map<IterVar, Expr> value_map;
   auto nest = op::MakeLoopNest(
@@ -357,7 +446,7 @@ Stmt ComputeOpNode::BuildProvide(
     for (auto& e : preds) e = likely(e);
     init_nest.push_back(op::MakeIfNest(preds));
     init = Substitute(init, init_value_map);
-    init  = MergeNest(init_nest, init);
+    init = MergeNest(init_nest, init);
     // common nest
     std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
     std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index ce179b79188affcda881018a99c61633ab770c21..b12f6648dffcdcb5842cafffda7d26c8eebe6d7e 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -4,6 +4,7 @@
  */
 #include <tvm/ir.h>
 #include <tvm/ir_mutator.h>
+#include "./ir_util.h"
 
 namespace tvm {
 namespace ir {
@@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() {  // NOLINT(*)
 }
 
 inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
-  std::vector<Expr> new_arr(arr.size());
-  bool changed = false;
-  for (size_t i = 0; i < arr.size(); i++) {
-    Expr old_elem = arr[i];
-    Expr new_elem = m->Mutate(old_elem);
-    if (!new_elem.same_as(old_elem)) changed = true;
-    new_arr[i] = new_elem;
-  }
-  if (!changed) {
-    return arr;
-  } else {
-    return Array<Expr>(new_arr);
-  }
+  return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
 }
 
 inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
@@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or)
 
 Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
   Array<IterVar> new_axis  = MutateIterVarArr(op->axis, this);
-  Expr new_source = this->Mutate(op->source);
+  Array<Expr> new_source = MutateArray(op->source, this);
   Expr new_cond = this->Mutate(op->condition);
   if (op->axis.same_as(new_axis) &&
       op->source.same_as(new_source) &&
       op->condition.same_as(new_cond)) {
     return e;
   } else {
-    return Reduce::make(op->combiner, new_source, new_axis, new_cond);
+    return Reduce::make(
+      op->combiner, new_source, new_axis, new_cond, op->value_index);
   }
 }
 
diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h
index 1982f977365f2bce182434f559e06ba8cbb5be88..472b408e32d52015d0321d707a49770beae65161 100644
--- a/src/pass/ir_util.h
+++ b/src/pass/ir_util.h
@@ -12,6 +12,32 @@
 namespace tvm {
 namespace ir {
 
+/*!
+ * \brief update array with an unary function
+ * \param arr array
+ * \param fupdate an unary function
+ * \tparam T type of array element
+ * \tparam F type of the unary function
+ * \return if update happens, return the new array, else return the
+ *  original array
+ */
+template<typename T, typename F>
+inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
+  std::vector<T> new_arr(arr.size());
+  bool changed = false;
+  for (size_t i = 0; i < arr.size(); ++i) {
+    T old_elem = arr[i];
+    T new_elem = fupdate(old_elem);
+    if (!new_elem.same_as(old_elem)) changed = true;
+    new_arr[i] = new_elem;
+  }
+  if (!changed) {
+    return arr;
+  } else {
+    return Array<T>(new_arr);
+  }
+}
+
 /*!
  * \brief combine the nest stmt, whose body is not defined.
  * \param nest A list of For and LetStmt, whose body is not defined.
diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc
index bb1b3678e0d836c2f96888fec3a2359d6a72276a..bae93f9d00b69bada1986046d3a1386c43530cb9 100644
--- a/src/pass/ir_visitor.cc
+++ b/src/pass/ir_visitor.cc
@@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or)
 
 void IRVisitor::Visit_(const Reduce* op) {
   VisitRDom(op->axis, this);
-  this->Visit(op->source);
+  VisitArray(op->source, this);
 }
 
 void IRVisitor::Visit_(const Cast* op) {
diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc
index f9f99e8fcd265dbf42041a25ddc1f708d4b3d374..1e59723d59d5a184adebee0dafeacb02b1a89958 100644
--- a/src/pass/lower_thread_allreduce.cc
+++ b/src/pass/lower_thread_allreduce.cc
@@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator {
       return IRMutator::Mutate_(op, s);
     }
   }
-  Stmt Mutate_(const Store* op, const Stmt& s) final {
+  Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
     Stmt stmt = IRMutator::Mutate_(op, s);
-    op = stmt.as<Store>();
+    op = stmt.as<Evaluate>();
     const Call* call = op->value.as<Call>();
     if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
-      return MakeAllreduce(op, call);
+      return MakeAllreduce(call);
     } else {
       return stmt;
     }
@@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator {
     }
   };
   // make allreduce.
-  Stmt MakeAllreduce(const Store* op, const Call* call) {
+  Stmt MakeAllreduce(const Call* call) {
     CHECK(!reduce_combiner_.empty());
     const CommReducerNode *combiner = reduce_combiner_.back();
-    Expr init  = combiner->identity_element;
-    Expr value = call->args[0];
-    Expr cond  = call->args[1];
-    if (!is_one(cond)) {
-      value = Select::make(cond, value, init);
+    size_t size = combiner->result.size();
+
+    const UIntImm *size_of_args = call->args[0].as<UIntImm>();
+    CHECK(size_of_args) << call->args[0]->type_key();
+    CHECK_EQ(size, size_of_args->value);
+    Array<Expr> inits = combiner->identity_element;
+    std::vector<Expr> values(size);
+    std::vector<Type> types(size);
+    Expr cond  = call->args[size+1];
+    for (size_t idx = 0; idx < size; ++idx) {
+      values[idx] = call->args[1+idx];
+      if (!is_one(cond)) {
+        values[idx] = Select::make(cond, values[idx], inits[idx]);
+      }
+      types[idx] = values[idx].type();
+    }
+    std::vector<const Variable*> buffers(size);
+    for (size_t idx = 0; idx < size; ++idx) {
+      const Variable* buffer = call->args[2+size+idx].as<Variable>();
+      CHECK(buffer);
+      buffers[idx] = buffer;
     }
 
     std::unordered_set<const Variable*> reduce_set;
-    for (size_t i = 2; i < call->args.size(); ++i) {
+    for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
       const Variable* v = call->args[i].as<Variable>();
       CHECK(v);
       reduce_set.insert(v);
@@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator {
     int threadx_extent = 1;
     Expr reduce_index = FlattenThread(vred, &reduce_extent);
     Expr group_index = FlattenThread(vpar, &group_extent);
-    Expr pred = const_true(value.type().lanes());
     if (reduce_extent == 1) {
       // special case, no reduction is needed.
-      return Store::make(op->buffer_var, value, 0, pred);
+      std::vector<Stmt> stores(size);
+      for (size_t i = 0; i < size; ++i) {
+        Expr pred = const_true(types[i].lanes());
+        Var buffer_var(call->args[2+size+i].node_);
+        stores[i] = Store::make(buffer_var, values[i], 0, pred);
+      }
+      return Block::make(stores);
     }
     // Whether the threadIdx.x is involved in reduction.
     if (vred[0].scope.dim_index == 0) {
       threadx_extent = vred[0].extent;
     }
-    Var shared_buf("red_buf", Handle());
     std::vector<Stmt> seq;
-    seq.emplace_back(Store::make(
-        shared_buf, value,
-        BufIndex(reduce_index, group_index, reduce_extent), pred));
+    std::vector<Var> shared_bufs(size);
+    for (size_t idx = 0; idx < size; ++idx) {
+      shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
+      Expr pred = const_true(types[idx].lanes());
+      seq.emplace_back(Store::make(
+          shared_bufs[idx], values[idx],
+          BufIndex(reduce_index, group_index, reduce_extent), pred));
+    }
     seq.emplace_back(SyncThread("shared"));
     seq.emplace_back(MakeBufAllreduce(
-        combiner, value.type(), shared_buf,
+        combiner, types, shared_bufs,
         reduce_index, group_index, reduce_extent, threadx_extent));
-    CHECK(!load_remap_.count(op->buffer_var.get()));
-    load_remap_[op->buffer_var.get()] =
-        Load::make(
-            value.type(), shared_buf,
-            BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
-            pred);
-    alloc_remap_[op->buffer_var.get()] =
-        Allocate::make(shared_buf, value.type(),
-                       {Expr(group_extent), Expr(reduce_extent)},
-                       pred, Evaluate::make(0));
+    for (size_t idx = 0; idx < size; ++idx) {
+      CHECK(!load_remap_.count(buffers[idx]));
+      Expr pred = const_true(types[idx].lanes());
+      load_remap_[buffers[idx]] = Load::make(
+        types[idx], shared_bufs[idx],
+        BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
+      alloc_remap_[buffers[idx]] = Allocate::make(
+        shared_bufs[idx], types[idx],
+        {Expr(group_extent), Expr(reduce_extent)},
+        pred, Evaluate::make(0));
+    }
     return MergeSeq(seq);
   }
   // make allreduce.
   Stmt MakeBufAllreduce(const CommReducerNode *combiner,
-                        Type type,
-                        Var shared_buf,
+                        const std::vector<Type>& types,
+                        const Array<Var>& shared_bufs,
                         Expr reduce_index,
                         Expr group_index,
                         int reduce_extent,
@@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator {
     CHECK_GT(reduce_align, 1);
     std::vector<Stmt> seq;
 
+    size_t size = shared_bufs.size();
     Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
     // make reduction
     auto freduce = [&](int offset) {
-      Expr b = Load::make(
-          type, shared_buf,
-          BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
-      Expr a = Load::make(type, shared_buf, buf_index, const_true());
-      return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
+      Array<Expr> a, b;
+      for (size_t i = 0; i < size; ++i) {
+        b.push_back(Load::make(types[i], shared_bufs[i],
+          BufIndex(reduce_index + offset, group_index, reduce_extent),
+          const_true()));
+        a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
+      }
+      Array<Expr> ret = (*combiner)(a, b);
+      std::vector<Stmt> stores(size);
+      for (size_t i = 0; i < size; ++i) {
+        stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
+      }
+      return Block::make(stores);
     };
     // Step one, check for
     if (reduce_align > reduce_extent) {
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index ff4c5912dae0ecf74e2105a0af157705dc4a0757..f807b92dceaa7ccec6d6535f4dce75936e86cc86 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator {
     CHECK_EQ(extern_buf_remap_.size(), 0U);
     for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) {
       TensorKey key{func, static_cast<int>(i)};
-      CHECK(buf_map_.count(key));
+      CHECK(buf_map_.count(key))
+          << "Cannot find allocated buffer for " << key.f
+          << "(" << key.value_index << ")";
       extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] =
           buf_map_.at(key).buffer->data;
     }
diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc
index 312be19a08747f66d16cb55d7b8603c6ae417797..9fd073c0ac7afca44907e05d2bd822c6dd9c1a74 100644
--- a/src/schedule/auto_inline_elem_wise.cc
+++ b/src/schedule/auto_inline_elem_wise.cc
@@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor {
 bool IsElemWise(const Operation& op) {
   if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
     ElemWiseDetector v = ElemWiseDetector(compute->axis);
-    v.Visit(compute->body);
+    for (auto& e : compute->body) v.Visit(e);
     return v.is_elem_wise_;
   }
   return false;
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
index 0fcf21def2fca9e1583cc9e502d0ac10bedde004..da0aeb0eccaa3f21fef70fb56297473459757457 100644
--- a/src/schedule/graph.cc
+++ b/src/schedule/graph.cc
@@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
           }
         }
       };
-      ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+      for (auto& e : op.as<ComputeOpNode>()->body) {
+        ir::PostOrderVisit(e, fvisit);
+      }
     }
   }
   return reach;
@@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
         }
       }
     } else if (op.as<ComputeOpNode>()) {
-      std::unordered_map<const Node*, TensorDimKey> vmap;
+      std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
       const auto& axis = op.as<ComputeOpNode>()->axis;
-      Tensor t = op.output(0);
       for (size_t i = 0; i < axis.size(); ++i) {
-        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
+        std::vector<TensorDimKey> keys;
+        for (int j = 0; j < op->num_outputs(); ++j) {
+          keys.emplace_back(op.output(j), i);
+        }
+        vmap[axis[i]->var.get()] = std::move(keys);
       }
       auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
           const NodeRef& n) {
@@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
             auto it = vmap.find(call->args[i].get());
             TensorDimKey src(call, static_cast<int>(i));
             if (it != vmap.end()) {
-              f_merge_key(it->second, src);
+              const std::vector<TensorDimKey>& keys = it->second;
+              for (const auto& key : keys) {
+                f_merge_key(key, src);
+              }
             } else {
               if (exact_reach.count(src)) {
                 fail_set.insert(exact_reach.at(src));
@@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
           }
         }
       };
-      ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+      for (auto& e : op.as<ComputeOpNode>()->body) {
+        ir::PostOrderVisit(e, fvisit);
+      }
     }
   }
   ReachGraph reach;
diff --git a/src/schedule/graph.h b/src/schedule/graph.h
index 7908dc9e1de64a4eccae1fc97031415490d94cf7..50d35355cc6482faf2317a1da2f8e26265fa3569 100644
--- a/src/schedule/graph.h
+++ b/src/schedule/graph.h
@@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >;
 using AttachPath = Map<Operation, Array<IterVar> >;
 
 /*!
- * \brief The map beteen tensor and operation it feeds to.
+ * \brief The map between tensor and operation it feeds to.
  */
 using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
 
@@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
  *  The operations contains node which input-reachable from any inputs
  *  output reachable to any outputs.
  *
- *  The inputs won't be included in the subgraph, the outputs will be inclued.
+ *  The inputs won't be included in the subgraph, the outputs will be included.
  *
  * \param outputs The outputs of the subgraph
  * \param inputs The inputs to the subgraph.
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index 25319dc24eff7d755b326094185d7b5c0a6ad695..d24ba17bf6e3e97c276d5dbb8e7f1b7a77fac553 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -8,6 +8,7 @@
 #include <tvm/ir_pass.h>
 #include <unordered_set>
 #include "./message_passing.h"
+#include "../pass/ir_util.h"
 
 namespace tvm {
 
@@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor,
     vsub[iv->var.get()] = new_iv->var;
   }
   VarReplacer repl(vsub);
-  Expr body = repl.Mutate(compute->body);
+  Expr body = repl.Mutate(compute->body[tensor->value_index]);
   Operation cache_op = ComputeOpNode::make(
-      compute->name + "." + scope, new_axis, body);
+      compute->name + "." + scope, new_axis, {body});
   Tensor cache_tensor = cache_op.output(0);
   Operation orig_new_op = ComputeOpNode::make(
       compute->name, compute->axis,
-      cache_tensor(args));
+      {cache_tensor(args)});
 
   std::unordered_map<Tensor, Tensor> vmap;
   vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
@@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
 void InjectInline(ScheduleNode* sch) {
   sch->InvalidateCache();
 
-  std::vector<Expr> new_body(sch->stages.size());
+  std::vector<Array<Expr>> new_body(sch->stages.size());
+  std::vector<bool> changed(sch->stages.size(), false);
   // inline all the ops
   for (size_t i = sch->stages.size(); i != 0; --i) {
     Stage stage = sch->stages[i - 1];
     if (stage->attach_type == kInline) {
       stage->attach_type = kInlinedAlready;
       Array<Var> args;
-      Expr body;
+      Array<Expr> body;
       {
         // setup args
         const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
@@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) {
         Stage s = sch->stages[j];
         const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
         if (compute) {
-          if (!new_body[j].defined()) {
+          if (!new_body[j].size()) {
             new_body[j] = s->op.as<ComputeOpNode>()->body;
           }
-          new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]),
-                                   stage->op, args, body).as<ir::Evaluate>()->value;
+          for (size_t k = 0; k < body.size(); ++k) {
+            changed[j] = true;
+            new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
+                            stage->op, args, body[k]).as<ir::Evaluate>()->value);
+          }
         }
       }
     }
@@ -234,19 +239,21 @@ void InjectInline(ScheduleNode* sch) {
   for (size_t i = 0; i < sch->stages.size(); ++i) {
     Stage s = sch->stages[i];
     if (s->attach_type == kInlinedAlready) continue;
-    if (new_body[i].defined()) {
+    if (new_body[i].size()) {
       // Logics from ReplaceDataFlow
       const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
       CHECK(compute);
       Operation op = s->op;
-      if (!new_body[i].same_as(compute->body)) {
+      if (changed[i]) {
         op = ComputeOpNode::make(
             compute->name, compute->axis, new_body[i]);
       }
       op = op->ReplaceInputs(op, repl);
       if (!op.same_as(s->op)) {
-        repl[s->op.output(0)] = op.output(0);
-        s->op = op;
+        for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+          repl[s->op.output(idx)] = op.output(idx);
+          s->op = op;
+        }
       }
     } else {
       Operation op = s->op->ReplaceInputs(s->op, repl);
@@ -268,15 +275,15 @@ Schedule Schedule::normalize() {
 }
 
 // Handle reduction factor.
-Tensor Schedule::rfactor(const Tensor& tensor,
-                         const IterVar& axis) {
+Array<Tensor> Schedule::rfactor(const Tensor& tensor,
+                                const IterVar& axis) {
   (*this)->InvalidateCache();
   using ir::Reduce;
   CHECK_EQ(axis->iter_type, kCommReduce)
       << "Can only factor reduction axis";
   Stage reduce_stage = operator[](tensor->op);
   const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
-  CHECK(compute_op) << "Can only factor  ComputeOp";
+  CHECK(compute_op) << "Can only factor ComputeOp";
   ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
   {
     size_t axis_pos = FindNodeRef(leaf_vars, axis);
@@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor,
     }
   }
   // predicate generation, copy not touched axis.
-  const Reduce* reduce = compute_op->body.as<Reduce>();
+  int idx = tensor->value_index;
+  const Reduce* reduce = compute_op->body[idx].as<Reduce>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   Expr predicate = reduce->condition;
   std::unordered_map<const Variable*, Expr> vsub;
@@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor,
       n->reduce_axis.push_back(IterVar(ncpy));
     }
   }
-  n->body = Reduce::make(reduce->combiner,
-                         VarReplacer(vsub).Mutate(reduce->source),
-                         n->reduce_axis,
-                         predicate);
+  VarReplacer replacer(vsub);
+  Array<Expr> new_source = ir::UpdateArray(reduce->source,
+    [&replacer] (const Expr& e) { return replacer.Mutate(e); });
+  std::vector<Expr> body;
+  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
+    body.emplace_back(Reduce::make(reduce->combiner,
+                                   new_source,
+                                   n->reduce_axis,
+                                   predicate,
+                                   idx));
+  }
+  n->body = Array<Expr>(body);
   // refresh relations, keep the un-touched relations.
   Array<IterVarRelation> rels;
   for (IterVarRelation rel : reduce_stage->relations) {
@@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor,
   // Replace the old reduction.
   IterVar repl_red_axis = reduce_axis(
       dom_map.at(axis), axis->var->name_hint + ".v");
-  Tensor factor_tensor = factor_op.output(0);
-  Tensor old_tensor = reduce_stage->op.output(0);
-  Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) {
+  Array<Tensor> factor_tensors;
+  Array<Tensor> old_tensors;
+  int size = factor_op->num_outputs();
+  for (int idx = 0; idx < size; ++idx) {
+    factor_tensors.push_back(factor_op.output(idx));
+    old_tensors.push_back(reduce_stage->op.output(idx));
+  }
+  Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
+    [&](const Array<Var>& i) {
       Array<Expr> indices;
       indices.push_back(repl_red_axis->var);
       for (Var v : i) {
         indices.push_back(v);
       }
-      return Reduce::make(reduce->combiner,
-        factor_tensor(indices), {repl_red_axis}, const_true());
-    }, old_tensor->op->name + ".repl");
+      Array<Expr> factor_exprs;
+      for (int idx = 0; idx < size; ++idx) {
+        factor_exprs.push_back(factor_tensors[idx](indices));
+      }
+      Array<Expr> reductions;
+      Array<IterVar> axis = {repl_red_axis};
+      Expr cond = const_true();
+      for (int idx = 0; idx < size; ++idx) {
+        reductions.push_back(Reduce::make(reduce->combiner,
+          factor_exprs, axis, cond, idx));
+      }
+      return reductions;
+    }, reduce_stage->op->name + ".repl");
 
   std::unordered_map<Tensor, Tensor> vmap;
-  vmap[old_tensor] = repl_tensor;
+  for (int idx = 0; idx < size; ++idx) {
+    vmap[old_tensors[idx]] = repl_tensors[idx];
+  }
   ReplaceDataFlow((*this)->stages, &vmap);
   // revamp the reduction stage.
-  reduce_stage->op = repl_tensor->op;
-  reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars();
+  reduce_stage->op = repl_tensors[0]->op;
+  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
   reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
   reduce_stage->relations = Array<IterVarRelation>();
-  return factor_tensor;
+  return factor_tensors;
 }
 }  // namespace tvm
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index edbc2878a5a963d5de2d6482de592fd931a8b507..347cf69884b6934eaf13dfa8250e7126e1c996d7 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator {
       // This must be checked for all ops, including scan.
       if (!s->op.same_as(s->origin_op)) {
         for (int i = 0; i < s->op->num_outputs(); ++i) {
-          Tensor target = s->origin_op.output(0);
+          Tensor target = s->origin_op.output(i);
           AddReplace(s->op.output(i), target,
                      target, s->origin_op);
         }
diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py
index bf3be27fd3bf2040c69237c0d4249c300c18d5ee..ffc4b79b58d26415a544ebb6caa6fe50dddbe857 100644
--- a/tests/python/integration/test_reduce.py
+++ b/tests/python/integration/test_reduce.py
@@ -49,7 +49,6 @@ def test_reduce_prims():
     test_prim(tvm.max, np.amax)
 
 
-
 def test_rfactor():
     n = tvm.convert(1027)
     A = tvm.placeholder((n,), name='A')
@@ -128,7 +127,115 @@ def test_rfactor_threads():
     check_target("metal")
     check_target("opencl")
 
+def test_argmax():
+    def fcombine(x, y):
+        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
+        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
+        return lhs, rhs
+
+    def fidentity(t0, t1):
+        return tvm.const(-1, t0), tvm.min_value(t1)
+
+    argmax = tvm.comm_reducer(fcombine,
+                              fidentity,
+                              name='argmax')
+    m = tvm.var('m')
+    n = tvm.var('n')
+    idx = tvm.placeholder((m, n), name='idx', dtype='int32')
+    val = tvm.placeholder((m, n), name='val', dtype='float32')
+    k = tvm.reduce_axis((0, n), 'k')
+    T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T')
+    s = tvm.create_schedule(T0.op)
+
+    def check_target():
+        device = 'cpu'
+        if not tvm.module.enabled(device):
+            print("skip because %s is not enabled.." % device)
+            return
+        ctx = tvm.context(device, 0)
+        fapi = tvm.lower(s, args=[idx, val, T0, T1])
+        fargmax = tvm.build(fapi,
+                            target='llvm',
+                            name="argmax")
+
+        mm = 12
+        nn = 16
+        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
+        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
+        np_res = np.argmax(np_val, axis=1)
+
+        nd_idx  = tvm.nd.array(np_idx, ctx)
+        nd_val  = tvm.nd.array(np_val, ctx)
+        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
+        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
+        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
+        np.testing.assert_allclose(np_res, nd_res0.asnumpy())
+
+    check_target()
+
+
+def test_rfactor_argmax():
+    def fcombine(x, y):
+        lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
+        rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
+        return lhs, rhs
+
+    def fidentity(t0, t1):
+        return tvm.const(-1, t0), tvm.min_value(t1)
+
+    argmax = tvm.comm_reducer(fcombine,
+                              fidentity,
+                              name='argmax')
+
+    nn = 1027
+    mm = 10
+    n = tvm.convert(nn)
+    m = tvm.convert(mm)
+    A0 = tvm.placeholder((m, n), name='A0', dtype='int32')
+    A1 = tvm.placeholder((m, n), name='A1', dtype='float32')
+    k = tvm.reduce_axis((0, n))
+    B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B')
+
+    # schedule
+    s = tvm.create_schedule(B0.op)
+    nthread = 16
+    ko, kf = s[B0].split(k, factor=nthread)
+    BF0, BF1 = s.rfactor(B0, kf)
+    bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
+    s[B0].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[B0].bind(ty, tvm.thread_axis("threadIdx.y"))
+    tx = s[B0].op.reduce_axis[0]
+    thread_x = tvm.thread_axis("threadIdx.x")
+    s[B0].bind(tx, thread_x)
+    s[BF0.op].compute_at(s[B0], tx)
+    s[B0].set_store_predicate(thread_x.var.equal(0))
+
+    def check_target(device):
+        if not tvm.module.enabled(device):
+            print("skip because %s is not enabled.." % device)
+            return
+        ctx = tvm.context(device, 0)
+        fapi = tvm.lower(s, args=[A0, A1, B0, B1])
+        fargmax = tvm.build(fapi,
+                            target=device,
+                            name="argmax")
+
+        np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0)
+        np_val = np.random.uniform(size=(mm, nn)).astype('float32')
+        np_res = np.argmax(np_val, axis=1)
+
+        nd_idx  = tvm.nd.array(np_idx, ctx)
+        nd_val  = tvm.nd.array(np_val, ctx)
+        nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx)
+        nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx)
+        fargmax(nd_idx, nd_val, nd_res0, nd_res1)
+        np.testing.assert_allclose(np_res, nd_res0.asnumpy())
+
+    check_target("cuda")
+
 if __name__ == "__main__":
     test_rfactor_threads()
     test_rfactor()
     test_reduce_prims()
+    test_argmax()
+    test_rfactor_argmax()
diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py
index 5fed5a23f750bc6d0ddbf6eab2bb233eb2c2c89a..1b0eac15fe0782ba1201facba7e35b0e48f2dd4e 100644
--- a/tests/python/unittest/test_lang_schedule.py
+++ b/tests/python/unittest/test_lang_schedule.py
@@ -101,8 +101,8 @@ def test_rfactor():
     s = tvm.create_schedule(B.op)
     BF = s.rfactor(B, k1)
     assert(tuple(BF.shape) == (n, n))
-    assert(set(BF.op.body.axis) == set([k2]))
-    assert(s[B].op.body.axis[0].dom.extent == n)
+    assert(set(BF.op.body[0].axis) == set([k2]))
+    assert(s[B].op.body[0].axis[0].dom.extent == n)
     assert(len(s[B].all_iter_vars) == 2)
     # schedule with splot
     s = tvm.create_schedule(B.op)
@@ -111,9 +111,9 @@ def test_rfactor():
     BF = s.rfactor(B, ki)
     assert(BF.shape[0].value == 4)
     assert(BF.shape[1] == n)
-    assert(BF.op.body.axis[0] ==  k2)
-    assert(BF.op.body.axis[1].var ==  ko.var)
-    assert(s[B].op.body.axis[0].dom.extent.value == 4)
+    assert(BF.op.body[0].axis[0] ==  k2)
+    assert(BF.op.body[0].axis[1].var ==  ko.var)
+    assert(s[B].op.body[0].axis[0].dom.extent.value == 4)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index e81b15ffa653c34df4604bdc83180e26866f3de0..9160baec3789379965c293bdbc094674cf53c0ad 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -118,6 +118,43 @@ def test_extern_multi_out():
     assert(len(res) == 2)
     assert(res[1].value_index == 1)
 
+def test_tuple_inputs():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A0 = tvm.placeholder((m, n), name='A0')
+    A1 = tvm.placeholder((m, n), name='A1')
+    T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T')
+    s = tvm.create_schedule(T0.op)
+
+    for i in range(len(T0.shape)):
+      assert(T0.shape[i] == T1.shape[i])
+    assert(T0.op == T1.op)
+    assert(T0.value_index == 0)
+    assert(T1.value_index == 1)
+
+def test_tuple_with_different_deps():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A0 = tvm.placeholder((m, n), name='A1')
+    A1 = tvm.placeholder((m, n), name='A2')
+    B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B')
+    C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C')
+
+    s = tvm.create_schedule(C.op)
+    xo, xi = s[C].split(C.op.axis[0], factor=10)
+    s[B0.op].compute_at(s[C], xo)
+    sch = s.normalize()
+    bounds = tvm.schedule.InferBound(sch)
+    stmt = tvm.schedule.ScheduleOps(sch, bounds)
+
+    def get_B1_realize(x):
+        if isinstance(x, tvm.stmt.Realize) and \
+           x.func == B1.op and x.value_index == 1:
+            ret.append(x)
+    ret = []
+    tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize)
+
+    assert stmt.node == C.op and len(ret) == 1
 
 if __name__ == "__main__":
     test_conv1d()
@@ -128,3 +165,5 @@ if __name__ == "__main__":
     test_scan_multi_out()
     test_extern()
     test_extern_multi_out()
+    test_tuple_inputs()
+    test_tuple_with_different_deps()
diff --git a/tests/python/unittest/test_pass_inline.py b/tests/python/unittest/test_pass_inline.py
index 1988d54083c767feb4758f48908fb761ff394beb..398c0d34d58d6336c1bc45629be65375eb50ecac 100644
--- a/tests/python/unittest/test_pass_inline.py
+++ b/tests/python/unittest/test_pass_inline.py
@@ -6,7 +6,7 @@ def test_inline():
     T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
     stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
     stmt = tvm.ir_pass.Inline(
-        stmt, T.op, [x.var for x in T.op.axis], T.op.body)
+        stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
     print(stmt)
     assert(tvm.ir_pass.VerifySSA(stmt))
 
@@ -25,7 +25,7 @@ def test_inline2():
     T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
     stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100])
     stmt = tvm.ir_pass.Inline(
-        stmt, T.op, [x.var for x in T.op.axis], T.op.body)
+        stmt, T.op, [x.var for x in T.op.axis], T.op.body[0])
     def check(op):
         if isinstance(op, tvm.expr.Call):
             assert op.func != T.op
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index eceba04f5b0b1b1dcaec3b892e4626eb6db46195..297800e1632d18dc3e9a8797c97cbd801203cce0 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -89,7 +89,7 @@ def test_inline_mixed():
     def check(x):
         if isinstance(x, tvm.expr.Call):
             assert x.func != A2
-    tvm.ir_pass.PostOrderVisit(s[C].op.body, check)
+    tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check)
 
 
 def test_scan_inline1():
diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py
index 1bdc1f9b8e756f93d93edf4caad6f39f17f8225a..e7295cb927a32a0d41f48dfab7e921420fa5fb92 100644
--- a/tutorials/python/reduction.py
+++ b/tutorials/python/reduction.py
@@ -125,6 +125,8 @@ np.testing.assert_allclose(
     b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)
 
 ######################################################################
+# .. _general-reduction:
+#
 # Define General Commutative Reduction Operation
 # ----------------------------------------------
 # Besides the built-in reduction operations like :any:`tvm.sum`,
@@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A')
 k = tvm.reduce_axis((0, m), name='k')
 B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')
 
+######################################################################
+# .. note::
+#
+#   Sometimes we would like to perform reduction that involves multiple
+#   values like :code:`argmax`, which can be done by tuple inputs.
+#   See :ref:`reduction-with-tuple-inputs` for more detail.
 
 ######################################################################
 # Summary
diff --git a/tutorials/python/tuple_inputs_operation.py b/tutorials/python/tuple_inputs_operation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c101a59e86e2f04c97b43a68c00fa5f258e719e
--- /dev/null
+++ b/tutorials/python/tuple_inputs_operation.py
@@ -0,0 +1,103 @@
+"""
+Compute and Reduction with Tuple Inputs
+=======================================
+**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_
+
+Often we want to compute multiple outputs with the same shape within
+a single loop or perform reduction that involves multiple values like
+:code:`argmax`. These problems can be addressed by tuple inputs.
+
+In this tutorial, we will introduce the usage of tuple inputs in TVM.
+"""
+from __future__ import absolute_import, print_function
+
+import tvm
+import numpy as np
+
+######################################################################
+# Describe Batchwise Computation
+# ------------------------------
+# For operators which have the same shape, we can put them together as
+# the inputs of :any:`tvm.compute`, if we wish they can be scheduled
+# together in the next schedule procedure.
+#
+n = tvm.var("n")
+m = tvm.var("m")
+A0 = tvm.placeholder((m, n), name='A0')
+A1 = tvm.placeholder((m, n), name='A1')
+B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B')
+
+# The generated IR code would be:
+s = tvm.create_schedule(B0.op)
+print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
+
+######################################################################
+# .. _reduction-with-tuple-inputs:
+#
+# Describe Reduction with Collaborative Inputs
+# --------------------------------------------
+# Sometimes, we requires multiple inputs to express some reduction
+# operators, and the inputs will collaborate together, e.g. :code:`argmax`.
+# In the reduction procedure, :code:`argmax` need to compare the value of
+# operands, also need to keep the index of operand. It can be expressed
+# with :any:`comm_reducer` as below:
+
+# x and y are the operands of reduction, both of them is a tuple of index
+# and value.
+def fcombine(x, y):
+    lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0])
+    rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1])
+    return lhs, rhs
+
+# our identity element also need to be a tuple, so `fidentity` accepts
+# two types as inputs.
+def fidentity(t0, t1):
+    return tvm.const(-1, t0), tvm.min_value(t1)
+
+argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')
+
+# describe the reduction computation
+m = tvm.var('m')
+n = tvm.var('n')
+idx = tvm.placeholder((m, n), name='idx', dtype='int32')
+val = tvm.placeholder((m, n), name='val', dtype='int32')
+k = tvm.reduce_axis((0, n), 'k')
+T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T')
+
+# the generated IR code would be:
+s = tvm.create_schedule(T0.op)
+print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
+
+######################################################################
+# .. note::
+#
+#   For ones who are not familiar with reduction, please refer to
+#   :ref:`general-reduction`.
+
+######################################################################
+# Schedule Operation with Tuple Inputs
+# ------------------------------------
+# It is worth mentioning that although you will get multiple outputs
+# with one batch operation, but they can only be scheduled together
+# in terms of operation.
+
+n = tvm.var("n")
+m = tvm.var("m")
+A0 = tvm.placeholder((m, n), name='A0')
+B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B')
+A1 = tvm.placeholder((m, n), name='A1')
+C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C')
+
+s = tvm.create_schedule(C.op)
+s[B0].compute_at(s[C], C.op.axis[0])
+# as you can see in the below generated IR code:
+print(tvm.lower(s, [A0, A1, C], simple_mode=True))
+
+######################################################################
+# Summary
+# -------
+# This tutorial introduces the usage of tuple inputs operation.
+#
+# - Describe normal batchwise computation.
+# - Describe reduction operation with tuple inputs.
+# - Notice that you can only schedule computation in terms of operation instead of tensor.