diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc
index eba80b955d963b39d333f8fb861c9dee3f0e1dfc..6d7f4bdf75331ed462b5103eb7bb04c05bdc23d5 100644
--- a/apps/extension/src/tvm_ext.cc
+++ b/apps/extension/src/tvm_ext.cc
@@ -60,4 +60,9 @@ TVM_REGISTER_GLOBAL("tvm_ext.sym_add")
     Var b = args[1];
     *rv = a + b;
   });
+
+TVM_REGISTER_GLOBAL("device_api.ext_dev")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+    *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
+  });
 }  // namespace tvm_ext
diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py
index 1cd867d4bb54931c66d393a130f2f0eff8b9abe6..0bbfff14eeef42b5fcb9316cb03104b5f9b5ac91 100644
--- a/apps/extension/tests/test_ext.py
+++ b/apps/extension/tests/test_ext.py
@@ -1,5 +1,6 @@
 import tvm_ext
 import tvm
+import numpy as np
 
 def test_bind_add():
     def add(a, b):
@@ -7,6 +8,24 @@ def test_bind_add():
     f = tvm_ext.bind_add(add, 1)
     assert f(2)  == 3
 
+def test_ext_dev():
+    n = 10
+    A = tvm.placeholder((n,), name='A')
+    B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
+    s = tvm.create_schedule(B.op)
+    def check_llvm():
+        if not tvm.module.enabled("llvm"):
+            return
+        f = tvm.build(s, [A, B], "ext_dev", "llvm")
+        ctx = tvm.ext_dev(0)
+        # launch the kernel.
+        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
+        f(a, b)
+        np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
+    check_llvm()
+
+
 def test_sym_add():
     a = tvm.var('a')
     b = tvm.var('b')
@@ -26,6 +45,7 @@ def test_ext_vec():
     tvm.convert(ivec_cb)(ivec)
 
 if __name__ == "__main__":
+    test_ext_dev()
     test_ext_vec()
     test_bind_add()
     test_sym_add()
diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h
index 5bf5b72492c4b96351dd417b94eeaf6661b9d1fd..91175f671a5670abca95b18da705865e7d49621a 100644
--- a/include/tvm/runtime/c_runtime_api.h
+++ b/include/tvm/runtime/c_runtime_api.h
@@ -55,6 +55,9 @@ typedef int64_t tvm_index_t;
 
 /*! \brief Extension device types in TVM */
 typedef enum {
+  // Extension DRAM type, used for quickly test extension device
+  // The device api can differ depending on the xpu driver registered.
+  kExtDev = 12
   // AddExtraTVMType which is not in DLPack here
 } TVMDeviceExtType;
 
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 4bc20aac7dd6c109e36952af72a503f1a56fe125..90ac45988bcb9752723df54030a9a2bfbdfe0129 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -17,7 +17,7 @@ from . import ir_builder
 from . import target
 
 from . import ndarray as nd
-from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm
+from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, ext_dev
 
 from ._ffi.runtime_ctypes import TypeCode
 from ._ffi.function import Function
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 9ea8ef579e10b9d92a86917228b6a5ea699db195..596e13f1f3fe10f2acfa8847c9fde3083bf203cb 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -96,7 +96,8 @@ class TVMContext(ctypes.Structure):
         4 : 'opencl',
         8 : 'metal',
         9 : 'vpi',
-        10: 'rocm'
+        10: 'rocm',
+        12: 'ext_dev',
     }
     STR2MASK = {
         'cpu': 1,
@@ -106,7 +107,8 @@ class TVMContext(ctypes.Structure):
         'opencl': 4,
         'metal': 8,
         'vpi': 9,
-        'rocm': 10
+        'rocm': 10,
+        'ext_dev': 12,
     }
     def __init__(self, device_type, device_id):
         super(TVMContext, self).__init__()
diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py
index d08aa8ad0e854f9e1825f84fccb56b32905048a5..621f8a72727477bc0b91846a7279bf609fe49ab9 100644
--- a/python/tvm/build_module.py
+++ b/python/tvm/build_module.py
@@ -345,7 +345,7 @@ def build(sch,
         else:
             raise ValueError("unknown function type %d" % func.func_type)
 
-    if not target.startswith("llvm") and target != "stackvm" and not fdevice:
+    if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
         warnings.warn(
             "Specified target %s, but cannot find device code, did you do bind?" % target)
 
diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py
index 3ad77cb3474398490cb64a8fd48d7d2d43046606..7b29b1ddac01c76b4b629fdb4918797a75594b01 100644
--- a/python/tvm/contrib/rpc.py
+++ b/python/tvm/contrib/rpc.py
@@ -247,6 +247,10 @@ class RPCSession(object):
         """Construct remote Metal device."""
         return self.context(8, dev_id)
 
+    def ext_dev(self, dev_id=0):
+        """Construct remote extension device."""
+        return self.context(12, dev_id)
+
     def upload(self, data, target=None):
         """Upload file to remote runtime temp folder
 
diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py
index 578f63f08d5d4c388c82f1b3cd0fd613989bc8b3..1556c4912a3520198da91ef9ae772aaf2232af75 100644
--- a/python/tvm/ndarray.py
+++ b/python/tvm/ndarray.py
@@ -120,6 +120,27 @@ def vpi(dev_id=0):
     """
     return TVMContext(9, dev_id)
 
+def ext_dev(dev_id=0):
+    """Construct a extension device
+
+    Parameters
+    ----------
+    dev_id : int, optional
+        The integer device id
+
+    Returns
+    -------
+    ctx : TVMContext
+        The created context
+
+    Note
+    ----
+    This API is reserved for quick testing of new
+    device by plugin device API as ext_dev.
+    """
+    return TVMContext(12, dev_id)
+
+
 cl = opencl
 mtl = metal
 
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 9ff24f4b761d62301c0de823e9d1397b74c8bb67..ce4a65dc79e2911f0a336975d215a1bb18ff61f7 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -31,6 +31,7 @@ inline std::string DeviceName(int type) {
     case kMetal: return "metal";
     case kVPI: return "vpi";
     case kROCM: return "rocm";
+    case kExtDev: return "ext_dev";
     default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
   }
 }
diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc
index 3edb1c67c4d401250a1d31828ace3d128ce4eab1..4fa91dcbeb3f9ca61c02e272d057770008e7dc70 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
 typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
 
 ROCMThreadEntry::ROCMThreadEntry()
-    : pool(kGPU, ROCMDeviceAPI::Global()) {
+    : pool(kROCM, ROCMDeviceAPI::Global()) {
 }
 
 ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py
index c4fbb4eccac1782d632c5f4e53055427c2296146..bbdd65e4be1c02ede61ba1b1ab5dd6800353fb6a 100644
--- a/tests/python/unittest/test_codegen_device.py
+++ b/tests/python/unittest/test_codegen_device.py
@@ -7,7 +7,7 @@ def test_add_pipeline():
     A = tvm.placeholder((n,), name='A')
     B = tvm.placeholder((n,), name='B')
     C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
-    D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='C')
+    D = tvm.compute(A.shape, lambda *i: C(*i) + 1, name='D')
     s = tvm.create_schedule(D.op)
 
     # GPU schedule have to split by gridIdx and threadIdx
@@ -26,11 +26,11 @@ def test_add_pipeline():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
     Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
-    Cb = tvm.decl_buffer(C.shape, C.dtype, name='C')
+    Db = tvm.decl_buffer(D.shape, D.dtype, name='D')
     stmt = tvm.ir_pass.LoopPartition(stmt)
-    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
     stmt = tvm.ir_pass.Simplify(stmt)
-    fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
+    fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
     fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
     fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
 
@@ -49,10 +49,10 @@ def test_add_pipeline():
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
-        c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
-        f(a, b, c)
+        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
+        f(a, b, d)
         np.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
 
     def check_module_save(device, host="stackvm"):
         if not tvm.module.enabled(host):
@@ -73,10 +73,10 @@ def test_add_pipeline():
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
-        c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
-        f(a, b, c)
+        d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
+        f(a, b, d)
         np.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
+            d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
 
     check_target("cuda", host="stackvm")
     check_target("cuda", host="llvm")