diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 39cd82de83e267ebd70b7d5b3b3a07accb3731f8..3e56106df0c2d3a0f6ebb590415c6852d16a5ee7 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -151,6 +151,19 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> { } }; +// Clip +struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { + double a_min; + double a_max; + + TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { + TVM_ATTR_FIELD(a_min) + .describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max) + .describe("The maximum clip value."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1681f9b87d2f732e4e038c56192b46170f2e3ab4..60b18218a313109ec2707e7292d4c864df3f6e97 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -182,6 +182,14 @@ class ExprMutator std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; }; +/*! + * \brief recursively visit the ir in post DFS order node, apply fvisit + * Each node is guaranteed to be visited only once. + * \param node The ir to be visited. + * \param fvisit The visitor function to be applied. + */ +void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit); + /* * \brief Bind function parameters or free variables. * diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ef0a59cd3f6d558e98c2a87544b2d4dd26d4bce2..6297e366070f792dea66fcb5201d8e99cf75baf8 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -10,6 +10,19 @@ from . import _make from .expr import Expr from .ty import Type +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + fvisit : function + The visitor function to be applied. + """ + return _ir_pass.post_order_visit(expr, fvisit) def infer_type(expr, mod=None): """Infer the type of expr under the context of mod. diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 5e3ee1761c38e3e8ab9496a84cc8152cefda7a07..bacbfea7c0638ec960b7869a358844dff4f5daae 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -228,6 +228,36 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitType(const Type& t) { return; } + +// visitor to implement apply +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {} + void VisitExpr(const Expr& e) final { + if (visited_.count(e.get()) != 0) return; + visited_.insert(e.get()); + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function<void(const Expr&)> f_; + std::unordered_set<const Node*> visited_; +}; + +void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_API("relay._ir_pass.post_order_visit") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc f = args[1]; + PostOrderVisit(args[0], [f](const Expr& n) { + f(n); + }); + }); + + // Implement bind. class ExprBinder : public ExprMutator { public: diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 6c94fe2adcc24c0b6a0d03f0b36fee2908e11d58..fef0302a05074735606cc6454328699a1fd1521d 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -5,6 +5,7 @@ */ #include <tvm/relay/expr.h> #include <tvm/relay/op.h> +#include <tvm/relay/attrs/transform.h> #include <topi/elemwise.h> #include "../type_relations.h" #include "../op_common.h" @@ -89,19 +90,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") .add_type_rel("Identity", IdentityRel) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); - -// Clip -struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { - double a_min; - double a_max; - - TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); - } -}; +// relay.clip +TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") .set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {