Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
078c767c
Commit
078c767c
authored
Mar 06, 2018
by
Leyuan Wang
Committed by
Tianqi Chen
Mar 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
MPS conv (#822)
parent
6731f660
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
418 additions
and
145 deletions
+418
-145
make/contrib/mps.mk
+8
-2
python/tvm/contrib/mps.py
+40
-4
python/tvm/contrib/rpc_proxy.py
+2
-0
src/contrib/mps/conv.mm
+154
-0
src/contrib/mps/gemm.mm
+76
-82
src/contrib/mps/mps_utils.h
+11
-4
src/contrib/mps/mps_utils.mm
+50
-28
src/runtime/metal/metal_device_api.mm
+10
-2
tests/python/contrib/test_mps.py
+67
-23
No files found.
make/contrib/mps.mk
View file @
078c767c
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm
, src/contrib/mps/*.cc
)
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm)
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
ifeq ($(USE_MPS), 1)
ifeq ($(USE_MPS), 1)
...
@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
...
@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
CFLAGS +=
CFLAGS +=
ADD_LDFLAGS +=
ADD_LDFLAGS +=
RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
CONTRIB_OBJ += $(MPS_CONTRIB_OBJ)
endif
endif
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc
build/contrib/mps/%.o: src/contrib/mps/%.mm
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/contrib/mps/%.o: src/contrib/mps/%.cc
@mkdir -p $(@D)
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
python/tvm/contrib/mps.py
View file @
078c767c
"""External function interface to MPS libraroes."""
"""External function interface to MPS libraroes."""
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
from
..
import
api
as
_api
from
..
import
api
as
_api
from
..
import
intrin
as
_intrin
from
..
import
intrin
as
_intrin
# pylint: disable=C0103,W0612
def
matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
):
def
matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
...
@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
...
@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
C : Tensor
C : Tensor
The result tensor.
The result tensor.
"""
"""
m
=
lhs
.
shape
[
0
]
m
=
lhs
.
shape
[
0
]
if
transa
is
False
else
lhs
.
shape
[
1
]
n
=
rhs
.
shape
[
1
]
n
=
rhs
.
shape
[
1
]
if
transb
is
False
else
rhs
.
shape
[
0
]
if
transa
:
m
=
b
if
transb
:
n
=
c
return
_api
.
extern
(
return
_api
.
extern
(
(
n
,
m
),
[
lhs
,
rhs
],
(
m
,
n
),
[
lhs
,
rhs
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.mps.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
"tvm.contrib.mps.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
name
=
"C"
)
def
conv2d
(
data
,
weight
,
pad
=
'SAME'
,
stride
=
1
):
"""
Create an extern op that compute data * weight and return result in output
Parameters:
----------
data: Tensor
The input data, format NHWC
weight: Tensor
The conv weight, format output_feature * kH * kW * input_feature
pad: str
Padding method, 'SAME' or 'VALID'
stride: int
convolution stride
Returns
-------
output: Tensor
The result tensor
"""
n
,
hi
,
wi
,
ci
=
data
.
shape
co
,
kh
,
kw
,
ciw
=
weight
.
shape
padding
=
0
if
pad
==
'SAME'
else
1
ho
=
hi
//
stride
wo
=
wi
//
stride
return
_api
.
extern
(
(
n
,
ho
,
wo
,
co
),
[
data
,
weight
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.mps.conv2d"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
padding
,
stride
),
name
=
"C"
)
python/tvm/contrib/rpc_proxy.py
View file @
078c767c
...
@@ -70,6 +70,7 @@ class ForwardHandler(object):
...
@@ -70,6 +70,7 @@ class ForwardHandler(object):
ProxyServerHandler
.
current
.
handler_ready
(
self
)
ProxyServerHandler
.
current
.
handler_ready
(
self
)
def
on_data
(
self
,
message
):
def
on_data
(
self
,
message
):
"""on data"""
assert
isinstance
(
message
,
bytes
)
assert
isinstance
(
message
,
bytes
)
if
self
.
forward_proxy
:
if
self
.
forward_proxy
:
self
.
forward_proxy
.
send_data
(
message
)
self
.
forward_proxy
.
send_data
(
message
)
...
@@ -98,6 +99,7 @@ class ForwardHandler(object):
...
@@ -98,6 +99,7 @@ class ForwardHandler(object):
self
.
close
()
self
.
close
()
def
on_close_event
(
self
):
def
on_close_event
(
self
):
"""on close event"""
assert
not
self
.
_done
assert
not
self
.
_done
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
self
.
_done
=
True
self
.
_done
=
True
...
...
src/contrib/mps/conv.mm
0 → 100644
View file @
078c767c
#include "mps_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *buf = args[0];
DLTensor *img = args[1];
// copy to temp
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr
);
MPSImageDescriptor *desc = [MPSImageDescriptor
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
width:buf->shape[2]
height:buf->shape[1]
featureChannels:buf->shape[3]];
MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
[mpsimg writeBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
img->data = (__bridge void *)mpsimg;
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *img = args[0];
DLTensor *buf = args[1];
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr);
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
.set_body([](TVMArgs args, TVMRetValue *ret) {
// MPS-NHWC
DLTensor *data = args[0];
DLTensor *weight = args[1];
DLTensor *output = args[2];
int pad = args[3];
int stride = args[4];
CHECK_EQ(data->ndim, 4);
CHECK_EQ(weight->ndim, 4);
CHECK_EQ(output->ndim, 4);
CHECK(output->strides == nullptr);
CHECK(weight->strides == nullptr);
CHECK(data->strides == nullptr);
CHECK_EQ(data->shape[0], 1);
CHECK_EQ(output->shape[0], 1);
int oCh = weight->shape[0];
int kH = weight->shape[1];
int kW = weight->shape[2];
int iCh = weight->shape[3];
auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
// Get Metal device API
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
id<MTLCommandQueue> queue =
entry_ptr->metal_api->GetCommandQueue(data->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
// data to MPSImage
DLTensor tmp_in;
(*f_buf2img)(data, &tmp_in);
MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
// weight to temp memory
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
weight->ctx, weight->ctx, nullptr);
float *ptr_w = (float *)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
(*f_buf2img)(output, &tmp_out);
MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
// conv desc
MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
cnnConvolutionDescriptorWithKernelWidth:kW
kernelHeight:kH
inputFeatureChannels:iCh
outputFeatureChannels:oCh];
[conv_desc setStrideInPixelsX:stride];
[conv_desc setStrideInPixelsY:stride];
MPSCNNConvolution *conv =
[[MPSCNNConvolution alloc] initWithDevice:dev
convolutionDescriptor:conv_desc
kernelWeights:ptr_w
biasTerms:nil
flags:MPSCNNConvolutionFlagsNone];
if (pad == 0) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeSame];
} else if (pad == 1) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeValidOnly];
}
[conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
[cb commit];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder synchronizeResource:tempC.texture];
[encoder endEncoding];
[cb waitUntilCompleted];
(*f_img2buf)(&tmp_out, output);
});
} // namespace contrib
} // namespace tvm
src/contrib/mps/gemm.mm
View file @
078c767c
#include "../../runtime/metal/metal_common.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include "mps_utils.h"
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
namespace tvm {
namespace tvm {
namespace contrib {
namespace contrib {
...
@@ -11,83 +7,81 @@ namespace contrib {
...
@@ -11,83 +7,81 @@ namespace contrib {
using namespace runtime;
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0];
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *B = args[1];
DLTensor *C = args[2];
DLTensor *C = args[2];
bool transa = args[3];
bool transa = args[3];
bool transb = args[4];
bool transb = args[4];
// call gemm for simple compact code.
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
// Get Metal device API
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
CHECK_EQ(A->ctx, B->ctx);
// CHECK_EQ(A->ctx, B->ctx);
CHECK_EQ(A->ctx, C->ctx);
// CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + transa?1:0];
NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
NSUInteger N = B->shape[1 - transb?1:0];
NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
NSUInteger K = B->shape[0 + transb?1:0];
NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
CHECK_EQ(A->shape[1-transa?1:0], K);
// mps a
CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
// mps a
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
matrixDescriptorWithDimensions:M
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
columns:K
matrixDescriptorWithDimensions:M
rowBytes:M * sizeof(dtype)
columns:K
dataType:dtype];
rowBytes:K * sizeof(MPSDataTypeFloat32)
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
dataType:MPSDataTypeFloat32];
MPSMatrix *matrixA =
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
MPSMatrix *matrixA =
// mps b
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
// mps b
matrixDescriptorWithDimensions:K
MPSMatrixDescriptor *descB =
columns:N
[MPSMatrixDescriptor matrixDescriptorWithDimensions:K
rowBytes:K * sizeof(dtype)
columns:N
dataType:dtype];
rowBytes:N * sizeof(dtype)
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
dataType:dtype];
MPSMatrix *matrixB =
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
MPSMatrix *matrixB =
// mps c
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
// mps c
matrixDescriptorWithDimensions:M
MPSMatrixDescriptor *descC =
columns:N
[MPSMatrixDescriptor matrixDescriptorWithDimensions:M
rowBytes:M * sizeof(dtype)
columns:N
dataType:dtype];
rowBytes:N * sizeof(dtype)
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
dataType:dtype];
MPSMatrix *matrixC =
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
MPSMatrix *matrixC =
// kernel
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeLeft:transa
transposeRight:transb
transposeRight:transb
resultRows:M
resultRows:M
resultColumns:N
resultColumns:N
interiorColumns:K
interiorColumns:K
alpha:1.0f
alpha:1.0f
beta:0.0f];
beta:0.0f];
CHECK(sgemm != nil);
CHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb
[sgemm encodeToCommandBuffer:cb
leftMatrix:matrixA
leftMatrix:matrixA
rightMatrix:matrixB
rightMatrix:matrixB
resultMatrix:matrixC];
resultMatrix:matrixC];
[cb commit];
[cb commit];
[mul_obj dealloc];
[matrixA dealloc];
});
[matrixB dealloc];
[matrixC dealloc];
});
} // namespace contrib
} // namespace contrib
} // namespace tvm
} // namespace tvm
src/contrib/mps/mps_utils.h
View file @
078c767c
...
@@ -6,11 +6,15 @@
...
@@ -6,11 +6,15 @@
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
#include "../../runtime/metal/metal_common.h"
#include "../../runtime/metal/metal_common.h"
namespace
tvm
{
namespace
tvm
{
namespace
contrib
{
namespace
contrib
{
...
@@ -19,12 +23,15 @@ struct MPSType {
...
@@ -19,12 +23,15 @@ struct MPSType {
static
MPSDataType
DLTypeToMPSType
(
const
DLDataType
&
dtype
);
static
MPSDataType
DLTypeToMPSType
(
const
DLDataType
&
dtype
);
};
// struct MPSType
};
// struct MPSType
struct
MetalThreadEntry
{
struct
MetalThreadEntry
{
MetalThreadEntry
();
MetalThreadEntry
();
~
MetalThreadEntry
();
~
MetalThreadEntry
();
runtime
::
MetalWorkspace
*
metal_api
{
nullptr
};
MPSImage
*
AllocMPSImage
(
id
<
MTLDevice
>
dev
,
MPSImageDescriptor
*
desc
);
static
MetalThreadEntry
*
ThreadLocal
();
MPSTemporaryImage
*
AllocTempImage
(
id
<
MTLCommandBuffer
>
cb
,
MPSImageDescriptor
*
desc
);
runtime
::
metal
::
MetalWorkspace
*
metal_api
{
nullptr
};
static
MetalThreadEntry
*
ThreadLocal
();
std
::
vector
<
MPSImage
*>
img_table
;
};
// MetalThreadEntry
};
// MetalThreadEntry
}
// namespace contrib
}
// namespace contrib
...
...
src/contrib/mps/mps_utils.
cc
→
src/contrib/mps/mps_utils.
mm
View file @
078c767c
...
@@ -3,10 +3,6 @@
...
@@ -3,10 +3,6 @@
* \file Use external mps utils function
* \file Use external mps utils function
*/
*/
#include "mps_utils.h"
#include "mps_utils.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tvm {
namespace contrib {
namespace contrib {
...
@@ -14,31 +10,54 @@ namespace contrib {
...
@@ -14,31 +10,54 @@ namespace contrib {
// MPS Data Type
// MPS Data Type
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
switch (dtype.code) {
switch (dtype.code) {
case
kDLInt
:
case kDLInt:
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeInt8
;
if (dtype.bits == 8 && dtype.lanes == 1)
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeInt16
;
return MPSDataTypeInt8;
else
else if (dtype.bits == 16 && dtype.lanes == 1)
LOG
(
FATAL
)
<<
"Unsupported type"
;
return MPSDataTypeInt16;
break
;
else
case
kDLUInt
:
LOG(FATAL) << "Unsupported type";
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt8
;
break;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt16
;
case kDLUInt:
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt32
;
if (dtype.bits == 8 && dtype.lanes == 1)
LOG
(
FATAL
)
<<
"Unsupported type"
;
return MPSDataTypeUInt8;
break
;
else if (dtype.bits == 16 && dtype.lanes == 1)
case
kDLFloat
:
return MPSDataTypeUInt16;
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeFloat16
;
else if (dtype.bits == 32 && dtype.lanes == 1)
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeFloat32
;
return MPSDataTypeUInt32;
else
LOG(FATAL) << "Unsupported type";
LOG
(
FATAL
)
<<
"Unsupported type"
;
break;
break
;
case kDLFloat:
default
:
if (dtype.bits == 16 && dtype.lanes == 1)
LOG
(
FATAL
)
<<
"Unsupported type"
;
return MPSDataTypeFloat16;
}
else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
return MPSDataTypeFloat32;
}
}
// MetalThreadEntry
// MetalThreadEntry
MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
MPSImageDescriptor *desc) {
MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
img_table.push_back(mpsimg);
return mpsimg;
}
MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc) {
MPSTemporaryImage *mpsimg =
[MPSTemporaryImage temporaryImageWithCommandBuffer:cb
imageDescriptor:desc];
return mpsimg;
}
MetalThreadEntry::MetalThreadEntry() {
MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal");
auto func = runtime::Registry::Get("device_api.metal");
void *ret = (*func)();
void *ret = (*func)();
...
@@ -46,13 +65,16 @@ MetalThreadEntry::MetalThreadEntry() {
...
@@ -46,13 +65,16 @@ MetalThreadEntry::MetalThreadEntry() {
}
}
MetalThreadEntry::~MetalThreadEntry() {
MetalThreadEntry::~MetalThreadEntry() {
for (int i = 0; i < img_table.size(); ++i) {
[img_table[i] dealloc];
}
}
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry
*
MetalThreadEntry
::
ThreadLocal
()
{
MetalThreadEntry
*
MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get();
return MetalThreadStore::Get();
}
}
}
// namespace contrib
} // namespace contrib
}
// namespace tvm
} // namespace tvm
src/runtime/metal/metal_device_api.mm
View file @
078c767c
...
@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
...
@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
this->Init();
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
id<MTLDevice> dev = GetDevice(ctx);
// allocate buffer in GPU only mode.
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [
id<MTLBuffer> buf = [
dev newBufferWithLength:nbytes
dev newBufferWithLength:nbytes
options:
MTLResourceStorageModePrivat
e];
options:
storage_mod
e];
CHECK(buf != nil);
CHECK(buf != nil);
return (__bridge void*)([buf retain]);
return (__bridge void*)([buf retain]);
}
}
...
...
tests/python/contrib/test_mps.py
View file @
078c767c
...
@@ -2,39 +2,83 @@ import tvm
...
@@ -2,39 +2,83 @@ import tvm
import
numpy
as
np
import
numpy
as
np
from
tvm.contrib
import
mps
from
tvm.contrib
import
mps
def
test_matmul_add
():
def
test_matmul
():
if
not
tvm
.
module
.
enabled
(
"metal"
):
print
(
"skip because
%
s is not enabled..."
%
"metal"
)
return
n
=
1024
n
=
1024
l
=
128
l
=
128
m
=
235
m
=
256
bias
=
tvm
.
var
(
'bias'
,
dtype
=
tvm
.
float32
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
l
,
m
),
name
=
'B'
)
B
=
tvm
.
placeholder
((
l
,
m
),
name
=
'B'
)
C1
=
mps
.
matmul
(
A
,
B
)
C
=
mps
.
matmul
(
A
,
B
)
C2
=
mps
.
matmul
(
B
,
A
,
True
,
True
)
D
=
tvm
.
compute
(
D1
=
tvm
.
compute
(
C1
.
shape
,
lambda
i
,
j
:
C1
[
i
,
j
]
+
bias
,
name
=
"D1"
)
C
.
shape
,
D2
=
tvm
.
compute
(
C2
.
shape
,
lambda
i
,
j
:
C2
[
i
,
j
]
+
bias
,
name
=
"D2"
)
lambda
*
i
:
C
(
*
i
)
+
1.
s1
=
tvm
.
create_schedule
(
D1
.
op
)
)
s2
=
tvm
.
create_schedule
(
D2
.
op
)
s
=
tvm
.
create_schedule
(
D
.
op
)
yo
,
xo
=
D
.
op
.
axis
def
verify
(
A
,
B
,
D
,
s
,
bias
,
target
=
"llvm"
):
block_y
=
tvm
.
thread_axis
(
"blockIdx.y"
)
if
not
tvm
.
module
.
enabled
(
target
):
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
print
(
"skip because
%
s is not enabled..."
%
target
)
thread_y
=
tvm
.
thread_axis
(
"threadIdx.y"
)
return
thread_x
=
tvm
.
thread_axis
(
"threadIdx.x"
)
by
,
ty
=
s
[
D
]
.
split
(
yo
,
factor
=
16
)
bx
,
tx
=
s
[
D
]
.
split
(
xo
,
factor
=
16
)
s
[
D
]
.
bind
(
by
,
block_y
)
s
[
D
]
.
bind
(
bx
,
block_x
)
s
[
D
]
.
bind
(
ty
,
thread_y
)
s
[
D
]
.
bind
(
tx
,
thread_x
)
def
verify
(
A
,
B
,
D
,
s
,
target
=
"metal"
):
if
not
tvm
.
get_global_func
(
"tvm.contrib.mps.matmul"
,
True
):
if
not
tvm
.
get_global_func
(
"tvm.contrib.mps.matmul"
,
True
):
print
(
"skip because extern function is not avalable"
)
print
(
"skip because extern function is not avalable"
)
return
return
ctx
=
tvm
.
cpu
(
0
)
ctx
=
tvm
.
metal
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
,
bias
],
target
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
],
"metal"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
l
))
.
astype
(
A
.
dtype
),
ctx
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
l
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
l
,
m
))
.
astype
(
B
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
l
,
m
))
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
bb
=
10.0
f
(
a
,
b
,
c
)
f
(
a
,
b
,
d
,
bb
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
bb
,
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
1
,
rtol
=
1e-5
)
verify
(
A
,
B
,
D1
,
s1
,
bias
)
verify
(
A
,
B
,
D
,
s
)
verify
(
A
,
B
,
D2
,
s2
,
bias
)
def
test_conv2d
():
if
not
tvm
.
module
.
enabled
(
"metal"
):
print
(
"skip because
%
s is not enabled..."
%
"metal"
)
return
n
=
1
h
=
14
w
=
14
ci
=
2
co
=
4
kh
=
3
kw
=
3
stride
=
2
A
=
tvm
.
placeholder
((
n
,
h
,
w
,
ci
),
name
=
"x"
)
B
=
tvm
.
placeholder
((
co
,
kh
,
kw
,
ci
),
name
=
"w"
)
C
=
mps
.
conv2d
(
A
,
B
,
'SAME'
,
2
)
s1
=
tvm
.
create_schedule
(
C
.
op
)
def
verify
(
A
,
B
,
C
,
target
=
"llvm"
):
if
not
tvm
.
get_global_func
(
"tvm.contrib.mps.conv2d"
,
True
):
print
(
"skip because extern function is not avalable"
)
return
ctx
=
tvm
.
metal
(
0
)
f
=
tvm
.
build
(
s1
,
[
A
,
B
,
C
],
"metal"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
h
,
w
,
ci
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
co
,
kh
,
kw
,
ci
))
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
h
//
stride
,
w
//
stride
,
co
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
# print(c.asnumpy())
# print(c.shape)
verify
(
A
,
B
,
C
,
s1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_matmul_add
()
#test_matmul()
test_conv2d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment