diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 9e9dd0604916f1971be52a9c89e3d8b0934ec917..760a226a2facc22e72c96adcb6aa2193bac08d0b 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -246,44 +246,9 @@ class ForwardPrep : private ExprVisitor { // Per operator defs for FScaleAxisForward //---------------------------------------------- -// Helper functions -Expr GetForwardScale(const Expr& expr, AxesSet out) { - static const Op& multiply = Op::Get("multiply"); - static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep"); - - const CallNode* call = expr.as<CallNode>(); - if (!call) return NullValue<Expr>(); - auto f = fprep.get(call->op, nullptr); - - if (call->op.same_as(multiply)) { - const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); - const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); - if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) { - return call->args[1]; - } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) { - return call->args[0]; - } else { - return NullValue<Expr>(); - } - } else if (f != nullptr) { - Array<AxesSet> in_axes = f(GetRef<Call>(call), out); - for (size_t i = 0; i < call->args.size(); i++) { - auto scale = GetForwardScale(call->args[i], in_axes[i]); - if (scale.defined()) { - return scale; - } - } - } - return NullValue<Expr>(); -} - // Intermediate operators Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { - Expr scale = GetForwardScale(call->args[0], out); - if (IsPositiveConstant(scale)) { - return {out}; - } - return {NullValue<AxesSet>()}; + return {out}; } Expr ReluForwardRewrite(const Call& ref_call, @@ -391,16 +356,21 @@ Expr MultiplyForwardRewrite(const Call& ref_call, Expr lhs = new_args[0]; Expr rhs = new_args[1]; auto rnode = make_node<ScaledExprNode>(); - if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) { + if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && + IsAllPositiveConstant(rhs)) { rnode->value = lhs; rnode->scale = rhs; rnode->axes = expected_out_axes; - } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs)) { + return Expr(rnode); + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, expected_out_axes, &lhs) && + IsAllPositiveConstant(lhs)) { rnode->value = rhs; rnode->scale = lhs; rnode->axes = expected_out_axes; + return Expr(rnode); + } else { + return Expr(); } - return Expr(rnode); } RELAY_REGISTER_OP("multiply") @@ -790,22 +760,6 @@ RELAY_REGISTER_OP("subtract") RELAY_REGISTER_OP("subtract") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); -// Find relu in the backward path between multiply and conv2d -bool FindBackwardRelu(const Expr& expr) { - const CallNode* call = expr.as<CallNode>(); - static const Op& conv2d = Op::Get("nn.conv2d"); - static const Op& relu = Op::Get("nn.relu"); - - if (!call) return false; - if (call->op.same_as(relu)) return true; - if (call->op.same_as(conv2d)) return false; - - for (size_t i = 0; i < call->args.size(); i++) { - if (FindBackwardRelu(call->args[i])) return true; - } - return false; -} - // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, @@ -821,16 +775,16 @@ Expr MultiplyBackwardTransform(const Call& call, // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; + // Only propagate positive scaling. if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && - (!FindBackwardRelu(call->args[0]) || - IsPositiveConstant(call->args[1]))) { + IsAllPositiveConstant(rhs)) { return transformer->Transform(call->args[0], lhs_axes, rhs); } } else if (rhs_axes.defined() && rhs_axes.size() != 0) { + // Only propagate positive scaling. Expr lhs = call->args[0]; if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && - (!FindBackwardRelu(call->args[1]) || - IsPositiveConstant(call->args[0]))) { + IsAllPositiveConstant(lhs)) { return transformer->Transform(call->args[1], rhs_axes, lhs); } } diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index d42494409b53227f74d1a76537cb411156863756..ddd73901c452ae1e4e0e6f414bbb3beab5acf923 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -22,6 +22,15 @@ namespace relay { std::unordered_map<const Node*, size_t> GetExprRefCount(const Expr& body); + +/*! + * \brief Check if expr is positive constant. + * \param expr The expression to be checked. + * \return Whether all elements of expr is positive constant. + */ +bool IsAllPositiveConstant(const Expr& expr); + + /*! * \brief Substitute var with subst. * \param type The type to be substituted. diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5d76efd0124d44ef3a919b315675952e1800b783..e6e8415bd620fef196bf15127681a6969af012f3 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -190,57 +190,6 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); - -template <typename T> -bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { - CHECK_EQ(tensor->ctx.device_type, kDLCPU); - CHECK(tensor->strides == nullptr); - CHECK_EQ(tensor->byte_offset, 0); - const T* data = static_cast<const T*>(tensor->data); - int64_t num_elems = 1; - for (int i = 0; i < tensor->ndim; ++i) { - num_elems *= tensor->shape[i]; - } - - for (int64_t i = 0; i < num_elems; i++) { - if (*data < value) { - return false; - } - data++; - } - return true; -} - - -inline bool IsPositiveConstant(const Expr& expr) { - const auto* constant = expr.as<ConstantNode>(); - if (!constant) return false; - const auto& tensor = constant->data; - const auto& dtype = tensor->dtype; - - if (dtype.lanes != 1) { - // pass - } else if (dtype.code == kDLFloat && dtype.bits == 32) { - return IsNDArrayAllGreaterEqual<float>(tensor, 0); - } else if (dtype.code == kDLFloat && dtype.bits == 64) { - return IsNDArrayAllGreaterEqual<double>(tensor, 0); - } else if (dtype.code == kDLInt && dtype.bits == 8) { - return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0); - } else if (dtype.code == kDLInt && dtype.bits == 32) { - return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0); - } else if (dtype.code == kDLUInt && dtype.bits == 8) { - return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0); - } else if (dtype.code == kDLUInt && dtype.bits == 32) { - return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0); - } - - LOG(WARNING) << "Unsupported data type (code = " << dtype.code - << ", bits = " << dtype.bits << ", lanes = " << dtype.lanes - << ")"; - return false; -} - - } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 8f7179deea5351a1be3b7c16f4ad427fc1490339..b99d975135bea45484103cf13e5529507021a5c7 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -146,5 +146,67 @@ GetExprRefCount(const Expr& body) { return ExprRefCounter().Get(body); } +template <typename T> +bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { + CHECK_EQ(tensor->ctx.device_type, kDLCPU); + CHECK(tensor->strides == nullptr); + CHECK_EQ(tensor->byte_offset, 0); + const T* data = static_cast<const T*>(tensor->data); + int64_t num_elems = 1; + for (int i = 0; i < tensor->ndim; ++i) { + num_elems *= tensor->shape[i]; + } + + for (int64_t i = 0; i < num_elems; i++) { + if (*data < value) { + return false; + } + data++; + } + return true; +} + +bool IsAllPositiveConstant(const Expr& expr) { + // peel through a few common transform ops. + static const auto& expand_dims = Op::Get("expand_dims"); + static const auto& reshape = Op::Get("reshape"); + static const auto& transpose = Op::Get("transpose"); + static const auto& squeeze = Op::Get("squeeze"); + + if (const auto* constant = expr.as<ConstantNode>()) { + const auto& tensor = constant->data; + const auto& dtype = tensor->dtype; + if (dtype.lanes != 1) { + return false; + } else if (dtype.code == kDLFloat && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<float>(tensor, 0); + } else if (dtype.code == kDLFloat && dtype.bits == 64) { + return IsNDArrayAllGreaterEqual<double>(tensor, 0); + } else if (dtype.code == kDLInt && dtype.bits == 8) { + return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0); + } else if (dtype.code == kDLInt && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0); + } else if (dtype.code == kDLUInt && dtype.bits == 8) { + return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0); + } else if (dtype.code == kDLUInt && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0); + } else { + return false; + } + } else if (const auto* op = expr.as<CallNode>()) { + // tail recursion. + if (op->op.same_as(expand_dims) || + op->op.same_as(reshape) || + op->op.same_as(transpose) || + op->op.same_as(squeeze)) { + return IsAllPositiveConstant(op->args[0]); + } else { + return false; + } + } else { + return false; + } +} + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index f42aa7b7b8d0bd17bf4a33af523125bb3227e0e3..57cb7c84b10d7cff87bac83cb96c47270b666d45 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -1,6 +1,9 @@ from tvm import relay import numpy as np +def _get_positive_scale(size): + return np.random.uniform(0.5, 1, size=size).astype('float32') + def test_fold_fwd_simple(): """Simple testcase.""" @@ -14,6 +17,7 @@ def test_fold_fwd_simple(): channels=channels, kernel_size=(3, 3), padding=(1, 1)) + return relay.Function(args, y) def expected(x, conv_weight, in_bias, in_scale, channels): @@ -37,14 +41,14 @@ def test_fold_fwd_simple(): in_channels = shape[1] weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32')) - + in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) y1_expected = relay.ir_pass.infer_type(y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) @@ -107,7 +111,7 @@ def test_fold_fwd_dual_path(): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) + in_scale = relay.const(_get_positive_scale(in_channels,)) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) @@ -141,7 +145,7 @@ def test_fold_fwd_fail(): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) + in_scale = relay.const(_get_positive_scale(size=(in_channels,))) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)