From dd248af6b8b9e82ef06f02af81dadf7b00eccd46 Mon Sep 17 00:00:00 2001
From: Ding <37059654+dingobye@users.noreply.github.com>
Date: Fri, 16 Mar 2018 13:30:20 +1100
Subject: [PATCH] [LANGUAGE] Verify Compute with respect to Reduce operations
 (#1006)

---
 src/op/compute_op.cc                          | 94 ++++++++++++++++---
 .../unittest/test_lang_verify_compute.py      | 64 +++++++++++++
 2 files changed, 143 insertions(+), 15 deletions(-)
 create mode 100644 tests/python/unittest/test_lang_verify_compute.py

diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 8b8bfbfe6..f3f8335c1 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -24,6 +24,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode *op);
+
 inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
   return (a->combiner.same_as(b->combiner)) &&
          (a->source.same_as(b->source)) &&
@@ -116,15 +119,9 @@ Operation ComputeOpNode::make(std::string name,
   n->body = body;
   if (n->body[0]->is_type<ir::Reduce>()) {
     const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
-    for (size_t i = 1; i < n->body.size(); ++i) {
-      const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>();
-      CHECK(reduce_);
-      CHECK(ReduceEqual(reduce_, reduce))
-        << "The Reduce inputs of ComputeOp should "
-        << "have the same attribute except value_index";
-    }
     n->reduce_axis = reduce->axis;
   }
+  VerifyComputeOp(n.get());
   return Operation(n);
 }
 
@@ -151,18 +148,11 @@ Operation ComputeOpNode::ReplaceInputs(
     const Operation& self,
     const std::unordered_map<Tensor, Tensor>& rmap) const {
   CHECK_EQ(self.operator->(), this);
+  VerifyComputeOp(this);
   Array<Expr> arr;
   if (this->body[0]->is_type<ir::Reduce>()) {
     // Specially handle reduce so the replaced op
     // still share all the components
-    const ir::Reduce* reduce = this->body[0].as<ir::Reduce>();
-    for (size_t i = 1; i < this->body.size(); ++i) {
-      const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>();
-      CHECK(reduce_);
-      CHECK(ReduceEqual(reduce_, reduce))
-        << "The Reduce inputs of ComputeOp should "
-        << "have the same attribute except value_index";
-    }\
     Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
     if (!new_reduce.same_as(this->body[0])) {
       const ir::Reduce* r = new_reduce.as<ir::Reduce>();
@@ -466,4 +456,78 @@ ComputeLoopNest ComputeLoopNest::make(
   // copy elison here.
   return ret;
 }
+
+namespace {
+/*!
+ * \brief Verify if ComputeOp is valid with respect to Reduce operations.
+ *
+ *  The following two properties are verified:
+ *  (1) All Reduce operations must exist at top level.
+ *  (2) For a list of operations, if one is Reduce, then the others
+ *      must be Reduce as well; and their inputs should have the
+ *      same attribute except value_index.
+ */
+class ComputeVerifier final : protected ir::IRVisitor {
+ public:
+  /// Special member functions
+  //@{
+  explicit ComputeVerifier(const ComputeOpNode* compute)
+      : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
+  virtual ~ComputeVerifier() = default;
+  ComputeVerifier(const ComputeVerifier&) = delete;
+  ComputeVerifier(ComputeVerifier&&) = delete;
+  ComputeVerifier& operator=(const ComputeVerifier&) = delete;
+  ComputeVerifier& operator=(ComputeVerifier&&) = delete;
+  //@}
+
+  /// Interface to perform compute verification
+  void Run() {
+    for (const Expr e : compute_->body) {
+      // Check for consistency of top level reductions
+      const ir::Reduce* reduce = e.as<ir::Reduce>();
+      CHECK((reduce && reduce_) || (!reduce && !reduce_))
+          << "All ComputeOp should be consistent "
+          << "with being Reduce operation or not.";
+
+      if (reduce && reduce_) {
+        CHECK(ReduceEqual(reduce, reduce_))
+            << "The Reduce inputs of ComputeOp should "
+            << "have the same attribute except value_index";
+      }
+
+      level_ = 0;
+      ir::IRVisitor::Visit(e);
+    }
+  }
+
+ protected:
+  /// Visitor implementation
+  //@{
+  void Visit(const NodeRef& n) final {
+    ++level_;
+    ir::IRVisitor::Visit(n);
+    --level_;
+  }
+
+  void Visit_(const ir::Reduce* op) final {
+    // Check for non top level reductions
+    CHECK(0 == level_)
+        << "Reductions are only allowed at the top level of compute. "
+        << "Please create another tensor for further composition.";
+  }
+  //@}
+
+ private:
+  const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
+  const ir::Reduce* reduce_{nullptr};      ///< Top level Reduce operation
+  int level_{0};                           ///< Level of op being processed
+};
+}  // namespace
+
+/// Verify if ComputeOp is valid with respect to Reduce operations.
+static void VerifyComputeOp(const ComputeOpNode* op) {
+  ComputeVerifier v(op);
+  v.Run();
+}
+
 }  // namespace tvm
diff --git a/tests/python/unittest/test_lang_verify_compute.py b/tests/python/unittest/test_lang_verify_compute.py
new file mode 100644
index 000000000..1b9ecf453
--- /dev/null
+++ b/tests/python/unittest/test_lang_verify_compute.py
@@ -0,0 +1,64 @@
+import tvm
+
+def test_verify_compute():
+  n = tvm.var("n")
+  m = tvm.var("m")
+  A = tvm.placeholder((n, m), name='A')
+  k = tvm.reduce_axis((0, m), "k")
+  k_ = tvm.reduce_axis((0, m-1), "k_")
+  f1 = lambda i: tvm.sum(A[i, k], axis=k)
+  f2 = lambda i: A[i,0] + 1
+  f3 = lambda i: tvm.sum(A[i, k], axis=k) + 1
+  f4 = lambda i: A[i,0] * (tvm.sum(A[i, k], axis=k) + 1)
+  f5 = lambda i: (tvm.sum(A[i, k], axis=k), A[i,0] + 1)
+  f6 = lambda i: (tvm.sum(A[i, k], axis=k), tvm.sum(A[i, k_], axis=k_))
+
+  #
+  # Valid compute
+  try:
+    B = tvm.compute((n,), f1, name="B")
+  except tvm._ffi.base.TVMError as ex:
+    assert False
+
+  #
+  # Valid compute
+  try:
+    B = tvm.compute((n,), f2, name="B")
+  except tvm._ffi.base.TVMError as ex:
+    assert False
+
+  #
+  # Invalid compute with non top level reduction
+  try:
+    B = tvm.compute((n,), f3, name="B")
+    assert False
+  except tvm._ffi.base.TVMError as ex:
+    pass
+
+  #
+  # Invalid compute with non top level reduction
+  try:
+    B = tvm.compute((n,), f4, name="B")
+    assert False
+  except tvm._ffi.base.TVMError as ex:
+    pass
+
+  #
+  # Invalid compute with reduction and non-reduction batch ops
+  try:
+    B0, B1 = tvm.compute((n,), f5, name="B")
+    assert False
+  except tvm._ffi.base.TVMError as ex:
+    pass
+
+  #
+  # Invalid compute with unequal batch reduction ops
+  try:
+    B0, B1 = tvm.compute((n,), f6, name="B")
+    assert False
+  except tvm._ffi.base.TVMError as ex:
+    pass
+
+
+if __name__ == "__main__":
+  test_verify_compute()
\ No newline at end of file
-- 
GitLab