diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 5a0ea067d433a91fa4c1c2861ea0620b596e0ef1..fe4295feb98284ea670486f998f68b9f0e351722 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -154,6 +154,8 @@ class LinearAccessPatternFinder final : public IRVisitor { in_thread_env_ = false; } else if (op->attr_key == attr::extern_scope) { VisitNewScope(op); + } else if (op->attr_key == attr::virtual_thread) { + VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as<Variable>(); alloc_info_[buf].storage_scope = @@ -395,11 +397,10 @@ class StoragePlanRewriter : public IRMutator { } Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { - CHECK(op->attr_key != attr::virtual_thread) - << "InjectVirtualThread before StoragePlan"; if (op->attr_key == attr::storage_scope) { return this->Mutate(op->body); } else if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread || op->attr_key == attr::pragma_scope) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { @@ -481,11 +482,13 @@ class StoragePlanRewriter : public IRMutator { Stmt body) { std::vector<Stmt> nest; for (StorageEntry* e : svec) { - nest.emplace_back(AttrStmt::make( - e->alloc_var, attr::storage_scope, - StringImm::make(e->scope.to_string()), - Evaluate::make(0))); - nest.push_back(e->new_alloc); + if (e->new_alloc.defined()) { + nest.emplace_back(AttrStmt::make( + e->alloc_var, attr::storage_scope, + StringImm::make(e->scope.to_string()), + Evaluate::make(0))); + nest.push_back(e->new_alloc); + } } return MergeNest(nest, body); } @@ -716,7 +719,8 @@ class StoragePlanRewriter : public IRMutator { if (s.stmt->is_type<AttrStmt>()) { const auto* op = static_cast<const AttrStmt*>(s.stmt); if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pragma_scope) { + op->attr_key == attr::pragma_scope || + op->attr_key == attr::virtual_thread) { PlanNewScope(op); } else { CHECK(op->attr_key == attr::extern_scope); diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py index adfe18ba4ef9c4ba875c37bc80a0f08df9de9b79..b4c630395f6054bd7d913032e601b697e6b8c2bc 100644 --- a/topi/tests/python_cpp/test_topi_reduce.py +++ b/topi/tests/python_cpp/test_topi_reduce.py @@ -77,7 +77,22 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) for _ in range(1): foo(data_tvm, out_tvm) - np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) + if type == "argmax" or type == "argmin": + out_tvm_indices = out_tvm.asnumpy() + if keepdims: + out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis) + if axis is None: + out_tvm_val = in_npy_map.ravel()[out_tvm_indices] + else: + other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):])) + sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:] + out_tvm_val = in_npy_map[sel_indices] + if type == "argmax": + np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3) + elif type == "argmin": + np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3) + else: + np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3) for device in ["cuda", "opencl", "metal", "llvm", "rocm"]: check_device(device)