diff --git a/make/contrib/mps.mk b/make/contrib/mps.mk
index 501e62b2a6716721da9325921bef5e8153f527c3..0fe8a7f128893e5fb071782cabc83ff73a53f248 100644
--- a/make/contrib/mps.mk
+++ b/make/contrib/mps.mk
@@ -1,4 +1,4 @@
-MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc)
+MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm)
 MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
 
 ifeq ($(USE_MPS), 1)
@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
 CFLAGS += 
 ADD_LDFLAGS += 
 RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
+CONTRIB_OBJ += $(MPS_CONTRIB_OBJ)
 endif
 
-build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc
+build/contrib/mps/%.o: src/contrib/mps/%.mm
+	@mkdir -p $(@D)
+	$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
+	$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
+
+build/contrib/mps/%.o: src/contrib/mps/%.cc
 	@mkdir -p $(@D)
 	$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
 	$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
diff --git a/python/tvm/contrib/mps.py b/python/tvm/contrib/mps.py
index d214d4b93631ae9550c756cd64c03ba036a22681..43b3b9fb48dbe2cdf1a86f39634d1262e51963b0 100644
--- a/python/tvm/contrib/mps.py
+++ b/python/tvm/contrib/mps.py
@@ -1,9 +1,9 @@
 """External function interface to MPS libraroes."""
 from __future__ import absolute_import as _abs
-
 from .. import api as _api
 from .. import intrin as _intrin
 
+# pylint: disable=C0103,W0612
 
 def matmul(lhs, rhs, transa=False, transb=False):
     """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
     C : Tensor
         The result tensor.
     """
-    m = lhs.shape[0]
-    n = rhs.shape[1]
+    m = lhs.shape[0] if transa is False else lhs.shape[1]
+    n = rhs.shape[1] if transb is False else rhs.shape[0]
+    if transa:
+        m = b
+    if transb:
+        n = c
     return _api.extern(
-        (n, m), [lhs, rhs],
+        (m, n), [lhs, rhs],
         lambda ins, outs: _intrin.call_packed(
             "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
         name="C")
+
+def conv2d(data, weight, pad='SAME', stride=1):
+    """
+    Create an extern op that compute data * weight and return result in output
+
+    Parameters:
+    ----------
+    data: Tensor
+        The input data, format NHWC
+    weight: Tensor
+        The conv weight, format output_feature * kH * kW * input_feature
+    pad: str
+        Padding method, 'SAME' or 'VALID'
+    stride: int
+        convolution stride
+
+    Returns
+    -------
+    output: Tensor
+        The result tensor
+    """
+    n, hi, wi, ci = data.shape
+    co, kh, kw, ciw = weight.shape
+    padding = 0 if pad == 'SAME' else 1
+    ho = hi // stride
+    wo = wi // stride
+
+    return _api.extern(
+        (n, ho, wo, co), [data, weight],
+        lambda ins, outs: _intrin.call_packed(
+            "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
+        name="C")
diff --git a/python/tvm/contrib/rpc_proxy.py b/python/tvm/contrib/rpc_proxy.py
index fe289935e7cbf68a64c1bff37c5f9ef949e6174a..9634c258b39fccde263e8613783b3cfbf39003c3 100644
--- a/python/tvm/contrib/rpc_proxy.py
+++ b/python/tvm/contrib/rpc_proxy.py
@@ -70,6 +70,7 @@ class ForwardHandler(object):
         ProxyServerHandler.current.handler_ready(self)
 
     def on_data(self, message):
+        """on data"""
         assert isinstance(message, bytes)
         if self.forward_proxy:
             self.forward_proxy.send_data(message)
@@ -98,6 +99,7 @@ class ForwardHandler(object):
         self.close()
 
     def on_close_event(self):
+        """on close event"""
         assert not self._done
         logging.info("RPCProxy:on_close %s ...", self.name())
         self._done = True
diff --git a/src/contrib/mps/conv.mm b/src/contrib/mps/conv.mm
new file mode 100644
index 0000000000000000000000000000000000000000..fa279bd5cc9513e500269c8348edaf3cf61156d8
--- /dev/null
+++ b/src/contrib/mps/conv.mm
@@ -0,0 +1,154 @@
+#include "mps_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+using namespace runtime;
+
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  DLTensor *buf = args[0];
+  DLTensor *img = args[1];
+  // copy to temp
+  id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
+  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
+  runtime::metal::MetalThreadEntry *rt =
+      runtime::metal::MetalThreadEntry::ThreadLocal();
+  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
+  id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
+  entry_ptr->metal_api->CopyDataFromTo(
+      (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
+      buf->ctx, buf->ctx, nullptr
+  );
+
+  MPSImageDescriptor *desc = [MPSImageDescriptor
+      imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
+                                 width:buf->shape[2]
+                                height:buf->shape[1]
+                       featureChannels:buf->shape[3]];
+
+  MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
+
+  [mpsimg writeBytes:[temp contents]
+          dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
+          imageIndex:0];
+
+  img->data = (__bridge void *)mpsimg;
+
+  [mpsimg readBytes:[temp contents]
+         dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
+         imageIndex:0];
+  
+  });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  DLTensor *img = args[0];
+  DLTensor *buf = args[1];
+  id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
+  MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
+  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
+  runtime::metal::MetalThreadEntry *rt =
+      runtime::metal::MetalThreadEntry::ThreadLocal();
+  id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
+
+  [mpsimg readBytes:[temp contents]
+         dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
+         imageIndex:0];
+
+  entry_ptr->metal_api->CopyDataFromTo(
+      (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
+      buf->ctx, buf->ctx, nullptr);
+
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  // MPS-NHWC
+  DLTensor *data = args[0];
+  DLTensor *weight = args[1];
+  DLTensor *output = args[2];
+  int pad = args[3];
+  int stride = args[4];
+
+  CHECK_EQ(data->ndim, 4);
+  CHECK_EQ(weight->ndim, 4);
+  CHECK_EQ(output->ndim, 4);
+  CHECK(output->strides == nullptr);
+  CHECK(weight->strides == nullptr);
+  CHECK(data->strides == nullptr);
+
+  CHECK_EQ(data->shape[0], 1);
+  CHECK_EQ(output->shape[0], 1);
+
+  int oCh = weight->shape[0];
+  int kH = weight->shape[1];
+  int kW = weight->shape[2];
+  int iCh = weight->shape[3];
+
+  auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
+  auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
+  // Get Metal device API
+  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
+  runtime::metal::MetalThreadEntry *rt =
+      runtime::metal::MetalThreadEntry::ThreadLocal();
+  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
+  id<MTLCommandQueue> queue =
+      entry_ptr->metal_api->GetCommandQueue(data->ctx);
+  id<MTLCommandBuffer> cb = [queue commandBuffer];
+  // data to MPSImage
+  DLTensor tmp_in;
+  (*f_buf2img)(data, &tmp_in);
+  MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
+  // weight to temp memory
+  id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
+  id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
+  entry_ptr->metal_api->CopyDataFromTo(
+      (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
+      weight->ctx, weight->ctx, nullptr);
+  float *ptr_w = (float *)[tempB contents];
+  // output to MPSImage
+  DLTensor tmp_out;
+  (*f_buf2img)(output, &tmp_out);
+  MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
+  // conv desc
+
+  MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
+      cnnConvolutionDescriptorWithKernelWidth:kW
+                                 kernelHeight:kH
+                         inputFeatureChannels:iCh
+                        outputFeatureChannels:oCh];
+  [conv_desc setStrideInPixelsX:stride];
+  [conv_desc setStrideInPixelsY:stride];
+
+  MPSCNNConvolution *conv =
+      [[MPSCNNConvolution alloc] initWithDevice:dev
+                          convolutionDescriptor:conv_desc
+                                  kernelWeights:ptr_w
+                                      biasTerms:nil
+                                          flags:MPSCNNConvolutionFlagsNone];
+  if (pad == 0) {
+    conv.padding = [MPSNNDefaultPadding
+        paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
+                          MPSNNPaddingMethodAlignCentered |
+                          MPSNNPaddingMethodSizeSame];
+  } else if (pad == 1) {
+    conv.padding = [MPSNNDefaultPadding
+        paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
+                          MPSNNPaddingMethodAlignCentered |
+                          MPSNNPaddingMethodSizeValidOnly];
+  }
+  [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
+
+  [cb commit];
+  id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
+  [encoder synchronizeResource:tempC.texture];
+  [encoder endEncoding];
+  [cb waitUntilCompleted];
+
+  (*f_img2buf)(&tmp_out, output);
+  
+  });
+
+} // namespace contrib
+} // namespace tvm
diff --git a/src/contrib/mps/gemm.mm b/src/contrib/mps/gemm.mm
index f877cb8b0ea1de43cbe2a1be52037e0b293db981..1d92ad2851d042f36fdba10847b31f34ea834a16 100644
--- a/src/contrib/mps/gemm.mm
+++ b/src/contrib/mps/gemm.mm
@@ -1,9 +1,5 @@
-#include "../../runtime/metal/metal_common.h"
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#include <dmlc/logging.h>
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/util.h>
+
+#include "mps_utils.h"
 
 namespace tvm {
 namespace contrib {
@@ -11,83 +7,81 @@ namespace contrib {
 using namespace runtime;
 
 TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
-    .set_body([](TVMArgs args, TVMRetValue *ret) {
-      DLTensor *A = args[0];
-      DLTensor *B = args[1];
-      DLTensor *C = args[2];
-      bool transa = args[3];
-      bool transb = args[4];
-      // call gemm for simple compact code.
-      CHECK_EQ(A->ndim, 2);
-      CHECK_EQ(B->ndim, 2);
-      CHECK_EQ(C->ndim, 2);
-      CHECK(C->strides == nullptr);
-      CHECK(B->strides == nullptr);
-      CHECK(A->strides == nullptr);
-      CHECK(TypeMatch(A->dtype, kDLFloat, 32));
-      CHECK(TypeMatch(B->dtype, kDLFloat, 32));
-      CHECK(TypeMatch(C->dtype, kDLFloat, 32));
-      // Get Metal device API
-      MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
-      CHECK_EQ(A->ctx, B->ctx);
-      CHECK_EQ(A->ctx, C->ctx);
-      id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
-      id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
-      id<MTLCommandBuffer> cb = [queue commandBuffer];
-      NSUInteger M = A->shape[0 + transa?1:0];
-      NSUInteger N = B->shape[1 - transb?1:0];
-      NSUInteger K = B->shape[0 + transb?1:0];
-      CHECK_EQ(A->shape[1-transa?1:0], K);
-      // mps a
-      MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
-      MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
-          matrixDescriptorWithDimensions:M
-                                 columns:K
-                                rowBytes:M * sizeof(dtype)
-                                dataType:dtype];
-      id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
-      MPSMatrix *matrixA =
-          [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
-      // mps b
-      MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
-          matrixDescriptorWithDimensions:K
-                                 columns:N
-                                rowBytes:K * sizeof(dtype)
-                                dataType:dtype];
-      id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
-      MPSMatrix *matrixB =
-          [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
-      // mps c
-      MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
-          matrixDescriptorWithDimensions:M
-                                 columns:N
-                                rowBytes:M * sizeof(dtype)
-                                dataType:dtype];
-      id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
-      MPSMatrix *matrixC =
-          [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
-      // kernel
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+  DLTensor *A = args[0];
+  DLTensor *B = args[1];
+  DLTensor *C = args[2];
+  bool transa = args[3];
+  bool transb = args[4];
+  // call gemm for simple compact code.
+  CHECK_EQ(A->ndim, 2);
+  CHECK_EQ(B->ndim, 2);
+  CHECK_EQ(C->ndim, 2);
+  CHECK(C->strides == nullptr);
+  CHECK(B->strides == nullptr);
+  CHECK(A->strides == nullptr);
+  CHECK(TypeMatch(A->dtype, kDLFloat, 32));
+  CHECK(TypeMatch(B->dtype, kDLFloat, 32));
+  CHECK(TypeMatch(C->dtype, kDLFloat, 32));
+  // Get Metal device API
+  MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
+  // CHECK_EQ(A->ctx, B->ctx);
+  // CHECK_EQ(A->ctx, C->ctx);
+  id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
+  id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
+  id<MTLCommandBuffer> cb = [queue commandBuffer];
+  NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
+  NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
+  NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
+
+  CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
+  // mps a
+  MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
+  MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
+      matrixDescriptorWithDimensions:M
+                             columns:K
+                            rowBytes:K * sizeof(MPSDataTypeFloat32)
+                            dataType:MPSDataTypeFloat32];
+  id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
+  MPSMatrix *matrixA =
+      [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
+  // mps b
+  MPSMatrixDescriptor *descB =
+      [MPSMatrixDescriptor matrixDescriptorWithDimensions:K
+                                                  columns:N
+                                                  rowBytes:N * sizeof(dtype)
+                                                  dataType:dtype];
+  id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
+  MPSMatrix *matrixB =
+      [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
+  // mps c
+  MPSMatrixDescriptor *descC =
+      [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
+                                                  columns:N
+                                                 rowBytes:N * sizeof(dtype)
+                                                 dataType:dtype];
+  id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
+  MPSMatrix *matrixC =
+      [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
+  // kernel
 
-      MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
-      MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
-                                                 transposeLeft:transa
-                                                transposeRight:transb
-                                                    resultRows:M
-                                                 resultColumns:N
-                                               interiorColumns:K
-                                                         alpha:1.0f
-                                                          beta:0.0f];
-      CHECK(sgemm != nil);
-      [sgemm encodeToCommandBuffer:cb
-                        leftMatrix:matrixA
-                       rightMatrix:matrixB
-                      resultMatrix:matrixC];
-      [cb commit];
-      [mul_obj dealloc];
-      [matrixA dealloc];
-      [matrixB dealloc];
-      [matrixC dealloc];
-    });
+  MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
+  MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
+                                             transposeLeft:transa
+                                            transposeRight:transb
+                                                resultRows:M
+                                             resultColumns:N
+                                           interiorColumns:K
+                                                     alpha:1.0f
+                                                      beta:0.0f];
+  CHECK(sgemm != nil);
+  [sgemm encodeToCommandBuffer:cb
+                    leftMatrix:matrixA
+                   rightMatrix:matrixB
+                  resultMatrix:matrixC];
+  [cb commit];
+  
+  });
 
 } // namespace contrib
 } // namespace tvm
diff --git a/src/contrib/mps/mps_utils.cc b/src/contrib/mps/mps_utils.cc
deleted file mode 100644
index 2e3ca6218bb4d3ad30f96f0bb31a0aea9286c2bf..0000000000000000000000000000000000000000
--- a/src/contrib/mps/mps_utils.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file Use external mps utils function
- */
-#include "mps_utils.h"
-#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
-#include <dmlc/thread_local.h>
-#include <tvm/runtime/registry.h>
-
-
-namespace tvm {
-namespace contrib {
-
-// MPS Data Type
-MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
-  switch (dtype.code) {
-      case kDLInt:
-        if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8;
-        else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16;
-        else
-          LOG(FATAL) << "Unsupported type";
-        break;
-      case kDLUInt:
-        if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8;
-        else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16;
-        else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32;
-        LOG(FATAL) << "Unsupported type";
-        break;
-      case kDLFloat:
-        if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16;
-        else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32;
-        else
-          LOG(FATAL) << "Unsupported type";
-        break;
-      default:
-        LOG(FATAL) << "Unsupported type";
-    }
-}
-
-// MetalThreadEntry
-
-MetalThreadEntry::MetalThreadEntry() {
-  auto func = runtime::Registry::Get("device_api.metal");
-  void *ret = (*func)();
-  metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret);
-}
-
-MetalThreadEntry::~MetalThreadEntry() {
-}
-
-typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
-
-MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
-  return MetalThreadStore::Get();
-}
-
-}  // namespace contrib
-}  // namespace tvm
diff --git a/src/contrib/mps/mps_utils.h b/src/contrib/mps/mps_utils.h
index 91336ce44eddf4dfb0740a6363e412ef90d46c82..f07156a252a3dd080e2be0400a41e7778d74a6b5 100644
--- a/src/contrib/mps/mps_utils.h
+++ b/src/contrib/mps/mps_utils.h
@@ -6,11 +6,15 @@
 #ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
 #define TVM_CONTRIB_MPS_MPS_UTILS_H_
 
+#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 #include <dmlc/logging.h>
+#include <dmlc/thread_local.h>
 #include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/util.h>
+#include <vector>
 #include "../../runtime/metal/metal_common.h"
 
-
 namespace tvm {
 namespace contrib {
 
@@ -19,12 +23,15 @@ struct MPSType {
   static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
 };  // struct MPSType
 
-
 struct MetalThreadEntry {
   MetalThreadEntry();
   ~MetalThreadEntry();
-  runtime::MetalWorkspace *metal_api{nullptr};
-  static MetalThreadEntry* ThreadLocal();
+  MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
+  MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
+                                    MPSImageDescriptor *desc);
+  runtime::metal::MetalWorkspace *metal_api{nullptr};
+  static MetalThreadEntry *ThreadLocal();
+  std::vector<MPSImage *> img_table;
 };  // MetalThreadEntry
 
 }  // namespace contrib
diff --git a/src/contrib/mps/mps_utils.mm b/src/contrib/mps/mps_utils.mm
new file mode 100644
index 0000000000000000000000000000000000000000..bed8278a1d50e03517b63b3c589162f8875907cf
--- /dev/null
+++ b/src/contrib/mps/mps_utils.mm
@@ -0,0 +1,80 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file Use external mps utils function
+ */
+#include "mps_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+// MPS Data Type
+MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
+  switch (dtype.code) {
+  case kDLInt:
+    if (dtype.bits == 8 && dtype.lanes == 1)
+      return MPSDataTypeInt8;
+    else if (dtype.bits == 16 && dtype.lanes == 1)
+      return MPSDataTypeInt16;
+    else
+      LOG(FATAL) << "Unsupported type";
+    break;
+  case kDLUInt:
+    if (dtype.bits == 8 && dtype.lanes == 1)
+      return MPSDataTypeUInt8;
+    else if (dtype.bits == 16 && dtype.lanes == 1)
+      return MPSDataTypeUInt16;
+    else if (dtype.bits == 32 && dtype.lanes == 1)
+      return MPSDataTypeUInt32;
+    LOG(FATAL) << "Unsupported type";
+    break;
+  case kDLFloat:
+    if (dtype.bits == 16 && dtype.lanes == 1)
+      return MPSDataTypeFloat16;
+    else if (dtype.bits == 32 && dtype.lanes == 1)
+      return MPSDataTypeFloat32;
+    else
+      LOG(FATAL) << "Unsupported type";
+    break;
+  default:
+    LOG(FATAL) << "Unsupported type";
+  }
+  return MPSDataTypeFloat32;
+}
+
+// MetalThreadEntry
+
+MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
+                                          MPSImageDescriptor *desc) {
+  MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
+  img_table.push_back(mpsimg);
+  return mpsimg;
+}
+
+MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
+                                                    MPSImageDescriptor *desc) {
+  MPSTemporaryImage *mpsimg =
+      [MPSTemporaryImage temporaryImageWithCommandBuffer:cb
+                                         imageDescriptor:desc];
+  return mpsimg;
+}
+
+MetalThreadEntry::MetalThreadEntry() {
+  auto func = runtime::Registry::Get("device_api.metal");
+  void *ret = (*func)();
+  metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret);
+}
+
+MetalThreadEntry::~MetalThreadEntry() {
+  for (int i = 0; i < img_table.size(); ++i) {
+    [img_table[i] dealloc];
+  }
+}
+
+typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
+
+MetalThreadEntry *MetalThreadEntry::ThreadLocal() {
+  return MetalThreadStore::Get();
+}
+
+} // namespace contrib
+} // namespace tvm
diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm
index d87a9eac4f72347e0e61c4f598723076c84fc479..1768d6334b5ce7d44fec7b2b51d4e374a14d6f40 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
     TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
   this->Init();
   id<MTLDevice> dev = GetDevice(ctx);
-  // allocate buffer in GPU only mode.
+  // GPU memory only
+  MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
+  /*
+  #if TARGET_OS_IPHONE
+  storage_mode = MTLResourceStorageModeShared;
+  #else
+  storage_mode = MTLResourceStorageModeManaged;
+  #endif
+  */
   id<MTLBuffer> buf = [
       dev newBufferWithLength:nbytes
-          options:MTLResourceStorageModePrivate];
+          options:storage_mode];
   CHECK(buf != nil);
   return (__bridge void*)([buf retain]);
 }
diff --git a/tests/python/contrib/test_mps.py b/tests/python/contrib/test_mps.py
index 68dcb135e90823fefbb24100f48bfa2ccf62a222..25437605525b8149c60d324e2c0a189afc0d0861 100644
--- a/tests/python/contrib/test_mps.py
+++ b/tests/python/contrib/test_mps.py
@@ -2,39 +2,83 @@ import tvm
 import numpy as np
 from tvm.contrib import mps
 
-def test_matmul_add():
+def test_matmul():
+    if not tvm.module.enabled("metal"):
+        print("skip because %s is not enabled..." % "metal")
+        return
     n = 1024
     l = 128
-    m = 235
-    bias = tvm.var('bias', dtype=tvm.float32)
+    m = 256
     A = tvm.placeholder((n, l), name='A')
     B = tvm.placeholder((l, m), name='B')
-    C1 = mps.matmul(A, B)
-    C2 = mps.matmul(B, A, True, True)
-    D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1")
-    D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2")
-    s1 = tvm.create_schedule(D1.op)
-    s2 = tvm.create_schedule(D2.op)
-
-    def verify(A, B, D, s, bias, target="llvm"):
-        if not tvm.module.enabled(target):
-            print("skip because %s is not enabled..." % target)
-            return
+    C = mps.matmul(A, B)
+    D = tvm.compute(
+        C.shape,
+        lambda *i: C(*i) + 1.
+    )
+    s = tvm.create_schedule(D.op)
+    yo, xo = D.op.axis
+    block_y = tvm.thread_axis("blockIdx.y")
+    block_x = tvm.thread_axis("blockIdx.x")
+    thread_y = tvm.thread_axis("threadIdx.y")
+    thread_x = tvm.thread_axis("threadIdx.x")
+    by, ty = s[D].split(yo, factor=16)
+    bx, tx = s[D].split(xo, factor=16)
+    s[D].bind(by, block_y)
+    s[D].bind(bx, block_x)
+    s[D].bind(ty, thread_y)
+    s[D].bind(tx, thread_x)
+
+
+
+    def verify(A, B, D, s, target="metal"):
         if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
             print("skip because extern function is not avalable")
             return
-        ctx = tvm.cpu(0)
-        f = tvm.build(s, [A, B, D, bias], target)
+        ctx = tvm.metal(0)
+        f = tvm.build(s, [A, B, D], "metal")
         a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
         b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
-        d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
-        bb = 10.0
-        f(a, b, d, bb)
+        c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
+        f(a, b, c)
         np.testing.assert_allclose(
-            d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
-    verify(A, B, D1, s1, bias)
-    verify(A, B, D2, s2, bias)
+            c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
+    verify(A, B, D, s)
+
+def test_conv2d():
+    if not tvm.module.enabled("metal"):
+        print("skip because %s is not enabled..." % "metal")
+        return
+    n = 1
+    h = 14
+    w = 14
+    ci = 2
+    co = 4
+    kh = 3
+    kw = 3
+    stride = 2
+    A = tvm.placeholder((n, h, w, ci), name="x")
+    B = tvm.placeholder((co, kh, kw, ci), name="w")
+    C = mps.conv2d(A, B, 'SAME', 2)
+    s1 = tvm.create_schedule(C.op)
+
+    def verify(A, B, C, target="llvm"):
+        if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
+            print("skip because extern function is not avalable")
+            return
+        ctx = tvm.metal(0)
+        f = tvm.build(s1, [A, B, C], "metal")
+        a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
+        c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
+        f(a, b, c)
+        # print(c.asnumpy())
+        # print(c.shape)
+        
+    verify(A, B, C, s1)
 
 
 if __name__ == "__main__":
-    test_matmul_add()
+    #test_matmul()
+    test_conv2d()
+