From dd248af6b8b9e82ef06f02af81dadf7b00eccd46 Mon Sep 17 00:00:00 2001 From: Ding <37059654+dingobye@users.noreply.github.com> Date: Fri, 16 Mar 2018 13:30:20 +1100 Subject: [PATCH] [LANGUAGE] Verify Compute with respect to Reduce operations (#1006) --- src/op/compute_op.cc | 94 ++++++++++++++++--- .../unittest/test_lang_verify_compute.py | 64 +++++++++++++ 2 files changed, 143 insertions(+), 15 deletions(-) create mode 100644 tests/python/unittest/test_lang_verify_compute.py diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 8b8bfbfe6..f3f8335c1 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -24,6 +24,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ComputeOpNode); +/// Verify if ComputeOp is valid with respect to Reduce operations. +static void VerifyComputeOp(const ComputeOpNode *op); + inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && @@ -116,15 +119,9 @@ Operation ComputeOpNode::make(std::string name, n->body = body; if (n->body[0]->is_type<ir::Reduce>()) { const ir::Reduce* reduce = n->body[0].as<ir::Reduce>(); - for (size_t i = 1; i < n->body.size(); ++i) { - const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>(); - CHECK(reduce_); - CHECK(ReduceEqual(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; - } n->reduce_axis = reduce->axis; } + VerifyComputeOp(n.get()); return Operation(n); } @@ -151,18 +148,11 @@ Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map<Tensor, Tensor>& rmap) const { CHECK_EQ(self.operator->(), this); + VerifyComputeOp(this); Array<Expr> arr; if (this->body[0]->is_type<ir::Reduce>()) { // Specially handle reduce so the replaced op // still share all the components - const ir::Reduce* reduce = this->body[0].as<ir::Reduce>(); - for (size_t i = 1; i < this->body.size(); ++i) { - const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>(); - CHECK(reduce_); - CHECK(ReduceEqual(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; - }\ Expr new_reduce = op::ReplaceTensor(this->body[0], rmap); if (!new_reduce.same_as(this->body[0])) { const ir::Reduce* r = new_reduce.as<ir::Reduce>(); @@ -466,4 +456,78 @@ ComputeLoopNest ComputeLoopNest::make( // copy elison here. return ret; } + +namespace { +/*! + * \brief Verify if ComputeOp is valid with respect to Reduce operations. + * + * The following two properties are verified: + * (1) All Reduce operations must exist at top level. + * (2) For a list of operations, if one is Reduce, then the others + * must be Reduce as well; and their inputs should have the + * same attribute except value_index. + */ +class ComputeVerifier final : protected ir::IRVisitor { + public: + /// Special member functions + //@{ + explicit ComputeVerifier(const ComputeOpNode* compute) + : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {} + virtual ~ComputeVerifier() = default; + ComputeVerifier(const ComputeVerifier&) = delete; + ComputeVerifier(ComputeVerifier&&) = delete; + ComputeVerifier& operator=(const ComputeVerifier&) = delete; + ComputeVerifier& operator=(ComputeVerifier&&) = delete; + //@} + + /// Interface to perform compute verification + void Run() { + for (const Expr e : compute_->body) { + // Check for consistency of top level reductions + const ir::Reduce* reduce = e.as<ir::Reduce>(); + CHECK((reduce && reduce_) || (!reduce && !reduce_)) + << "All ComputeOp should be consistent " + << "with being Reduce operation or not."; + + if (reduce && reduce_) { + CHECK(ReduceEqual(reduce, reduce_)) + << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; + } + + level_ = 0; + ir::IRVisitor::Visit(e); + } + } + + protected: + /// Visitor implementation + //@{ + void Visit(const NodeRef& n) final { + ++level_; + ir::IRVisitor::Visit(n); + --level_; + } + + void Visit_(const ir::Reduce* op) final { + // Check for non top level reductions + CHECK(0 == level_) + << "Reductions are only allowed at the top level of compute. " + << "Please create another tensor for further composition."; + } + //@} + + private: + const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify + const ir::Reduce* reduce_{nullptr}; ///< Top level Reduce operation + int level_{0}; ///< Level of op being processed +}; +} // namespace + +/// Verify if ComputeOp is valid with respect to Reduce operations. +static void VerifyComputeOp(const ComputeOpNode* op) { + ComputeVerifier v(op); + v.Run(); +} + } // namespace tvm diff --git a/tests/python/unittest/test_lang_verify_compute.py b/tests/python/unittest/test_lang_verify_compute.py new file mode 100644 index 000000000..1b9ecf453 --- /dev/null +++ b/tests/python/unittest/test_lang_verify_compute.py @@ -0,0 +1,64 @@ +import tvm + +def test_verify_compute(): + n = tvm.var("n") + m = tvm.var("m") + A = tvm.placeholder((n, m), name='A') + k = tvm.reduce_axis((0, m), "k") + k_ = tvm.reduce_axis((0, m-1), "k_") + f1 = lambda i: tvm.sum(A[i, k], axis=k) + f2 = lambda i: A[i,0] + 1 + f3 = lambda i: tvm.sum(A[i, k], axis=k) + 1 + f4 = lambda i: A[i,0] * (tvm.sum(A[i, k], axis=k) + 1) + f5 = lambda i: (tvm.sum(A[i, k], axis=k), A[i,0] + 1) + f6 = lambda i: (tvm.sum(A[i, k], axis=k), tvm.sum(A[i, k_], axis=k_)) + + # + # Valid compute + try: + B = tvm.compute((n,), f1, name="B") + except tvm._ffi.base.TVMError as ex: + assert False + + # + # Valid compute + try: + B = tvm.compute((n,), f2, name="B") + except tvm._ffi.base.TVMError as ex: + assert False + + # + # Invalid compute with non top level reduction + try: + B = tvm.compute((n,), f3, name="B") + assert False + except tvm._ffi.base.TVMError as ex: + pass + + # + # Invalid compute with non top level reduction + try: + B = tvm.compute((n,), f4, name="B") + assert False + except tvm._ffi.base.TVMError as ex: + pass + + # + # Invalid compute with reduction and non-reduction batch ops + try: + B0, B1 = tvm.compute((n,), f5, name="B") + assert False + except tvm._ffi.base.TVMError as ex: + pass + + # + # Invalid compute with unequal batch reduction ops + try: + B0, B1 = tvm.compute((n,), f6, name="B") + assert False + except tvm._ffi.base.TVMError as ex: + pass + + +if __name__ == "__main__": + test_verify_compute() \ No newline at end of file -- GitLab