From 40ac206438a1ea7098ce7e6a553835c54b0ed99d Mon Sep 17 00:00:00 2001
From: Dayananda V <dayanandasiet@gmail.com>
Date: Wed, 4 Jul 2018 21:32:25 +0530
Subject: [PATCH] add take frontend (#1307)

---
 nnvm/include/nnvm/top/tensor.h                |  10 ++
 nnvm/python/nnvm/top/transform.py             |   4 +
 nnvm/src/top/tensor/transform.cc              | 120 ++++++++++++++++++
 nnvm/tests/python/compiler/test_top_level1.py |  35 +++++
 4 files changed, 169 insertions(+)

diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h
index 87128a663..22ee9d711 100644
--- a/nnvm/include/nnvm/top/tensor.h
+++ b/nnvm/include/nnvm/top/tensor.h
@@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
   }
 };
 
+
+struct TakeParam : public dmlc::Parameter<TakeParam> {
+  dmlc::optional<int> axis;
+
+  DMLC_DECLARE_PARAMETER(TakeParam) {
+    DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>())
+        .describe("the axis over which to select values.");
+  }
+};
+
 struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> {
   // numpy convention, only support indices, not support list.
   Tuple<int64_t> begin;
diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py
index c87b4735d..b5e00f012 100644
--- a/nnvm/python/nnvm/top/transform.py
+++ b/nnvm/python/nnvm/top/transform.py
@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
 reg.register_pattern("split", OpPattern.INJECTIVE)
 reg.register_schedule("split", _fschedule_injective)
 
+# take
+reg.register_pattern("take", OpPattern.INJECTIVE)
+reg.register_schedule("take", _fschedule_injective)
+
 # strided_slice
 reg.register_pattern("strided_slice", OpPattern.INJECTIVE)
 reg.register_schedule("strided_slice", _fschedule_injective)
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 72e49a040..5bb2ec137 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -1001,6 +1001,126 @@ Examples::
     return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
 });
 
+
+// take
+DMLC_REGISTER_PARAMETER(TakeParam);
+
+inline bool TakeInferShape(const NodeAttrs& attrs,
+                           std::vector<TShape>* in_shape,
+                           std::vector<TShape>* out_shape) {
+  CHECK_EQ(in_shape->size(), 2U);
+  CHECK_EQ(out_shape->size(), 1U);
+  const TShape& dshape = (*in_shape)[0];
+  const TShape& indicesshape = (*in_shape)[1];
+  if (dshape.ndim() == 0) return false;
+  if (indicesshape.ndim() == 0) return false;
+
+  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+  TShape oshape((!param.axis ? 0: dshape.ndim() - 1) + indicesshape.ndim());
+  if (!param.axis) {
+    for (size_t j = 0; j < indicesshape.ndim(); ++j) {
+      oshape[j] = indicesshape[j];
+    }
+  } else {
+    int axis = param.axis.value();
+    if (axis < 0) {
+      axis += dshape.ndim();
+    }
+    CHECK_LT(axis, dshape.ndim());
+
+    size_t posi = 0;
+    for (size_t i = 0; i < dshape.ndim(); ++i) {
+      if (static_cast<int>(i) == axis) {
+        for (size_t j = 0; j < indicesshape.ndim(); ++j) {
+          oshape[posi++] = indicesshape[j];
+        }
+      } else {
+        oshape[posi++] = dshape[i];
+      }
+    }
+  }
+  NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
+  NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, indicesshape);
+  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
+  return dshape.Size() != 0;
+}
+
+inline bool TakeInferType(const NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  CHECK_EQ((*in_attrs)[1], kInt32);
+  NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 0, (*in_attrs)[0]);
+  NNVM_ASSIGN_INPUT_TYPE(attrs, *in_attrs, 1, static_cast<int>(kInt32));
+  NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_attrs, 0, (*in_attrs)[0]);
+  return true;
+}
+
+inline bool TakeCorrectLayout(const NodeAttrs& attrs,
+                              std::vector<Layout> *ilayouts,
+                              const std::vector<Layout> *last_ilayouts,
+                              std::vector<Layout> *olayouts) {
+  CHECK_EQ(ilayouts->size(), last_ilayouts->size());
+  CHECK_EQ(olayouts->size(), 1U);
+
+  for (size_t i = 0; i < ilayouts->size(); ++i) {
+    const Layout& input = last_ilayouts->at(i).defined() ?
+                          last_ilayouts->at(i) : ilayouts->at(i);
+    NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
+  }
+
+  return true;
+}
+
+NNVM_REGISTER_OP(take)
+.describe(R"code(Take elements from an array along an axis.
+
+When axis is not None, this function does the same thing as 'fancy' indexing
+(indexing arrays using arrays); however, it can be easier to use if you need
+elements along a given axis.
+
+**Note** that when axis is none the flattened input array is used.
+
+Examples::
+
+  a = [[ 1, 2],
+       [ 3, 4]]
+  indices = [3, 0, 2]
+  take(a, indices) = [ 4, 1, 3]
+
+  a = [[ 1., 2.],
+       [ 3., 4.]]
+  indices = [1, 0]
+  take(a, indices, axis=1) = [[ 2., 1.],
+                              [ 4., 3.]]
+
+  )code" NNVM_ADD_FILELINE)
+.add_argument("data", "Tensor", "Array to be indexed")
+.add_argument("indices", "Tensor", "The indices of the values to extract")
+.add_arguments(TakeParam::__FIELDS__())
+.set_attr_parser(ParamParser<TakeParam>)
+.set_attr<FInferShape>("FInferShape", TakeInferShape)
+.set_attr<FInferType>("FInferType", TakeInferType)
+.set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_support_level(1)
+.set_attr<FTVMCompute>(
+    "FTVMCompute", [](const NodeAttrs& attrs,
+                      const Array<Tensor>& inputs,
+                      const Array<Tensor>& out_info) {
+      const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+      if (!param.axis) {
+        return Array<Tensor>{
+            topi::take(inputs[0], inputs[1]) };
+      } else {
+        return Array<Tensor>{
+            topi::take(inputs[0], inputs[1], param.axis.value()) };
+      }
+  });
+
+
 // SliceLike
 DMLC_REGISTER_PARAMETER(SliceLikeParam);
 
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index b97aff8ef..d9c6655fe 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -365,6 +365,40 @@ def test_strided_slice():
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
     verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
 
+def verify_take(src_shape, indices_src, axis=None):
+    src_dtype = "float32"
+    indices_dtype = "int32"
+    indices_src = np.array(indices_src, dtype=indices_dtype)
+    a = sym.Variable("a")
+    indices = sym.Variable("indices")
+    y = sym.take(a, indices, axis=axis)
+    for target, ctx in ctx_list():
+        # set input
+        shape_dict = {"a":src_shape, "indices":indices_src.shape}
+        type_dict = {"a":src_dtype, "indices":indices_dtype}
+        graph, lib, _ = nnvm.compiler.build(y, target, shape=shape_dict, dtype=type_dict)
+        m = graph_runtime.create(graph, lib, ctx)
+
+        shape_size = 1
+        for i in range(len(src_shape)):
+            shape_size = shape_size * src_shape[i]
+        a_src = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
+        out_np = np.take(a_src, indices_src, axis=axis)
+        m.run(a=a_src, indices=indices_src)
+        out = m.get_output(0, tvm.nd.empty(out_np.shape, dtype=src_dtype))
+        np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
+
+def test_take():
+    verify_take((4,), [1])
+    verify_take((4,), [[0,1,2,3]])
+    verify_take((3,3,3), [[11,25]])
+    verify_take((4,), [[0,1],[2,3]])
+    verify_take((4,), [1], 0)
+    verify_take((2,2), [[[1,0],[0,1]]], 0)
+    verify_take((2,2), [[[1,0],[0,1]]], 1)
+    verify_take((4,3,5,6), [[2,1,0,0]], -2)
+
+
 def verify_squeeze(dshape, axis):
     x = sym.Variable("x")
     if axis:
@@ -481,6 +515,7 @@ if __name__ == "__main__":
     test_softmax()
     test_squeeze()
     test_pad()
+    test_take()
     test_lrn()
     test_l2_normalize()
     test_strided_slice()
-- 
GitLab