From 211ab97836763ada19e292ee9230fddbcbfa5dc2 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY <32511895+ANSHUMAN87@users.noreply.github.com> Date: Tue, 3 Jul 2018 08:20:00 +0530 Subject: [PATCH] Transpose core dump resolved (#1355) --- topi/include/topi/transform.h | 12 ++++++++++++ topi/tests/python_cpp/test_topi_transform.py | 1 + 2 files changed, 13 insertions(+) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index ac408ba63..10fb5bc47 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -89,6 +89,18 @@ inline Tensor transpose(const Tensor& x, } auto axes_val = GetConstIntValues(axes, "axes"); + for (size_t i = 0; i < axes_val.size(); ++i) { + int axis = axes_val[i]; + if (axes_val[i] < 0) { + axes_val[i] = static_cast<int>(x->shape.size()) + axes_val[i]; + } + CHECK((0 <= axes_val[i]) && (axes_val[i] < static_cast<int>(x->shape.size()))) + << "axis=" << axis << " is invalid for the " + << static_cast<int>(x->shape.size()) << "-dimensional input tensor"; + + CHECK(1 == std::count(std::begin(axes_val), std::end(axes_val), axes_val[i])) + << "repeated axis in transpose"; + } Array<Expr> new_shape; for (size_t i = 0; i < axes_val.size(); ++i) { diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index f1355bae2..a94fc8932 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -281,6 +281,7 @@ def test_tranpose(): verify_tranpose((3, 10, 2), (1, 0, 2)) verify_tranpose((3, 10, 5), (2, 0, 1)) verify_tranpose((3, 10), None) + verify_tranpose((3, 10, 5), (2, -3, 1)) def test_reshape(): -- GitLab