Commit 078c767c by Leyuan Wang Committed by Tianqi Chen

MPS conv (#822)

parent 6731f660
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc) MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm)
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC)) MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
ifeq ($(USE_MPS), 1) ifeq ($(USE_MPS), 1)
...@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders ...@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
CFLAGS += CFLAGS +=
ADD_LDFLAGS += ADD_LDFLAGS +=
RUNTIME_DEP += $(MPS_CONTRIB_OBJ) RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
CONTRIB_OBJ += $(MPS_CONTRIB_OBJ)
endif endif
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc build/contrib/mps/%.o: src/contrib/mps/%.mm
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/contrib/mps/%.o: src/contrib/mps/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d $(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
"""External function interface to MPS libraroes.""" """External function interface to MPS libraroes."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .. import api as _api from .. import api as _api
from .. import intrin as _intrin from .. import intrin as _intrin
# pylint: disable=C0103,W0612
def matmul(lhs, rhs, transa=False, transb=False): def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS """Create an extern op that compute matrix mult of A and rhs with CrhsLAS
...@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False): ...@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
C : Tensor C : Tensor
The result tensor. The result tensor.
""" """
m = lhs.shape[0] m = lhs.shape[0] if transa is False else lhs.shape[1]
n = rhs.shape[1] n = rhs.shape[1] if transb is False else rhs.shape[0]
if transa:
m = b
if transb:
n = c
return _api.extern( return _api.extern(
(n, m), [lhs, rhs], (m, n), [lhs, rhs],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb), "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
name="C") name="C")
def conv2d(data, weight, pad='SAME', stride=1):
"""
Create an extern op that compute data * weight and return result in output
Parameters:
----------
data: Tensor
The input data, format NHWC
weight: Tensor
The conv weight, format output_feature * kH * kW * input_feature
pad: str
Padding method, 'SAME' or 'VALID'
stride: int
convolution stride
Returns
-------
output: Tensor
The result tensor
"""
n, hi, wi, ci = data.shape
co, kh, kw, ciw = weight.shape
padding = 0 if pad == 'SAME' else 1
ho = hi // stride
wo = wi // stride
return _api.extern(
(n, ho, wo, co), [data, weight],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
name="C")
...@@ -70,6 +70,7 @@ class ForwardHandler(object): ...@@ -70,6 +70,7 @@ class ForwardHandler(object):
ProxyServerHandler.current.handler_ready(self) ProxyServerHandler.current.handler_ready(self)
def on_data(self, message): def on_data(self, message):
"""on data"""
assert isinstance(message, bytes) assert isinstance(message, bytes)
if self.forward_proxy: if self.forward_proxy:
self.forward_proxy.send_data(message) self.forward_proxy.send_data(message)
...@@ -98,6 +99,7 @@ class ForwardHandler(object): ...@@ -98,6 +99,7 @@ class ForwardHandler(object):
self.close() self.close()
def on_close_event(self): def on_close_event(self):
"""on close event"""
assert not self._done assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name()) logging.info("RPCProxy:on_close %s ...", self.name())
self._done = True self._done = True
......
#include "mps_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *buf = args[0];
DLTensor *img = args[1];
// copy to temp
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr
);
MPSImageDescriptor *desc = [MPSImageDescriptor
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
width:buf->shape[2]
height:buf->shape[1]
featureChannels:buf->shape[3]];
MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
[mpsimg writeBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
img->data = (__bridge void *)mpsimg;
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *img = args[0];
DLTensor *buf = args[1];
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr);
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
.set_body([](TVMArgs args, TVMRetValue *ret) {
// MPS-NHWC
DLTensor *data = args[0];
DLTensor *weight = args[1];
DLTensor *output = args[2];
int pad = args[3];
int stride = args[4];
CHECK_EQ(data->ndim, 4);
CHECK_EQ(weight->ndim, 4);
CHECK_EQ(output->ndim, 4);
CHECK(output->strides == nullptr);
CHECK(weight->strides == nullptr);
CHECK(data->strides == nullptr);
CHECK_EQ(data->shape[0], 1);
CHECK_EQ(output->shape[0], 1);
int oCh = weight->shape[0];
int kH = weight->shape[1];
int kW = weight->shape[2];
int iCh = weight->shape[3];
auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
// Get Metal device API
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
id<MTLCommandQueue> queue =
entry_ptr->metal_api->GetCommandQueue(data->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
// data to MPSImage
DLTensor tmp_in;
(*f_buf2img)(data, &tmp_in);
MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
// weight to temp memory
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
weight->ctx, weight->ctx, nullptr);
float *ptr_w = (float *)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
(*f_buf2img)(output, &tmp_out);
MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
// conv desc
MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
cnnConvolutionDescriptorWithKernelWidth:kW
kernelHeight:kH
inputFeatureChannels:iCh
outputFeatureChannels:oCh];
[conv_desc setStrideInPixelsX:stride];
[conv_desc setStrideInPixelsY:stride];
MPSCNNConvolution *conv =
[[MPSCNNConvolution alloc] initWithDevice:dev
convolutionDescriptor:conv_desc
kernelWeights:ptr_w
biasTerms:nil
flags:MPSCNNConvolutionFlagsNone];
if (pad == 0) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeSame];
} else if (pad == 1) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeValidOnly];
}
[conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
[cb commit];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder synchronizeResource:tempC.texture];
[encoder endEncoding];
[cb waitUntilCompleted];
(*f_img2buf)(&tmp_out, output);
});
} // namespace contrib
} // namespace tvm
#include "../../runtime/metal/metal_common.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h> #include "mps_utils.h"
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -11,83 +7,81 @@ namespace contrib { ...@@ -11,83 +7,81 @@ namespace contrib {
using namespace runtime; using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0]; DLTensor *A = args[0];
DLTensor *B = args[1]; DLTensor *B = args[1];
DLTensor *C = args[2]; DLTensor *C = args[2];
bool transa = args[3]; bool transa = args[3];
bool transb = args[4]; bool transb = args[4];
// call gemm for simple compact code. // call gemm for simple compact code.
CHECK_EQ(A->ndim, 2); CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2); CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2); CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr); CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr); CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr); CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32)); CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32)); CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API // Get Metal device API
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
CHECK_EQ(A->ctx, B->ctx); // CHECK_EQ(A->ctx, B->ctx);
CHECK_EQ(A->ctx, C->ctx); // CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx); id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx); id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer]; id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + transa?1:0]; NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
NSUInteger N = B->shape[1 - transb?1:0]; NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
NSUInteger K = B->shape[0 + transb?1:0]; NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
CHECK_EQ(A->shape[1-transa?1:0], K);
// mps a CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); // mps a
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
matrixDescriptorWithDimensions:M MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
columns:K matrixDescriptorWithDimensions:M
rowBytes:M * sizeof(dtype) columns:K
dataType:dtype]; rowBytes:K * sizeof(MPSDataTypeFloat32)
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data); dataType:MPSDataTypeFloat32];
MPSMatrix *matrixA = id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; MPSMatrix *matrixA =
// mps b [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor // mps b
matrixDescriptorWithDimensions:K MPSMatrixDescriptor *descB =
columns:N [MPSMatrixDescriptor matrixDescriptorWithDimensions:K
rowBytes:K * sizeof(dtype) columns:N
dataType:dtype]; rowBytes:N * sizeof(dtype)
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data); dataType:dtype];
MPSMatrix *matrixB = id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; MPSMatrix *matrixB =
// mps c [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor // mps c
matrixDescriptorWithDimensions:M MPSMatrixDescriptor *descC =
columns:N [MPSMatrixDescriptor matrixDescriptorWithDimensions:M
rowBytes:M * sizeof(dtype) columns:N
dataType:dtype]; rowBytes:N * sizeof(dtype)
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data); dataType:dtype];
MPSMatrix *matrixC = id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; MPSMatrix *matrixC =
// kernel [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa transposeLeft:transa
transposeRight:transb transposeRight:transb
resultRows:M resultRows:M
resultColumns:N resultColumns:N
interiorColumns:K interiorColumns:K
alpha:1.0f alpha:1.0f
beta:0.0f]; beta:0.0f];
CHECK(sgemm != nil); CHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb [sgemm encodeToCommandBuffer:cb
leftMatrix:matrixA leftMatrix:matrixA
rightMatrix:matrixB rightMatrix:matrixB
resultMatrix:matrixC]; resultMatrix:matrixC];
[cb commit]; [cb commit];
[mul_obj dealloc];
[matrixA dealloc]; });
[matrixB dealloc];
[matrixC dealloc];
});
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -6,11 +6,15 @@ ...@@ -6,11 +6,15 @@
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_ #ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_ #define TVM_CONTRIB_MPS_MPS_UTILS_H_
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
#include "../../runtime/metal/metal_common.h" #include "../../runtime/metal/metal_common.h"
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -19,12 +23,15 @@ struct MPSType { ...@@ -19,12 +23,15 @@ struct MPSType {
static MPSDataType DLTypeToMPSType(const DLDataType &dtype); static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
}; // struct MPSType }; // struct MPSType
struct MetalThreadEntry { struct MetalThreadEntry {
MetalThreadEntry(); MetalThreadEntry();
~MetalThreadEntry(); ~MetalThreadEntry();
runtime::MetalWorkspace *metal_api{nullptr}; MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
static MetalThreadEntry* ThreadLocal(); MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc);
runtime::metal::MetalWorkspace *metal_api{nullptr};
static MetalThreadEntry *ThreadLocal();
std::vector<MPSImage *> img_table;
}; // MetalThreadEntry }; // MetalThreadEntry
} // namespace contrib } // namespace contrib
......
...@@ -3,10 +3,6 @@ ...@@ -3,10 +3,6 @@
* \file Use external mps utils function * \file Use external mps utils function
*/ */
#include "mps_utils.h" #include "mps_utils.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm { namespace tvm {
namespace contrib { namespace contrib {
...@@ -14,31 +10,54 @@ namespace contrib { ...@@ -14,31 +10,54 @@ namespace contrib {
// MPS Data Type // MPS Data Type
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
switch (dtype.code) { switch (dtype.code) {
case kDLInt: case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8; if (dtype.bits == 8 && dtype.lanes == 1)
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16; return MPSDataTypeInt8;
else else if (dtype.bits == 16 && dtype.lanes == 1)
LOG(FATAL) << "Unsupported type"; return MPSDataTypeInt16;
break; else
case kDLUInt: LOG(FATAL) << "Unsupported type";
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8; break;
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16; case kDLUInt:
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32; if (dtype.bits == 8 && dtype.lanes == 1)
LOG(FATAL) << "Unsupported type"; return MPSDataTypeUInt8;
break; else if (dtype.bits == 16 && dtype.lanes == 1)
case kDLFloat: return MPSDataTypeUInt16;
if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16; else if (dtype.bits == 32 && dtype.lanes == 1)
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32; return MPSDataTypeUInt32;
else LOG(FATAL) << "Unsupported type";
LOG(FATAL) << "Unsupported type"; break;
break; case kDLFloat:
default: if (dtype.bits == 16 && dtype.lanes == 1)
LOG(FATAL) << "Unsupported type"; return MPSDataTypeFloat16;
} else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
return MPSDataTypeFloat32;
} }
// MetalThreadEntry // MetalThreadEntry
MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
MPSImageDescriptor *desc) {
MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
img_table.push_back(mpsimg);
return mpsimg;
}
MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc) {
MPSTemporaryImage *mpsimg =
[MPSTemporaryImage temporaryImageWithCommandBuffer:cb
imageDescriptor:desc];
return mpsimg;
}
MetalThreadEntry::MetalThreadEntry() { MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal"); auto func = runtime::Registry::Get("device_api.metal");
void *ret = (*func)(); void *ret = (*func)();
...@@ -46,13 +65,16 @@ MetalThreadEntry::MetalThreadEntry() { ...@@ -46,13 +65,16 @@ MetalThreadEntry::MetalThreadEntry() {
} }
MetalThreadEntry::~MetalThreadEntry() { MetalThreadEntry::~MetalThreadEntry() {
for (int i = 0; i < img_table.size(); ++i) {
[img_table[i] dealloc];
}
} }
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore; typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry* MetalThreadEntry::ThreadLocal() { MetalThreadEntry *MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get(); return MetalThreadStore::Get();
} }
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
...@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace( ...@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) { TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
this->Init(); this->Init();
id<MTLDevice> dev = GetDevice(ctx); id<MTLDevice> dev = GetDevice(ctx);
// allocate buffer in GPU only mode. // GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [ id<MTLBuffer> buf = [
dev newBufferWithLength:nbytes dev newBufferWithLength:nbytes
options:MTLResourceStorageModePrivate]; options:storage_mode];
CHECK(buf != nil); CHECK(buf != nil);
return (__bridge void*)([buf retain]); return (__bridge void*)([buf retain]);
} }
......
...@@ -2,39 +2,83 @@ import tvm ...@@ -2,39 +2,83 @@ import tvm
import numpy as np import numpy as np
from tvm.contrib import mps from tvm.contrib import mps
def test_matmul_add(): def test_matmul():
if not tvm.module.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1024 n = 1024
l = 128 l = 128
m = 235 m = 256
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A') A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B') B = tvm.placeholder((l, m), name='B')
C1 = mps.matmul(A, B) C = mps.matmul(A, B)
C2 = mps.matmul(B, A, True, True) D = tvm.compute(
D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1") C.shape,
D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2") lambda *i: C(*i) + 1.
s1 = tvm.create_schedule(D1.op) )
s2 = tvm.create_schedule(D2.op) s = tvm.create_schedule(D.op)
yo, xo = D.op.axis
def verify(A, B, D, s, bias, target="llvm"): block_y = tvm.thread_axis("blockIdx.y")
if not tvm.module.enabled(target): block_x = tvm.thread_axis("blockIdx.x")
print("skip because %s is not enabled..." % target) thread_y = tvm.thread_axis("threadIdx.y")
return thread_x = tvm.thread_axis("threadIdx.x")
by, ty = s[D].split(yo, factor=16)
bx, tx = s[D].split(xo, factor=16)
s[D].bind(by, block_y)
s[D].bind(bx, block_x)
s[D].bind(ty, thread_y)
s[D].bind(tx, thread_x)
def verify(A, B, D, s, target="metal"):
if not tvm.get_global_func("tvm.contrib.mps.matmul", True): if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
print("skip because extern function is not avalable") print("skip because extern function is not avalable")
return return
ctx = tvm.cpu(0) ctx = tvm.metal(0)
f = tvm.build(s, [A, B, D, bias], target) f = tvm.build(s, [A, B, D], "metal")
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
bb = 10.0 f(a, b, c)
f(a, b, d, bb)
np.testing.assert_allclose( np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
verify(A, B, D1, s1, bias) verify(A, B, D, s)
verify(A, B, D2, s2, bias)
def test_conv2d():
if not tvm.module.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1
h = 14
w = 14
ci = 2
co = 4
kh = 3
kw = 3
stride = 2
A = tvm.placeholder((n, h, w, ci), name="x")
B = tvm.placeholder((co, kh, kw, ci), name="w")
C = mps.conv2d(A, B, 'SAME', 2)
s1 = tvm.create_schedule(C.op)
def verify(A, B, C, target="llvm"):
if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
print("skip because extern function is not avalable")
return
ctx = tvm.metal(0)
f = tvm.build(s1, [A, B, C], "metal")
a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
f(a, b, c)
# print(c.asnumpy())
# print(c.shape)
verify(A, B, C, s1)
if __name__ == "__main__": if __name__ == "__main__":
test_matmul_add() #test_matmul()
test_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment