diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 3c17e47030cfd52f3c0ac113bf8a01ae660b131d..11731361148d6b528e4f5a3c10ddc39ad93e5c74 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -24,6 +24,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 
+inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+  return (a->combiner.same_as(b->combiner)) &&
+         (a->source.same_as(b->source)) &&
+         (a->axis.same_as(b->axis)) &&
+         (a->condition.same_as(b->condition));
+}
+
 int ComputeOpNode::num_outputs() const {
   return body.size();
 }
@@ -98,13 +105,6 @@ Array<Tensor> compute(Array<Expr> shape,
   return outputs;
 }
 
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
-  return (a->combiner.same_as(b->combiner)) &&
-         (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) &&
-         (a->condition.same_as(b->condition));
-}
-
 Operation ComputeOpNode::make(std::string name,
                               std::string tag,
                               Array<IterVar> axis,
@@ -151,9 +151,35 @@ Operation ComputeOpNode::ReplaceInputs(
     const Operation& self,
     const std::unordered_map<Tensor, Tensor>& rmap) const {
   CHECK_EQ(self.operator->(), this);
-  Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
-      return op::ReplaceTensor(e, rmap);
-    });
+  Array<Expr> arr;
+  if (this->body[0]->is_type<ir::Reduce>()) {
+    // Specially handle reduce so the replaced op
+    // still share all the components
+    const ir::Reduce* reduce = this->body[0].as<ir::Reduce>();
+    for (size_t i = 1; i < this->body.size(); ++i) {
+      const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>();
+      CHECK(reduce_);
+      CHECK(ReduceEqual(reduce_, reduce))
+        << "The Reduce inputs of ComputeOp should "
+        << "have the same attribute except value_index";
+    }\
+    Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
+    if (!new_reduce.same_as(this->body[0])) {
+      const ir::Reduce* r = new_reduce.as<ir::Reduce>();
+      for (size_t k = 0; k < this->body.size(); ++k) {
+        std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
+        n->value_index = static_cast<int>(k);
+        n->type = r->source[k].type();
+        arr.push_back(Expr(n));
+      }
+    } else {
+      arr = this->body;
+    }
+  } else {
+    arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
+        return op::ReplaceTensor(e, rmap);
+      });
+  }
   if (!arr.same_as(this->body)) {
     return ComputeOpNode::make(name, tag, axis, arr);
   } else {
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index cd0d5e43625e90f290f31cfb797d23a8d3ec56b4..7cf6711d2270c5209ecd7d0729cb53b7fc04b49a 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -162,6 +162,7 @@ class TensorReplacer : public ir::IRMutator {
  public:
   explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
       : vmap_(vmap) {}
+
   Expr Mutate_(const ir::Call* op, const Expr& e) {
     if (op->call_type == ir::Call::Halide) {
       Tensor t = Operation(op->func.node_).output(op->value_index);
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index 252cea21f7cec58ae0f92ddf3e790ecbd1bdfeea..a85db2a23e86cb1dd39598b84b282aedcd32ea19 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -68,16 +68,18 @@ def test_inline_multi_reduce():
     m = tvm.var('m')
     n = tvm.var('n')
     val = tvm.placeholder((m, n), name='val', dtype='float32')
-    val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
+    val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1')
+    val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2')
     k = tvm.reduce_axis((0, n), 'k')
     T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
     s = tvm.create_schedule(T_idx.op)
-    s[val2].compute_inline()
+    s[val1].compute_inline()
     s = s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
 
 
+
 def test_auto_inline():
     m = tvm.var('m')
     n = tvm.var('n')