diff --git a/make/contrib/mps.mk b/make/contrib/mps.mk index 501e62b2a6716721da9325921bef5e8153f527c3..0fe8a7f128893e5fb071782cabc83ff73a53f248 100644 --- a/make/contrib/mps.mk +++ b/make/contrib/mps.mk @@ -1,4 +1,4 @@ -MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc) +MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm) MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC)) ifeq ($(USE_MPS), 1) @@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders CFLAGS += ADD_LDFLAGS += RUNTIME_DEP += $(MPS_CONTRIB_OBJ) +CONTRIB_OBJ += $(MPS_CONTRIB_OBJ) endif -build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc +build/contrib/mps/%.o: src/contrib/mps/%.mm + @mkdir -p $(@D) + $(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d + $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ + +build/contrib/mps/%.o: src/contrib/mps/%.cc @mkdir -p $(@D) $(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ diff --git a/python/tvm/contrib/mps.py b/python/tvm/contrib/mps.py index d214d4b93631ae9550c756cd64c03ba036a22681..43b3b9fb48dbe2cdf1a86f39634d1262e51963b0 100644 --- a/python/tvm/contrib/mps.py +++ b/python/tvm/contrib/mps.py @@ -1,9 +1,9 @@ """External function interface to MPS libraroes.""" from __future__ import absolute_import as _abs - from .. import api as _api from .. import intrin as _intrin +# pylint: disable=C0103,W0612 def matmul(lhs, rhs, transa=False, transb=False): """Create an extern op that compute matrix mult of A and rhs with CrhsLAS @@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False): C : Tensor The result tensor. """ - m = lhs.shape[0] - n = rhs.shape[1] + m = lhs.shape[0] if transa is False else lhs.shape[1] + n = rhs.shape[1] if transb is False else rhs.shape[0] + if transa: + m = b + if transb: + n = c return _api.extern( - (n, m), [lhs, rhs], + (m, n), [lhs, rhs], lambda ins, outs: _intrin.call_packed( "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb), name="C") + +def conv2d(data, weight, pad='SAME', stride=1): + """ + Create an extern op that compute data * weight and return result in output + + Parameters: + ---------- + data: Tensor + The input data, format NHWC + weight: Tensor + The conv weight, format output_feature * kH * kW * input_feature + pad: str + Padding method, 'SAME' or 'VALID' + stride: int + convolution stride + + Returns + ------- + output: Tensor + The result tensor + """ + n, hi, wi, ci = data.shape + co, kh, kw, ciw = weight.shape + padding = 0 if pad == 'SAME' else 1 + ho = hi // stride + wo = wi // stride + + return _api.extern( + (n, ho, wo, co), [data, weight], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride), + name="C") diff --git a/python/tvm/contrib/rpc_proxy.py b/python/tvm/contrib/rpc_proxy.py index fe289935e7cbf68a64c1bff37c5f9ef949e6174a..9634c258b39fccde263e8613783b3cfbf39003c3 100644 --- a/python/tvm/contrib/rpc_proxy.py +++ b/python/tvm/contrib/rpc_proxy.py @@ -70,6 +70,7 @@ class ForwardHandler(object): ProxyServerHandler.current.handler_ready(self) def on_data(self, message): + """on data""" assert isinstance(message, bytes) if self.forward_proxy: self.forward_proxy.send_data(message) @@ -98,6 +99,7 @@ class ForwardHandler(object): self.close() def on_close_event(self): + """on close event""" assert not self._done logging.info("RPCProxy:on_close %s ...", self.name()) self._done = True diff --git a/src/contrib/mps/conv.mm b/src/contrib/mps/conv.mm new file mode 100644 index 0000000000000000000000000000000000000000..fa279bd5cc9513e500269c8348edaf3cf61156d8 --- /dev/null +++ b/src/contrib/mps/conv.mm @@ -0,0 +1,154 @@ +#include "mps_utils.h" + +namespace tvm { +namespace contrib { + +using namespace runtime; + +TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor *buf = args[0]; + DLTensor *img = args[1]; + // copy to temp + id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data); + MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry *rt = + runtime::metal::MetalThreadEntry::ThreadLocal(); + id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx); + id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); + entry_ptr->metal_api->CopyDataFromTo( + (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length], + buf->ctx, buf->ctx, nullptr + ); + + MPSImageDescriptor *desc = [MPSImageDescriptor + imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 + width:buf->shape[2] + height:buf->shape[1] + featureChannels:buf->shape[3]]; + + MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc); + + [mpsimg writeBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + + img->data = (__bridge void *)mpsimg; + + [mpsimg readBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor *img = args[0]; + DLTensor *buf = args[1]; + id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data); + MPSImage *mpsimg = (__bridge MPSImage *)(img->data); + MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry *rt = + runtime::metal::MetalThreadEntry::ThreadLocal(); + id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); + + [mpsimg readBytes:[temp contents] + dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels + imageIndex:0]; + + entry_ptr->metal_api->CopyDataFromTo( + (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length], + buf->ctx, buf->ctx, nullptr); + + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") +.set_body([](TVMArgs args, TVMRetValue *ret) { + // MPS-NHWC + DLTensor *data = args[0]; + DLTensor *weight = args[1]; + DLTensor *output = args[2]; + int pad = args[3]; + int stride = args[4]; + + CHECK_EQ(data->ndim, 4); + CHECK_EQ(weight->ndim, 4); + CHECK_EQ(output->ndim, 4); + CHECK(output->strides == nullptr); + CHECK(weight->strides == nullptr); + CHECK(data->strides == nullptr); + + CHECK_EQ(data->shape[0], 1); + CHECK_EQ(output->shape[0], 1); + + int oCh = weight->shape[0]; + int kH = weight->shape[1]; + int kW = weight->shape[2]; + int iCh = weight->shape[3]; + + auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img"); + auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer"); + // Get Metal device API + MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry *rt = + runtime::metal::MetalThreadEntry::ThreadLocal(); + id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx); + id<MTLCommandQueue> queue = + entry_ptr->metal_api->GetCommandQueue(data->ctx); + id<MTLCommandBuffer> cb = [queue commandBuffer]; + // data to MPSImage + DLTensor tmp_in; + (*f_buf2img)(data, &tmp_in); + MPSImage *tempA = (__bridge MPSImage *)tmp_in.data; + // weight to temp memory + id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data); + id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); + entry_ptr->metal_api->CopyDataFromTo( + (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length], + weight->ctx, weight->ctx, nullptr); + float *ptr_w = (float *)[tempB contents]; + // output to MPSImage + DLTensor tmp_out; + (*f_buf2img)(output, &tmp_out); + MPSImage *tempC = (__bridge MPSImage *)tmp_out.data; + // conv desc + + MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor + cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; + [conv_desc setStrideInPixelsX:stride]; + [conv_desc setStrideInPixelsY:stride]; + + MPSCNNConvolution *conv = + [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; + if (pad == 0) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeSame]; + } else if (pad == 1) { + conv.padding = [MPSNNDefaultPadding + paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeValidOnly]; + } + [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; + + [cb commit]; + id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder]; + [encoder synchronizeResource:tempC.texture]; + [encoder endEncoding]; + [cb waitUntilCompleted]; + + (*f_img2buf)(&tmp_out, output); + + }); + +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/mps/gemm.mm b/src/contrib/mps/gemm.mm index f877cb8b0ea1de43cbe2a1be52037e0b293db981..1d92ad2851d042f36fdba10847b31f34ea834a16 100644 --- a/src/contrib/mps/gemm.mm +++ b/src/contrib/mps/gemm.mm @@ -1,9 +1,5 @@ -#include "../../runtime/metal/metal_common.h" -#include <MetalPerformanceShaders/MetalPerformanceShaders.h> -#include <dmlc/logging.h> -#include <tvm/runtime/device_api.h> -#include <tvm/runtime/registry.h> -#include <tvm/runtime/util.h> + +#include "mps_utils.h" namespace tvm { namespace contrib { @@ -11,83 +7,81 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") - .set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; - bool transa = args[3]; - bool transb = args[4]; - // call gemm for simple compact code. - CHECK_EQ(A->ndim, 2); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 2); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - // Get Metal device API - MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); - CHECK_EQ(A->ctx, B->ctx); - CHECK_EQ(A->ctx, C->ctx); - id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx); - id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx); - id<MTLCommandBuffer> cb = [queue commandBuffer]; - NSUInteger M = A->shape[0 + transa?1:0]; - NSUInteger N = B->shape[1 - transb?1:0]; - NSUInteger K = B->shape[0 + transb?1:0]; - CHECK_EQ(A->shape[1-transa?1:0], K); - // mps a - MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor *descA = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:M - columns:K - rowBytes:M * sizeof(dtype) - dataType:dtype]; - id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data); - MPSMatrix *matrixA = - [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; - // mps b - MPSMatrixDescriptor *descB = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:K - columns:N - rowBytes:K * sizeof(dtype) - dataType:dtype]; - id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data); - MPSMatrix *matrixB = - [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; - // mps c - MPSMatrixDescriptor *descC = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:M - columns:N - rowBytes:M * sizeof(dtype) - dataType:dtype]; - id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data); - MPSMatrix *matrixC = - [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; - // kernel +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // call gemm for simple compact code. + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); + // Get Metal device API + MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + // CHECK_EQ(A->ctx, B->ctx); + // CHECK_EQ(A->ctx, C->ctx); + id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx); + id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx); + id<MTLCommandBuffer> cb = [queue commandBuffer]; + NSUInteger M = A->shape[0 + (transa ? 1 : 0)]; + NSUInteger N = B->shape[1 - (transb ? 1 : 0)]; + NSUInteger K = B->shape[0 + (transb ? 1 : 0)]; + + CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); + // mps a + MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); + MPSMatrixDescriptor *descA = [MPSMatrixDescriptor + matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; + id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data); + MPSMatrix *matrixA = + [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + // mps b + MPSMatrixDescriptor *descB = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data); + MPSMatrix *matrixB = + [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + // mps c + MPSMatrixDescriptor *descC = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; + id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data); + MPSMatrix *matrixC = + [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + // kernel - MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev - transposeLeft:transa - transposeRight:transb - resultRows:M - resultColumns:N - interiorColumns:K - alpha:1.0f - beta:0.0f]; - CHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb - leftMatrix:matrixA - rightMatrix:matrixB - resultMatrix:matrixC]; - [cb commit]; - [mul_obj dealloc]; - [matrixA dealloc]; - [matrixB dealloc]; - [matrixC dealloc]; - }); + MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev + transposeLeft:transa + transposeRight:transb + resultRows:M + resultColumns:N + interiorColumns:K + alpha:1.0f + beta:0.0f]; + CHECK(sgemm != nil); + [sgemm encodeToCommandBuffer:cb + leftMatrix:matrixA + rightMatrix:matrixB + resultMatrix:matrixC]; + [cb commit]; + + }); } // namespace contrib } // namespace tvm diff --git a/src/contrib/mps/mps_utils.cc b/src/contrib/mps/mps_utils.cc deleted file mode 100644 index 2e3ca6218bb4d3ad30f96f0bb31a0aea9286c2bf..0000000000000000000000000000000000000000 --- a/src/contrib/mps/mps_utils.cc +++ /dev/null @@ -1,58 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file Use external mps utils function - */ -#include "mps_utils.h" -#include <MetalPerformanceShaders/MetalPerformanceShaders.h> -#include <dmlc/thread_local.h> -#include <tvm/runtime/registry.h> - - -namespace tvm { -namespace contrib { - -// MPS Data Type -MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { - switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16; - else - LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: - if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16; - else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32; - LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16; - else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32; - else - LOG(FATAL) << "Unsupported type"; - break; - default: - LOG(FATAL) << "Unsupported type"; - } -} - -// MetalThreadEntry - -MetalThreadEntry::MetalThreadEntry() { - auto func = runtime::Registry::Get("device_api.metal"); - void *ret = (*func)(); - metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret); -} - -MetalThreadEntry::~MetalThreadEntry() { -} - -typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore; - -MetalThreadEntry* MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} - -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/mps/mps_utils.h b/src/contrib/mps/mps_utils.h index 91336ce44eddf4dfb0740a6363e412ef90d46c82..f07156a252a3dd080e2be0400a41e7778d74a6b5 100644 --- a/src/contrib/mps/mps_utils.h +++ b/src/contrib/mps/mps_utils.h @@ -6,11 +6,15 @@ #ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_ #define TVM_CONTRIB_MPS_MPS_UTILS_H_ +#import <MetalPerformanceShaders/MetalPerformanceShaders.h> #include <dmlc/logging.h> +#include <dmlc/thread_local.h> #include <tvm/runtime/device_api.h> +#include <tvm/runtime/registry.h> +#include <tvm/runtime/util.h> +#include <vector> #include "../../runtime/metal/metal_common.h" - namespace tvm { namespace contrib { @@ -19,12 +23,15 @@ struct MPSType { static MPSDataType DLTypeToMPSType(const DLDataType &dtype); }; // struct MPSType - struct MetalThreadEntry { MetalThreadEntry(); ~MetalThreadEntry(); - runtime::MetalWorkspace *metal_api{nullptr}; - static MetalThreadEntry* ThreadLocal(); + MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc); + MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb, + MPSImageDescriptor *desc); + runtime::metal::MetalWorkspace *metal_api{nullptr}; + static MetalThreadEntry *ThreadLocal(); + std::vector<MPSImage *> img_table; }; // MetalThreadEntry } // namespace contrib diff --git a/src/contrib/mps/mps_utils.mm b/src/contrib/mps/mps_utils.mm new file mode 100644 index 0000000000000000000000000000000000000000..bed8278a1d50e03517b63b3c589162f8875907cf --- /dev/null +++ b/src/contrib/mps/mps_utils.mm @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file Use external mps utils function + */ +#include "mps_utils.h" + +namespace tvm { +namespace contrib { + +// MPS Data Type +MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { + switch (dtype.code) { + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeInt16; + else + LOG(FATAL) << "Unsupported type"; + break; + case kDLUInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeUInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeUInt16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeUInt32; + LOG(FATAL) << "Unsupported type"; + break; + case kDLFloat: + if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeFloat16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeFloat32; + else + LOG(FATAL) << "Unsupported type"; + break; + default: + LOG(FATAL) << "Unsupported type"; + } + return MPSDataTypeFloat32; +} + +// MetalThreadEntry + +MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev, + MPSImageDescriptor *desc) { + MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; + img_table.push_back(mpsimg); + return mpsimg; +} + +MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb, + MPSImageDescriptor *desc) { + MPSTemporaryImage *mpsimg = + [MPSTemporaryImage temporaryImageWithCommandBuffer:cb + imageDescriptor:desc]; + return mpsimg; +} + +MetalThreadEntry::MetalThreadEntry() { + auto func = runtime::Registry::Get("device_api.metal"); + void *ret = (*func)(); + metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret); +} + +MetalThreadEntry::~MetalThreadEntry() { + for (int i = 0; i < img_table.size(); ++i) { + [img_table[i] dealloc]; + } +} + +typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore; + +MetalThreadEntry *MetalThreadEntry::ThreadLocal() { + return MetalThreadStore::Get(); +} + +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index d87a9eac4f72347e0e61c4f598723076c84fc479..1768d6334b5ce7d44fec7b2b51d4e374a14d6f40 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace( TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) { this->Init(); id<MTLDevice> dev = GetDevice(ctx); - // allocate buffer in GPU only mode. + // GPU memory only + MTLResourceOptions storage_mode = MTLResourceStorageModePrivate; + /* + #if TARGET_OS_IPHONE + storage_mode = MTLResourceStorageModeShared; + #else + storage_mode = MTLResourceStorageModeManaged; + #endif + */ id<MTLBuffer> buf = [ dev newBufferWithLength:nbytes - options:MTLResourceStorageModePrivate]; + options:storage_mode]; CHECK(buf != nil); return (__bridge void*)([buf retain]); } diff --git a/tests/python/contrib/test_mps.py b/tests/python/contrib/test_mps.py index 68dcb135e90823fefbb24100f48bfa2ccf62a222..25437605525b8149c60d324e2c0a189afc0d0861 100644 --- a/tests/python/contrib/test_mps.py +++ b/tests/python/contrib/test_mps.py @@ -2,39 +2,83 @@ import tvm import numpy as np from tvm.contrib import mps -def test_matmul_add(): +def test_matmul(): + if not tvm.module.enabled("metal"): + print("skip because %s is not enabled..." % "metal") + return n = 1024 l = 128 - m = 235 - bias = tvm.var('bias', dtype=tvm.float32) + m = 256 A = tvm.placeholder((n, l), name='A') B = tvm.placeholder((l, m), name='B') - C1 = mps.matmul(A, B) - C2 = mps.matmul(B, A, True, True) - D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1") - D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2") - s1 = tvm.create_schedule(D1.op) - s2 = tvm.create_schedule(D2.op) - - def verify(A, B, D, s, bias, target="llvm"): - if not tvm.module.enabled(target): - print("skip because %s is not enabled..." % target) - return + C = mps.matmul(A, B) + D = tvm.compute( + C.shape, + lambda *i: C(*i) + 1. + ) + s = tvm.create_schedule(D.op) + yo, xo = D.op.axis + block_y = tvm.thread_axis("blockIdx.y") + block_x = tvm.thread_axis("blockIdx.x") + thread_y = tvm.thread_axis("threadIdx.y") + thread_x = tvm.thread_axis("threadIdx.x") + by, ty = s[D].split(yo, factor=16) + bx, tx = s[D].split(xo, factor=16) + s[D].bind(by, block_y) + s[D].bind(bx, block_x) + s[D].bind(ty, thread_y) + s[D].bind(tx, thread_x) + + + + def verify(A, B, D, s, target="metal"): if not tvm.get_global_func("tvm.contrib.mps.matmul", True): print("skip because extern function is not avalable") return - ctx = tvm.cpu(0) - f = tvm.build(s, [A, B, D, bias], target) + ctx = tvm.metal(0) + f = tvm.build(s, [A, B, D], "metal") a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) - d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) - bb = 10.0 - f(a, b, d, bb) + c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) + f(a, b, c) np.testing.assert_allclose( - d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) - verify(A, B, D1, s1, bias) - verify(A, B, D2, s2, bias) + c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5) + verify(A, B, D, s) + +def test_conv2d(): + if not tvm.module.enabled("metal"): + print("skip because %s is not enabled..." % "metal") + return + n = 1 + h = 14 + w = 14 + ci = 2 + co = 4 + kh = 3 + kw = 3 + stride = 2 + A = tvm.placeholder((n, h, w, ci), name="x") + B = tvm.placeholder((co, kh, kw, ci), name="w") + C = mps.conv2d(A, B, 'SAME', 2) + s1 = tvm.create_schedule(C.op) + + def verify(A, B, C, target="llvm"): + if not tvm.get_global_func("tvm.contrib.mps.conv2d", True): + print("skip because extern function is not avalable") + return + ctx = tvm.metal(0) + f = tvm.build(s1, [A, B, C], "metal") + a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx) + f(a, b, c) + # print(c.asnumpy()) + # print(c.shape) + + verify(A, B, C, s1) if __name__ == "__main__": - test_matmul_add() + #test_matmul() + test_conv2d() +