Commit 078c767c by Leyuan Wang Committed by Tianqi Chen

MPS conv (#822)

parent 6731f660
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc)
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm)
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
ifeq ($(USE_MPS), 1)
......@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
CFLAGS +=
ADD_LDFLAGS +=
RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
CONTRIB_OBJ += $(MPS_CONTRIB_OBJ)
endif
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc
build/contrib/mps/%.o: src/contrib/mps/%.mm
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/contrib/mps/%.o: src/contrib/mps/%.cc
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
"""External function interface to MPS libraroes."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
# pylint: disable=C0103,W0612
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
......@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
C : Tensor
The result tensor.
"""
m = lhs.shape[0]
n = rhs.shape[1]
m = lhs.shape[0] if transa is False else lhs.shape[1]
n = rhs.shape[1] if transb is False else rhs.shape[0]
if transa:
m = b
if transb:
n = c
return _api.extern(
(n, m), [lhs, rhs],
(m, n), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
name="C")
def conv2d(data, weight, pad='SAME', stride=1):
"""
Create an extern op that compute data * weight and return result in output
Parameters:
----------
data: Tensor
The input data, format NHWC
weight: Tensor
The conv weight, format output_feature * kH * kW * input_feature
pad: str
Padding method, 'SAME' or 'VALID'
stride: int
convolution stride
Returns
-------
output: Tensor
The result tensor
"""
n, hi, wi, ci = data.shape
co, kh, kw, ciw = weight.shape
padding = 0 if pad == 'SAME' else 1
ho = hi // stride
wo = wi // stride
return _api.extern(
(n, ho, wo, co), [data, weight],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride),
name="C")
......@@ -70,6 +70,7 @@ class ForwardHandler(object):
ProxyServerHandler.current.handler_ready(self)
def on_data(self, message):
"""on data"""
assert isinstance(message, bytes)
if self.forward_proxy:
self.forward_proxy.send_data(message)
......@@ -98,6 +99,7 @@ class ForwardHandler(object):
self.close()
def on_close_event(self):
"""on close event"""
assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name())
self._done = True
......
#include "mps_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *buf = args[0];
DLTensor *img = args[1];
// copy to temp
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr
);
MPSImageDescriptor *desc = [MPSImageDescriptor
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
width:buf->shape[2]
height:buf->shape[1]
featureChannels:buf->shape[3]];
MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
[mpsimg writeBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
img->data = (__bridge void *)mpsimg;
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *img = args[0];
DLTensor *buf = args[1];
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr);
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
.set_body([](TVMArgs args, TVMRetValue *ret) {
// MPS-NHWC
DLTensor *data = args[0];
DLTensor *weight = args[1];
DLTensor *output = args[2];
int pad = args[3];
int stride = args[4];
CHECK_EQ(data->ndim, 4);
CHECK_EQ(weight->ndim, 4);
CHECK_EQ(output->ndim, 4);
CHECK(output->strides == nullptr);
CHECK(weight->strides == nullptr);
CHECK(data->strides == nullptr);
CHECK_EQ(data->shape[0], 1);
CHECK_EQ(output->shape[0], 1);
int oCh = weight->shape[0];
int kH = weight->shape[1];
int kW = weight->shape[2];
int iCh = weight->shape[3];
auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
// Get Metal device API
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
id<MTLCommandQueue> queue =
entry_ptr->metal_api->GetCommandQueue(data->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
// data to MPSImage
DLTensor tmp_in;
(*f_buf2img)(data, &tmp_in);
MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
// weight to temp memory
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
weight->ctx, weight->ctx, nullptr);
float *ptr_w = (float *)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
(*f_buf2img)(output, &tmp_out);
MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
// conv desc
MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
cnnConvolutionDescriptorWithKernelWidth:kW
kernelHeight:kH
inputFeatureChannels:iCh
outputFeatureChannels:oCh];
[conv_desc setStrideInPixelsX:stride];
[conv_desc setStrideInPixelsY:stride];
MPSCNNConvolution *conv =
[[MPSCNNConvolution alloc] initWithDevice:dev
convolutionDescriptor:conv_desc
kernelWeights:ptr_w
biasTerms:nil
flags:MPSCNNConvolutionFlagsNone];
if (pad == 0) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeSame];
} else if (pad == 1) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeValidOnly];
}
[conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
[cb commit];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder synchronizeResource:tempC.texture];
[encoder endEncoding];
[cb waitUntilCompleted];
(*f_img2buf)(&tmp_out, output);
});
} // namespace contrib
} // namespace tvm
#include "../../runtime/metal/metal_common.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include "mps_utils.h"
namespace tvm {
namespace contrib {
......@@ -11,83 +7,81 @@ namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
CHECK_EQ(A->ctx, B->ctx);
CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + transa?1:0];
NSUInteger N = B->shape[1 - transb?1:0];
NSUInteger K = B->shape[0 + transb?1:0];
CHECK_EQ(A->shape[1-transa?1:0], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:K
rowBytes:M * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
MPSMatrix *matrixA =
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:K
columns:N
rowBytes:K * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
MPSMatrix *matrixB =
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:N
rowBytes:M * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
MPSMatrix *matrixC =
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
// CHECK_EQ(A->ctx, B->ctx);
// CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:K
rowBytes:K * sizeof(MPSDataTypeFloat32)
dataType:MPSDataTypeFloat32];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
MPSMatrix *matrixA =
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
MPSMatrixDescriptor *descB =
[MPSMatrixDescriptor matrixDescriptorWithDimensions:K
columns:N
rowBytes:N * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
MPSMatrix *matrixB =
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
MPSMatrixDescriptor *descC =
[MPSMatrixDescriptor matrixDescriptorWithDimensions:M
columns:N
rowBytes:N * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
MPSMatrix *matrixC =
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeRight:transb
resultRows:M
resultColumns:N
interiorColumns:K
alpha:1.0f
beta:0.0f];
CHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb
leftMatrix:matrixA
rightMatrix:matrixB
resultMatrix:matrixC];
[cb commit];
[mul_obj dealloc];
[matrixA dealloc];
[matrixB dealloc];
[matrixC dealloc];
});
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeRight:transb
resultRows:M
resultColumns:N
interiorColumns:K
alpha:1.0f
beta:0.0f];
CHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb
leftMatrix:matrixA
rightMatrix:matrixB
resultMatrix:matrixC];
[cb commit];
});
} // namespace contrib
} // namespace tvm
......@@ -6,11 +6,15 @@
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
#include "../../runtime/metal/metal_common.h"
namespace tvm {
namespace contrib {
......@@ -19,12 +23,15 @@ struct MPSType {
static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
}; // struct MPSType
struct MetalThreadEntry {
MetalThreadEntry();
~MetalThreadEntry();
runtime::MetalWorkspace *metal_api{nullptr};
static MetalThreadEntry* ThreadLocal();
MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc);
runtime::metal::MetalWorkspace *metal_api{nullptr};
static MetalThreadEntry *ThreadLocal();
std::vector<MPSImage *> img_table;
}; // MetalThreadEntry
} // namespace contrib
......
......@@ -3,10 +3,6 @@
* \file Use external mps utils function
*/
#include "mps_utils.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace contrib {
......@@ -14,31 +10,54 @@ namespace contrib {
// MPS Data Type
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
switch (dtype.code) {
case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8;
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16;
else
LOG(FATAL) << "Unsupported type";
break;
case kDLUInt:
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8;
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16;
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32;
LOG(FATAL) << "Unsupported type";
break;
case kDLFloat:
if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16;
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1)
return MPSDataTypeInt8;
else if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeInt16;
else
LOG(FATAL) << "Unsupported type";
break;
case kDLUInt:
if (dtype.bits == 8 && dtype.lanes == 1)
return MPSDataTypeUInt8;
else if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeUInt16;
else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeUInt32;
LOG(FATAL) << "Unsupported type";
break;
case kDLFloat:
if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeFloat16;
else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
return MPSDataTypeFloat32;
}
// MetalThreadEntry
MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
MPSImageDescriptor *desc) {
MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
img_table.push_back(mpsimg);
return mpsimg;
}
MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc) {
MPSTemporaryImage *mpsimg =
[MPSTemporaryImage temporaryImageWithCommandBuffer:cb
imageDescriptor:desc];
return mpsimg;
}
MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal");
void *ret = (*func)();
......@@ -46,13 +65,16 @@ MetalThreadEntry::MetalThreadEntry() {
}
MetalThreadEntry::~MetalThreadEntry() {
for (int i = 0; i < img_table.size(); ++i) {
[img_table[i] dealloc];
}
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
MetalThreadEntry *MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get();
}
} // namespace contrib
} // namespace tvm
} // namespace contrib
} // namespace tvm
......@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// allocate buffer in GPU only mode.
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [
dev newBufferWithLength:nbytes
options:MTLResourceStorageModePrivate];
options:storage_mode];
CHECK(buf != nil);
return (__bridge void*)([buf retain]);
}
......
......@@ -2,39 +2,83 @@ import tvm
import numpy as np
from tvm.contrib import mps
def test_matmul_add():
def test_matmul():
if not tvm.module.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
m = 256
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C1 = mps.matmul(A, B)
C2 = mps.matmul(B, A, True, True)
D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1")
D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2")
s1 = tvm.create_schedule(D1.op)
s2 = tvm.create_schedule(D2.op)
def verify(A, B, D, s, bias, target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
C = mps.matmul(A, B)
D = tvm.compute(
C.shape,
lambda *i: C(*i) + 1.
)
s = tvm.create_schedule(D.op)
yo, xo = D.op.axis
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis("threadIdx.y")
thread_x = tvm.thread_axis("threadIdx.x")
by, ty = s[D].split(yo, factor=16)
bx, tx = s[D].split(xo, factor=16)
s[D].bind(by, block_y)
s[D].bind(bx, block_x)
s[D].bind(ty, thread_y)
s[D].bind(tx, thread_x)
def verify(A, B, D, s, target="metal"):
if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
ctx = tvm.metal(0)
f = tvm.build(s, [A, B, D], "metal")
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
verify(A, B, D1, s1, bias)
verify(A, B, D2, s2, bias)
c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 1, rtol=1e-5)
verify(A, B, D, s)
def test_conv2d():
if not tvm.module.enabled("metal"):
print("skip because %s is not enabled..." % "metal")
return
n = 1
h = 14
w = 14
ci = 2
co = 4
kh = 3
kw = 3
stride = 2
A = tvm.placeholder((n, h, w, ci), name="x")
B = tvm.placeholder((co, kh, kw, ci), name="w")
C = mps.conv2d(A, B, 'SAME', 2)
s1 = tvm.create_schedule(C.op)
def verify(A, B, C, target="llvm"):
if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
print("skip because extern function is not avalable")
return
ctx = tvm.metal(0)
f = tvm.build(s1, [A, B, C], "metal")
a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
f(a, b, c)
# print(c.asnumpy())
# print(c.shape)
verify(A, B, C, s1)
if __name__ == "__main__":
test_matmul_add()
#test_matmul()
test_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment