From 93949456eedee33b1a39f59a30189bbce6732247 Mon Sep 17 00:00:00 2001 From: Wuwei Lin <vincentl13x@gmail.com> Date: Wed, 5 Dec 2018 04:58:16 +0800 Subject: [PATCH] [RELAY][PASS] Check Positiveness in FoldScaleAxis (#2220) --- python/tvm/relay/build_module.py | 5 +- src/relay/pass/fold_scale_axis.cc | 61 +++++++++- src/relay/pass/pattern_util.h | 51 ++++++++ .../python/relay/test_pass_fold_scale_axis.py | 112 +++++++++++++----- 4 files changed, 193 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 7af22431a..5b05bc445 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -150,13 +150,14 @@ def optimize(func, params=None): func = ir_pass.infer_type(func) func = ir_pass.combine_parallel_conv2d(func) + if cfg.pass_enabled("FoldConstant"): + func = ir_pass.fold_constant(func) + if cfg.pass_enabled("FoldScaleAxis"): func = ir_pass.infer_type(func) func = ir_pass.backward_fold_scale_axis(func) func = ir_pass.infer_type(func) func = ir_pass.forward_fold_scale_axis(func) - - if cfg.pass_enabled("FoldConstant"): func = ir_pass.fold_constant(func) if cfg.pass_enabled("AlterOpLayout"): diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 1cd6606bd..9e9dd0604 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -246,9 +246,44 @@ class ForwardPrep : private ExprVisitor { // Per operator defs for FScaleAxisForward //---------------------------------------------- +// Helper functions +Expr GetForwardScale(const Expr& expr, AxesSet out) { + static const Op& multiply = Op::Get("multiply"); + static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep"); + + const CallNode* call = expr.as<CallNode>(); + if (!call) return NullValue<Expr>(); + auto f = fprep.get(call->op, nullptr); + + if (call->op.same_as(multiply)) { + const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); + const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); + if (MatchBroadcastToLeftAxes(tlhs, trhs, out)) { + return call->args[1]; + } else if (MatchBroadcastToLeftAxes(trhs, tlhs, out)) { + return call->args[0]; + } else { + return NullValue<Expr>(); + } + } else if (f != nullptr) { + Array<AxesSet> in_axes = f(GetRef<Call>(call), out); + for (size_t i = 0; i < call->args.size(); i++) { + auto scale = GetForwardScale(call->args[i], in_axes[i]); + if (scale.defined()) { + return scale; + } + } + } + return NullValue<Expr>(); +} + // Intermediate operators Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) { - return {out}; + Expr scale = GetForwardScale(call->args[0], out); + if (IsPositiveConstant(scale)) { + return {out}; + } + return {NullValue<AxesSet>()}; } Expr ReluForwardRewrite(const Call& ref_call, @@ -755,6 +790,22 @@ RELAY_REGISTER_OP("subtract") RELAY_REGISTER_OP("subtract") .set_attr<FBackwardTransform>("FScaleAxisBackwardTransform", AddSubBackwardTransform); +// Find relu in the backward path between multiply and conv2d +bool FindBackwardRelu(const Expr& expr) { + const CallNode* call = expr.as<CallNode>(); + static const Op& conv2d = Op::Get("nn.conv2d"); + static const Op& relu = Op::Get("nn.relu"); + + if (!call) return false; + if (call->op.same_as(relu)) return true; + if (call->op.same_as(conv2d)) return false; + + for (size_t i = 0; i < call->args.size(); i++) { + if (FindBackwardRelu(call->args[i])) return true; + } + return false; +} + // Producer operators // Multiply produces the scale-axis pair. Expr MultiplyBackwardTransform(const Call& call, @@ -770,12 +821,16 @@ Expr MultiplyBackwardTransform(const Call& call, // NOTE we won't recursively call mutating on scale part. // since there won't be scale chance within scale part. Expr rhs = call->args[1]; - if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs)) { + if (MatchBroadcastToLeftAxes(tlhs, trhs, lhs_axes, &rhs) && + (!FindBackwardRelu(call->args[0]) || + IsPositiveConstant(call->args[1]))) { return transformer->Transform(call->args[0], lhs_axes, rhs); } } else if (rhs_axes.defined() && rhs_axes.size() != 0) { Expr lhs = call->args[0]; - if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs)) { + if (MatchBroadcastToLeftAxes(trhs, tlhs, rhs_axes, &lhs) && + (!FindBackwardRelu(call->args[1]) || + IsPositiveConstant(call->args[0]))) { return transformer->Transform(call->args[1], rhs_axes, lhs); } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index e6e8415bd..5d76efd01 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -190,6 +190,57 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides); + +template <typename T> +bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { + CHECK_EQ(tensor->ctx.device_type, kDLCPU); + CHECK(tensor->strides == nullptr); + CHECK_EQ(tensor->byte_offset, 0); + const T* data = static_cast<const T*>(tensor->data); + int64_t num_elems = 1; + for (int i = 0; i < tensor->ndim; ++i) { + num_elems *= tensor->shape[i]; + } + + for (int64_t i = 0; i < num_elems; i++) { + if (*data < value) { + return false; + } + data++; + } + return true; +} + + +inline bool IsPositiveConstant(const Expr& expr) { + const auto* constant = expr.as<ConstantNode>(); + if (!constant) return false; + const auto& tensor = constant->data; + const auto& dtype = tensor->dtype; + + if (dtype.lanes != 1) { + // pass + } else if (dtype.code == kDLFloat && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<float>(tensor, 0); + } else if (dtype.code == kDLFloat && dtype.bits == 64) { + return IsNDArrayAllGreaterEqual<double>(tensor, 0); + } else if (dtype.code == kDLInt && dtype.bits == 8) { + return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0); + } else if (dtype.code == kDLInt && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0); + } else if (dtype.code == kDLUInt && dtype.bits == 8) { + return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0); + } else if (dtype.code == kDLUInt && dtype.bits == 32) { + return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0); + } + + LOG(WARNING) << "Unsupported data type (code = " << dtype.code + << ", bits = " << dtype.bits << ", lanes = " << dtype.lanes + << ")"; + return false; +} + + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index e6e008f80..f42aa7b7b 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -1,11 +1,11 @@ from tvm import relay +import numpy as np def test_fold_fwd_simple(): """Simple testcase.""" def before(x, conv_weight, in_bias, in_scale, channels): - args = [x, conv_weight, in_bias, in_scale] - in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, in_bias] in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) x = relay.multiply(x, in_scale) x = relay.nn.relu(x) @@ -18,8 +18,7 @@ def test_fold_fwd_simple(): def expected(x, conv_weight, in_bias, in_scale, channels): # use a fixed order of args so alpha equal check can pass - args = [x, conv_weight, in_bias, in_scale] - in_scale = relay.expand_dims(in_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, in_bias] in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) x = relay.nn.relu(x) @@ -38,7 +37,7 @@ def test_fold_fwd_simple(): in_channels = shape[1] weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.var("in_scale", shape=(in_channels,)) + in_scale = relay.const(np.random.uniform(size=(in_channels, 1, 1)).astype('float32')) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -56,7 +55,7 @@ def test_fold_fwd_simple(): def test_fold_fwd_dual_path(): """scale axis being consumed by two consumers""" def before(x, conv_weight, in_bias, in_scale, channels): - args = [x, conv_weight, in_bias, in_scale] + args = [x, conv_weight, in_bias] x = relay.multiply(in_scale, x) x = relay.nn.relu(x) x = relay.subtract(x, in_bias) @@ -78,7 +77,7 @@ def test_fold_fwd_dual_path(): return relay.Function(args, z) def expected(x, conv_weight, in_bias, in_scale, channels): - args = [x, conv_weight, in_bias, in_scale] + args = [x, conv_weight, in_bias] x = relay.nn.relu(x) in_bias = relay.divide(in_bias, in_scale) x = relay.subtract(x, in_bias) @@ -108,7 +107,7 @@ def test_fold_fwd_dual_path(): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.var("in_scale", shape=(in_channels,)) + in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) @@ -142,7 +141,7 @@ def test_fold_fwd_fail(): assert in_channels == channels weight = relay.var("weight") in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.var("in_scale", shape=(in_channels,)) + in_scale = relay.const(np.random.uniform(size=(in_channels,)).astype("float32")) y1 = before(x, weight, in_bias, in_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) @@ -151,11 +150,42 @@ def test_fold_fwd_fail(): check((2, 11, 10, 4), 4) +def test_fold_fwd_relu_fail(): + """testcase where we canont fold because scale can not pass relu""" + def before(x, conv_weight, in_bias, in_scale, channels): + x = relay.multiply(x, in_scale) + xx = relay.nn.relu(x) + y1 = relay.nn.conv2d(xx, conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + padding=(1, 1)) + z = relay.add(y1, x) + return relay.Function(relay.ir_pass.free_vars(z), z) + + def check(shape, channels, in_scale): + x = relay.var("x", shape=shape) + in_channels = shape[-1] + # test depthwise + assert in_channels == channels + weight = relay.var("weight") + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.var("in_scale", shape=(in_channels,)) + y1 = before(x, weight, in_bias, in_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + assert relay.ir_pass.alpha_equal(y1, y1_folded) + + in_scale = relay.var("in_scale", shape=(4,)) + check((2, 11, 10, 4), 4, in_scale) + in_scale = relay.const(np.random.uniform(size=(4,), low=-1.0, high=0.0)).astype("float32") + check((2, 11, 10, 4), 4, in_scale) + + def test_fold_bwd_simple(): """Simple testcase.""" def before(x, conv_weight, out_bias, out_scale, channels): - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y = relay.nn.conv2d(x, conv_weight, channels=channels, @@ -168,8 +198,7 @@ def test_fold_bwd_simple(): def expected(x, conv_weight, out_bias, out_scale, channels): # use a fixed order of args so alpha equal check can pass - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) conv_weight = relay.multiply( @@ -190,7 +219,7 @@ def test_fold_bwd_simple(): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.var("out_scale", shape=(channels,)) + out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -208,9 +237,7 @@ def test_fold_bwd_simple(): def test_fold_bwd_dual_path(): """Dual path testcase.""" def before(x, conv_weight, out_bias, out_scale, channels): - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), @@ -227,8 +254,7 @@ def test_fold_bwd_dual_path(): def expected(x, conv_weight, out_bias, out_scale, channels): # use a fixed order of args so alpha equal check can pass - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) def fold_conv_weight(): @@ -253,7 +279,7 @@ def test_fold_bwd_dual_path(): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.var("out_scale", shape=(channels,)) + out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -270,8 +296,7 @@ def test_fold_bwd_dual_path(): def test_fold_bwd_dual_consumer(): def before(x, conv_weight, out_bias, out_scale, channels): - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] y0 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), @@ -298,8 +323,7 @@ def test_fold_bwd_dual_consumer(): def expected(x, conv_weight, out_bias, out_scale, channels): # use a fixed order of args so alpha equal check can pass - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] def fold_conv_weight(): squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) return relay.multiply( @@ -328,7 +352,7 @@ def test_fold_bwd_dual_consumer(): in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.var("out_scale", shape=(channels,)) + out_scale = relay.const(np.random.uniform(size=(channels,1, 1)).astype("float32")) y1 = before(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) @@ -346,8 +370,7 @@ def test_fold_bwd_dual_consumer(): def test_fold_bwd_fail(): """Dual path testcase.""" def fail1(x, conv_weight, out_bias, out_scale, channels): - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, @@ -367,8 +390,7 @@ def test_fold_bwd_fail(): return relay.Function(args, y) def fail2(x, conv_weight, out_bias, out_scale, channels): - args = [x, conv_weight, out_bias, out_scale] - out_scale = relay.expand_dims(out_scale, axis=1, num_newaxis=2) + args = [x, conv_weight, out_bias] out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, @@ -380,13 +402,12 @@ def test_fold_bwd_fail(): y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels, fbefore): x = relay.var("x", shape=shape) in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.var("out_scale", shape=(channels,)) + out_scale = relay.const(np.random.uniform(size=(channels, 1, 1)).astype("float32")) y1 = fbefore(x, weight, out_bias, out_scale, channels) y1 = relay.ir_pass.infer_type(y1) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) @@ -396,11 +417,40 @@ def test_fold_bwd_fail(): check((4, 4, 10, 10), 4, fail2) +def test_fold_bwd_relu_fail(): + """testcase where we canont fold because scale can not pass relu""" + def before(x, conv_weight, out_scale, channels): + y = relay.nn.conv2d(x, conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NCHW", + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.multiply(x, out_scale) + return relay.Function(relay.ir_pass.free_vars(y), y) + + def check(shape, channels, out_scale): + x = relay.var("x", shape=shape) + in_channels = shape[1] + weight = relay.var("weight") + y1 = before(x, weight, out_scale, channels) + y1 = relay.ir_pass.infer_type(y1) + y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) + assert relay.ir_pass.alpha_equal(y1, y1_folded) + + out_scale = relay.var("in_scale", shape=(4, 1, 1)) + check((4, 4, 10, 10), 4, out_scale) + out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32") + check((4, 4, 10, 10), 4, out_scale) + + if __name__ == "__main__": test_fold_fwd_simple() test_fold_fwd_dual_path() test_fold_fwd_fail() + test_fold_fwd_relu_fail() test_fold_bwd_simple() test_fold_bwd_dual_path() test_fold_bwd_dual_consumer() test_fold_bwd_fail() + test_fold_bwd_relu_fail() -- GitLab