From 3ff2d9583d1b371d94970b8211075314db376e2e Mon Sep 17 00:00:00 2001 From: Xingjian Shi <xshiab@ust.hk> Date: Mon, 15 Jan 2018 12:39:30 -0800 Subject: [PATCH] try to fix test (#784) try to fix fix --- topi/tests/python/test_topi_reduce.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 08e66e140..c8d95df25 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -54,7 +54,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): with tvm.target.create(device): s = topi.generic.schedule_reduce(B) ctx = tvm.context(device, 0) - foo = tvm.build(s, [A, B], device, name="sum") + foo = tvm.build(s, [A, B], device, name=type) # Test in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32) @@ -74,6 +74,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) for _ in range(1): foo(data_tvm, out_tvm) + if type == "argmax" or type == "argmin": + out_tvm_indices = out_tvm.asnumpy() + if keepdims: + out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) + if axis is None: + out_tvm_val = in_npy_map.ravel()[out_tvm_indices] + else: + other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):])) + sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:] + out_tvm_val = in_npy_map[sel_indices] + if type == "argmax": + np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3) + elif type == "argmin": + np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3) + np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) for device in ["cuda", "opencl", "metal", "llvm", "rocm"]: check_device(device) -- GitLab