diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index ac408ba633b0d22e5eb89f2b811cb4c2fc701ee2..10fb5bc478cb098972d17ea45615bb8e5b82202d 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -89,6 +89,18 @@ inline Tensor transpose(const Tensor& x, } auto axes_val = GetConstIntValues(axes, "axes"); + for (size_t i = 0; i < axes_val.size(); ++i) { + int axis = axes_val[i]; + if (axes_val[i] < 0) { + axes_val[i] = static_cast<int>(x->shape.size()) + axes_val[i]; + } + CHECK((0 <= axes_val[i]) && (axes_val[i] < static_cast<int>(x->shape.size()))) + << "axis=" << axis << " is invalid for the " + << static_cast<int>(x->shape.size()) << "-dimensional input tensor"; + + CHECK(1 == std::count(std::begin(axes_val), std::end(axes_val), axes_val[i])) + << "repeated axis in transpose"; + } Array<Expr> new_shape; for (size_t i = 0; i < axes_val.size(); ++i) { diff --git a/topi/tests/python_cpp/test_topi_transform.py b/topi/tests/python_cpp/test_topi_transform.py index f1355bae234146c586c013cfd297814665bff4c2..a94fc89328af9f86eb38f621a948efbaf1c6fd28 100644 --- a/topi/tests/python_cpp/test_topi_transform.py +++ b/topi/tests/python_cpp/test_topi_transform.py @@ -281,6 +281,7 @@ def test_tranpose(): verify_tranpose((3, 10, 2), (1, 0, 2)) verify_tranpose((3, 10, 5), (2, 0, 1)) verify_tranpose((3, 10), None) + verify_tranpose((3, 10, 5), (2, -3, 1)) def test_reshape():