diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 09af612b957b83e8abea4f07934eb9031c627f95..245b38cfb63d3ae7380a0443962762c41dbd8361 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -475,6 +475,11 @@ inline Array<Tensor> split_sections(const Tensor& x, int axis, std::string name = "tensor", std::string tag = kInjective) { + if (axis < 0) { + axis += static_cast<int>(x->shape.size()); + } + CHECK_LT(axis, x->shape.size()) << "axis out of bounds"; + auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis])); CHECK_GT(num_sections, 0) << "Slice count must be > 0"; diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index c8b7c3906caaf061b8fd2a5b6c07f90f561bebfb..3f7bdbfdd499da894854ceb750cc2b1e093cfadd 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -340,6 +340,7 @@ def test_concatenate(): def test_split(): verify_split((2, 12, 3), 3, 1) + verify_split((2, 12, 3), 3, -1) verify_split((2, 12, 3), [2, 4], 1) verify_split((10, 12, 24), [5, 7, 9], -1)