From 201cfdc59a7cc8f1ef2d930df8eb97180775ba3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?=
 <lolisa@marisa.moe>
Date: Mon, 15 Oct 2018 09:46:21 -0700
Subject: [PATCH] [Relay] [Op] Squeeze (#1858)

---
 include/tvm/relay/attrs/transform.h  | 14 ++++++
 python/tvm/relay/op/transform.py     | 25 +++++++++-
 src/relay/op/tensor/transform.cc     | 72 +++++++++++++++++++++++++++-
 tests/python/relay/test_op_level3.py | 41 ++++++++++++++++
 4 files changed, 149 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index 278826bc8..d304a5956 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -82,6 +82,20 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
   }
 };  // struct InitOpAttrs
 
+/*! \brief Attributes used in squeeze operators */
+struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
+  Array<IndexExpr> axes;
+
+  TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") {
+    TVM_ATTR_FIELD(axes)
+        .describe("The axes to squeeze in the input tensor."
+                  "If `axes = []`, all axis of dimension 1 get squeezed;"
+                  "Else, the dimension in axes get squeezed."
+                  "It is an error if an axes does not has dimension 1.")
+        .set_default(Array<IndexExpr>({}));
+  }
+};  // struct SqueezeAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 75fbba846..c2036f509 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -42,12 +42,35 @@ def transpose(data, axes=None):
     Returns
     -------
     result : relay.Expr
-        The reshaped result.
+        The transposed result.
     """
     axes = axes or []
     return _make.transpose(data, list(axes))
 
 
+def squeeze(data, axes=None):
+    """Squeeze axes in the array.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    axes : None or List[int]
+        Axes to remove.
+        If axes = [] or = None, remove all axis of dimensions 1.
+        Otherwise, remove all axis in axes.
+        If any axis in axes has dimension that does not equal 1, it is an error.
+
+    Returns
+    -------
+    result : relay.Expr
+        The squeezed result.
+    """
+    axes = axes or []
+    return _make.squeeze(data, list(axes))
+
+
 def reshape(data, newshape):
     """Reshapes the input array.
 
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index fb7b09fd3..956883476 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -80,8 +80,6 @@ RELAY_REGISTER_OP("expand_dims")
 .set_support_level(1)
 .add_type_rel("ExpandDims", ExpandDimsRel);
 
-/* relay.concatenate */
-
 TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
 
 bool ConcatenateRel(const Array<Type>& types,
@@ -633,5 +631,75 @@ Examples::
 .set_support_level(4)
 .add_type_rel("Where", WhereRel);
 
+Expr MakeSqueeze(Expr data,
+                 Array<IndexExpr> axes) {
+  auto attrs = make_node<SqueezeAttrs>();
+  attrs->axes = std::move(axes);
+  static const Op& op = Op::Get("squeeze");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.squeeze")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
+  });
+
+bool SqueezeRel(const Array<Type>& types,
+                int num_inputs,
+                const Attrs& attrs,
+                const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+  const auto* param = attrs.as<SqueezeAttrs>();
+  CHECK(param != nullptr);
+  std::vector<IndexExpr> result_shape;
+  // if axes is empty, squeeze all axes of dimension 1
+  if (param->axes.size() == 0) {
+    for (const auto& e : data->shape) {
+      const int64_t* axis_ptr = as_const_int(e);
+      CHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
+      if (*axis_ptr != 1) {
+        result_shape.push_back(e);
+      }
+    }
+  } else {
+    // pair up original shape with a boolean which control whether it will be in the final shape.
+    std::vector<std::pair<IndexExpr, bool> > original_shape;
+    for (const auto& e : data->shape) {
+      original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
+    }
+    for (const auto& e : param->axes) {
+      const int64_t* axis_ptr = as_const_int(e);
+      CHECK(axis_ptr != nullptr);
+      original_shape.at(*axis_ptr).second = false;
+    }
+    for (const auto p : original_shape) {
+      if (p.second) {
+        result_shape.push_back(p.first);
+      } else {
+        const int64_t* axis_ptr = as_const_int(p.first);
+        CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input tensor";
+        CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
+      }
+    }
+  }
+  reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
+  return true;
+}
+
+RELAY_REGISTER_OP("squeeze")
+.describe(R"code(Squeeze the input tensor at the dimensions given by axes
+
+- **data**: The input data to the operator.
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(3)
+.add_type_rel("Squeeze", SqueezeRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 7d949b210..13ab483f9 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -6,6 +6,7 @@ from tvm import relay
 from tvm.relay.ir_pass import infer_type
 from tvm.relay.ir_builder import IRBuilder, func_type
 from tvm.relay.env import Environment
+from nose.tools import raises
 
 def test_zeros_ones():
     for op in [relay.zeros, relay.ones]:
@@ -67,6 +68,44 @@ def test_transpose_infer_type():
         (t, n, 100), "float32")
 
 
+def test_squeeze_default_axes_infer_type():
+    ib = relay.ir_builder.IRBuilder()
+    n, t, d = 1, 4, 1
+    x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
+    with ib.function(x) as func:
+        ib.ret(relay.squeeze(x))
+    ib.ret(func)
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TensorType(
+        (4,), "float32")
+
+
+def test_squeeze_axes_infer_type():
+    ib = relay.ir_builder.IRBuilder()
+    n, t, d = 1, 4, 1
+    x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
+    with ib.function(x) as func:
+        ib.ret(relay.squeeze(x, axes=(2,)))
+    ib.ret(func)
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+    assert ftype.ret_type == relay.ty.TensorType(
+        (1, 4), "float32")
+
+
+@raises(tvm._ffi.base.TVMError)
+def test_squeeze_bad_axes_infer_type():
+    ib = relay.ir_builder.IRBuilder()
+    n, t, d = 1, 4, 1
+    x = ib.param("x", relay.ty.TensorType((n, t, d), "float32"))
+    with ib.function(x) as func:
+        ib.ret(relay.squeeze(x, axes=(1,)))
+    ib.ret(func)
+    func = relay.ir_pass.infer_type(ib.env, func.to_func())
+    ftype = func.checked_type
+
+
 def test_reshape_infer_type():
     ib = relay.ir_builder.IRBuilder()
     n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20
@@ -181,3 +220,5 @@ if __name__ == "__main__":
     test_take_infer_type()
     test_full()
     test_full_like()
+    test_squeeze_axes_infer_type()
+    test_squeeze_default_axes_infer_type()
-- 
GitLab