From 2849930465259e86de660e5a67754e76a85d32ad Mon Sep 17 00:00:00 2001
From: Siju <sijusamuel@gmail.com>
Date: Tue, 20 Nov 2018 22:50:09 +0530
Subject: [PATCH] [RELAY]Slice_like support (#2014)

---
 docs/langref/relay_op.rst             |   3 +-
 include/tvm/relay/attrs/transform.h   |  13 +++
 python/tvm/relay/op/_transform.py     |   6 +-
 python/tvm/relay/op/transform.py      |  26 +++++
 src/relay/op/tensor/transform.cc      | 147 ++++++++++++++++++++++++++
 tests/python/relay/test_op_level10.py |  62 +++++++++++
 6 files changed, 255 insertions(+), 2 deletions(-)

diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index e99ac3c97..95581a54e 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -143,6 +143,7 @@ This level support backpropagation of broadcast operators. It is temporary.
 
    tvm.relay.broadcast_to_like
    tvm.relay.collapse_sum_like
+   tvm.relay.slice_like
 
 
 Level 1 Definitions
@@ -231,7 +232,6 @@ Level 4 Definitions
 .. autofunction:: tvm.relay.strided_slice
 
 
-
 Level 5 Definitions
 -------------------
 .. autofunction:: tvm.relay.image.resize
@@ -241,3 +241,4 @@ Level 10 Definitions
 --------------------
 .. autofunction:: tvm.relay.broadcast_to_like
 .. autofunction:: tvm.relay.collapse_sum_like
+.. autofunction:: tvm.relay.slice_like
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index fc539f3ce..7a8129180 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -138,6 +138,19 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
         .describe("Stride values of the slice");
   }
 };
+
+
+struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
+  Array<Integer> axes;
+
+  TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") {
+    TVM_ATTR_FIELD(axes)
+        .describe("List of axes on which input data will be sliced according to the "
+                  "corresponding size of the second input. By default will slice "
+                  "on all axes. Negative axes mean counting in reverse.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 7867336d0..01814e0f7 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -2,7 +2,11 @@
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 from . import op as _reg
-from .op import schedule_injective
+from .op import schedule_injective, OpPattern
 
 # strided_slice
 _reg.register_schedule("strided_slice", schedule_injective)
+
+# slice_like
+_reg.register_schedule("slice_like", schedule_injective)
+_reg.register_pattern("slice_like", OpPattern.INJECTIVE)
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index e43a4a573..c5fedab05 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -361,3 +361,29 @@ def strided_slice(data, begin, end, strides=None):
     """
     strides = strides or []
     return _make.strided_slice(data, list(begin), list(end), list(strides))
+
+
+def slice_like(data, shape_like, axes=None):
+    """Slice the first input with respect to the second input.
+
+    For an input array with shape ``(d1, d2, ..., dk)``, `slice_like` operation slices the
+    the input array corresponding size of second array. By default will slice on all axes.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The source array.
+
+    shape_like : tvm.relay.Expr
+        The new shape.
+
+    axes : Optional[Tuple[int]]
+        List of axes on which input data will be sliced according to the corresponding size of
+        the second input. By default will slice on all axes. Negative axes mean counting in reverse.
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.slice_like(data, shape_like, axes)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 98ac1c30b..7a3a21511 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1153,5 +1153,152 @@ the entries indicate where along axis the array is split.
 .set_support_level(3)
 .add_type_rel("Split", SplitRel);
 
+
+TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
+
+/*!
+* \brief SliceLikeRel User defined type constraint function.
+* \param num_inputs Number of input types in the args.
+* \param attrs The additional attributes of the operator.
+* \param reporter The reporter to report solution to.
+* \return False if the relation has not been resolved, it might be resolved later.
+*  True if this relation has been resolved.
+*/
+bool SliceLikeRel(const Array<Type>& types,
+                  int num_inputs,
+                  const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+
+  const auto* target = types[1].as<TensorTypeNode>();
+  if (target == nullptr) {
+    return false;
+  }
+
+  const auto param = attrs.as<SliceLikeAttrs>();
+  CHECK(param != nullptr);
+
+  const Array<IndexExpr> dshape = data->shape;
+  const Array<IndexExpr> target_shape = target->shape;
+  std::vector<IndexExpr>&& oshape = AsVector(dshape);
+
+  if (!param->axes.defined()) {
+    for (size_t i = 0; i < dshape.size(); ++i) {
+      if (i < target_shape.size()) {
+        oshape[i] = target_shape[i];
+        CHECK(reporter->Assert(oshape[i] <= dshape[i]))
+          << "End index of axis " << i << " exceeds input shape: "
+          << oshape[i] << " vs " << dshape[i];
+      }
+    }
+  } else {
+    CHECK(param->axes.size() != 0) << "Axes cannot be empty.";
+    for (Integer val : param->axes) {
+      int axis = val->value;
+      if (axis < 0) {
+        axis += dshape.size();
+      }
+      CHECK(axis < static_cast<int>(target_shape.size()))
+        << "Axis " << axis << " exceeds dimension "
+        << target_shape.size() << " of target_shape.";
+      oshape[axis] = target_shape[axis];
+      CHECK(reporter->Assert(oshape[axis] <= dshape[axis]))
+        << "End index of axis " << axis << " exceeds input shape: "
+        << oshape[axis] << " vs " << dshape[axis];
+    }
+  }
+
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  return true;
+}
+
+
+Expr MakeSliceLike(Expr data,
+                   Expr shape_like,
+                   Array<Integer> axes) {
+  auto attrs = make_node<SliceLikeAttrs>();
+  attrs->axes = std::move(axes);
+  static const Op& op = Op::Get("slice_like");
+  return CallNode::make(op, {data, shape_like}, Attrs(attrs), {});
+}
+
+// Adapter function to make int array.
+Array<Integer> GetIntArray(Array<IndexExpr> arr) {
+  for (size_t i = 0; i < arr.size(); ++i) {
+    CHECK(!arr[i].defined() || arr[i].as<IntImm>())
+        << "Expect an int array";
+  }
+  return Array<Integer>(arr.node_);
+}
+
+template<typename AttrType>
+Array<Tensor> SliceLikeCompute(const Attrs& attrs,
+                               const Array<Tensor>& inputs,
+                               const Type& out_type,
+                               const Target& target) {
+  const auto* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  Array<IndexExpr> src_shape = inputs[0]->shape;
+  Array<IndexExpr> target_shape = inputs[1]->shape;
+  Array<IndexExpr> begin_idx, end_idx, strides;
+  for (size_t i = 0; i < src_shape.size(); ++i) {
+    begin_idx.push_back(0);
+    strides.push_back(1);
+  }
+  end_idx = Array<IndexExpr>(src_shape);
+  if (!param->axes.defined()) {
+    for (size_t i = 0; i < src_shape.size(); ++i) {
+      if (i < target_shape.size()) {
+        end_idx.Set(i, target_shape[i]);
+        CHECK_LE(topi::GetConstInt(end_idx[i]),
+                 topi::GetConstInt(src_shape[i]))
+          << "End index of axis " << i << " exceeds input shape: "
+          << topi::GetConstInt(end_idx[i]) << " vs "
+          << topi::GetConstInt(src_shape[i]);
+      }
+    }
+  } else {
+    for (int axis : param->axes) {
+      if (axis < 0) {
+        axis = static_cast<int>(src_shape.size()) + axis;
+      }
+      end_idx.Set(axis, target_shape[axis]);
+      CHECK_LE(topi::GetConstInt(end_idx[axis]),
+               topi::GetConstInt(src_shape[axis]))
+        << "End index of axis " << axis << " exceeds input shape: "
+        << topi::GetConstInt(end_idx[axis]) << " vs "
+        << topi::GetConstInt(src_shape[axis]);
+    }
+  }
+  return Array<Tensor>{
+    topi::strided_slice(inputs[0],
+                        GetIntArray(begin_idx),
+                        GetIntArray(end_idx),
+                        GetIntArray(strides))
+  };
+}
+
+
+TVM_REGISTER_API("relay.op._make.slice_like")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
+});
+
+
+RELAY_REGISTER_OP("slice_like")
+.describe(R"code(Slice the first input respect to the second input.
+)code" TVM_ADD_FILELINE)
+  .set_attrs_type_key("relay.attrs.SlicelikeAttrs")
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("shape_like", "Tensor", "Shape tensor.")
+.set_support_level(10)
+.add_type_rel("SliceLike", SliceLikeRel)
+.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute<SliceLikeAttrs>);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py
index 9486d0298..ef1c57d26 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -1,7 +1,9 @@
 """ Support level10 operator test cases.
 """
+import numpy as np
 import tvm
 from tvm import relay
+from tvm.relay.testing import ctx_list
 
 def test_collapse_sum_like():
     x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
@@ -18,6 +20,66 @@ def test_broadcast_to_like():
     zz = relay.ir_pass.infer_type(z)
     assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8")
 
+
+def np_slice_like(np_data, np_shape_like, axis=None):
+    begin_idx = [0 for _ in np_data.shape]
+    end_idx = list(np_data.shape)
+    if axis:
+        for i in axis:
+            if i < 0:
+                i = len(np_data.shape) + i
+            end_idx[i] = np_shape_like.shape[i]
+    else:
+        for i in range(len(np_data.shape)):
+            if i < len(np_shape_like.shape):
+                end_idx[i] = np_shape_like.shape[i]
+    slice_idx = []
+    for b, e in zip(begin_idx, end_idx):
+        slice_idx.append(slice(b, e))
+    np_result = np_data[tuple(slice_idx)]
+    return np_result
+
+
+def verify_slice_like(data, slice_like, axes, output, dtype="float32"):
+    x = relay.var("data", relay.TensorType(data, dtype))
+    y = relay.var("slice_like", relay.TensorType(slice_like, dtype))
+    z = relay.slice_like(x, y, axes)
+    zz = relay.ir_pass.infer_type(z)
+    if axes:
+        assert "axes" in z.astext()
+    assert zz.checked_type == relay.ty.TensorType(output, dtype)
+
+    if all(isinstance(v, int) == 0 for v in data) or \
+        all(isinstance(v, int) == 0 for v in slice_like):
+        return
+
+    func = relay.Function([x, y], z)
+    x_data = np.random.uniform(size=data).astype(dtype)
+    y_data = np.random.uniform(size=slice_like).astype(dtype)
+    ref_res = np_slice_like(x_data, y_data, axes)
+
+    for target, ctx in ctx_list():
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(x_data, y_data)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+def test_slice_like():
+    d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
+    verify_slice_like(data=(d1, d2, d3), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
+    verify_slice_like(data=(1, 2, 3), slice_like=(d1, d2, d3), axes=None, output=(d1, d2, d3))
+    verify_slice_like(data=(d2, d3, d4), slice_like=(d1, d2, d3), axes=(1,2), output=(d2, d2, d3))
+    verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=None, output=(1, 2, 3))
+    verify_slice_like(data=(3, 4, 5), slice_like=(1, 2), axes=None, output=(1, 2, 5))
+    verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(1, 2), output=(3, 2, 3))
+    verify_slice_like(data=(3, 4, 5), slice_like=(1, 2, 3), axes=(-1, -3), output=(1, 4, 3))
+    verify_slice_like(data=(1, 3, 224, 224),
+                      slice_like=(1, 3, 112, 112),
+                      axes=(2, 3),
+                      output=(1, 3, 112, 112))
+
+
 if __name__ == "__main__":
     test_collapse_sum_like()
     test_broadcast_to_like()
+    test_slice_like()
-- 
GitLab