diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 5a0ea067d433a91fa4c1c2861ea0620b596e0ef1..fe4295feb98284ea670486f998f68b9f0e351722 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -154,6 +154,8 @@ class LinearAccessPatternFinder final : public IRVisitor {
       in_thread_env_ = false;
     } else if (op->attr_key == attr::extern_scope) {
       VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
     } else if (op->attr_key == attr::storage_scope) {
       const Variable* buf = op->node.as<Variable>();
       alloc_info_[buf].storage_scope =
@@ -395,11 +397,10 @@ class StoragePlanRewriter : public IRMutator {
   }
 
   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
-    CHECK(op->attr_key != attr::virtual_thread)
-        << "InjectVirtualThread before StoragePlan";
     if (op->attr_key == attr::storage_scope) {
       return this->Mutate(op->body);
     } else if (op->attr_key == attr::thread_extent ||
+               op->attr_key == attr::virtual_thread ||
                op->attr_key == attr::pragma_scope) {
       // remake all the allocation at the attach scope.
       if (attach_map_.count(op)) {
@@ -481,11 +482,13 @@ class StoragePlanRewriter : public IRMutator {
                   Stmt body) {
     std::vector<Stmt> nest;
     for (StorageEntry* e : svec) {
-      nest.emplace_back(AttrStmt::make(
-          e->alloc_var, attr::storage_scope,
-          StringImm::make(e->scope.to_string()),
-          Evaluate::make(0)));
-      nest.push_back(e->new_alloc);
+      if (e->new_alloc.defined()) {
+        nest.emplace_back(AttrStmt::make(
+            e->alloc_var, attr::storage_scope,
+            StringImm::make(e->scope.to_string()),
+            Evaluate::make(0)));
+        nest.push_back(e->new_alloc);
+      }
     }
     return MergeNest(nest, body);
   }
@@ -716,7 +719,8 @@ class StoragePlanRewriter : public IRMutator {
       if (s.stmt->is_type<AttrStmt>()) {
         const auto* op = static_cast<const AttrStmt*>(s.stmt);
         if (op->attr_key == attr::thread_extent ||
-            op->attr_key == attr::pragma_scope) {
+            op->attr_key == attr::pragma_scope ||
+            op->attr_key == attr::virtual_thread) {
           PlanNewScope(op);
         } else {
           CHECK(op->attr_key == attr::extern_scope);
diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py
index adfe18ba4ef9c4ba875c37bc80a0f08df9de9b79..b4c630395f6054bd7d913032e601b697e6b8c2bc 100644
--- a/topi/tests/python_cpp/test_topi_reduce.py
+++ b/topi/tests/python_cpp/test_topi_reduce.py
@@ -77,7 +77,22 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
         out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
         for _ in range(1):
             foo(data_tvm, out_tvm)
-        np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
+        if type == "argmax" or type == "argmin":
+            out_tvm_indices = out_tvm.asnumpy()
+            if keepdims:
+                out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
+            if axis is None:
+                out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
+            else:
+                other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis+1):]))
+                sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
+                out_tvm_val = in_npy_map[sel_indices]
+            if type == "argmax":
+                np.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1E-3, 1E-3)
+            elif type == "argmin":
+                np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
+        else:
+            np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
     for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
         check_device(device)