From 81b42bc3716b80e821651d3dcd6942a653e57e9f Mon Sep 17 00:00:00 2001
From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com>
Date: Tue, 14 Aug 2018 01:43:09 +0530
Subject: [PATCH] Split_indices negative axis added (#1595)

---
 topi/include/topi/transform.h                | 5 +++++
 topi/tests/python_cpp/test_topi_transform.py | 1 +
 2 files changed, 6 insertions(+)

diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h
index 09af612b9..245b38cfb 100644
--- a/topi/include/topi/transform.h
+++ b/topi/include/topi/transform.h
@@ -475,6 +475,11 @@ inline Array<Tensor> split_sections(const Tensor& x,
                            int axis,
                            std::string name = "tensor",
                            std::string tag = kInjective) {
+  if (axis < 0) {
+    axis += static_cast<int>(x->shape.size());
+  }
+  CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
+
   auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
 
   CHECK_GT(num_sections, 0) << "Slice count must be > 0";
diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py
index c8b7c3906..3f7bdbfdd 100644
--- a/topi/tests/python_cpp/test_topi_transform.py
+++ b/topi/tests/python_cpp/test_topi_transform.py
@@ -340,6 +340,7 @@ def test_concatenate():
 
 def test_split():
     verify_split((2, 12, 3), 3, 1)
+    verify_split((2, 12, 3), 3, -1)
     verify_split((2, 12, 3), [2, 4], 1)
     verify_split((10, 12, 24), [5, 7, 9], -1)
 
-- 
GitLab