Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
078c767c
Commit
078c767c
authored
Mar 06, 2018
by
Leyuan Wang
Committed by
Tianqi Chen
Mar 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
MPS conv (#822)
parent
6731f660
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
346 additions
and
73 deletions
+346
-73
make/contrib/mps.mk
+8
-2
python/tvm/contrib/mps.py
+40
-4
python/tvm/contrib/rpc_proxy.py
+2
-0
src/contrib/mps/conv.mm
+154
-0
src/contrib/mps/gemm.mm
+20
-26
src/contrib/mps/mps_utils.h
+11
-4
src/contrib/mps/mps_utils.mm
+34
-12
src/runtime/metal/metal_device_api.mm
+10
-2
tests/python/contrib/test_mps.py
+67
-23
No files found.
make/contrib/mps.mk
View file @
078c767c
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm
, src/contrib/mps/*.cc
)
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm)
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
ifeq ($(USE_MPS), 1)
...
...
@@ -6,9 +6,15 @@ FRAMEWORKS += -framework MetalPerformanceShaders
CFLAGS +=
ADD_LDFLAGS +=
RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
CONTRIB_OBJ += $(MPS_CONTRIB_OBJ)
endif
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc
build/contrib/mps/%.o: src/contrib/mps/%.mm
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
build/contrib/mps/%.o: src/contrib/mps/%.cc
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
python/tvm/contrib/mps.py
View file @
078c767c
"""External function interface to MPS libraroes."""
from
__future__
import
absolute_import
as
_abs
from
..
import
api
as
_api
from
..
import
intrin
as
_intrin
# pylint: disable=C0103,W0612
def
matmul
(
lhs
,
rhs
,
transa
=
False
,
transb
=
False
):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
...
...
@@ -26,10 +26,46 @@ def matmul(lhs, rhs, transa=False, transb=False):
C : Tensor
The result tensor.
"""
m
=
lhs
.
shape
[
0
]
n
=
rhs
.
shape
[
1
]
m
=
lhs
.
shape
[
0
]
if
transa
is
False
else
lhs
.
shape
[
1
]
n
=
rhs
.
shape
[
1
]
if
transb
is
False
else
rhs
.
shape
[
0
]
if
transa
:
m
=
b
if
transb
:
n
=
c
return
_api
.
extern
(
(
n
,
m
),
[
lhs
,
rhs
],
(
m
,
n
),
[
lhs
,
rhs
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.mps.matmul"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
transa
,
transb
),
name
=
"C"
)
def
conv2d
(
data
,
weight
,
pad
=
'SAME'
,
stride
=
1
):
"""
Create an extern op that compute data * weight and return result in output
Parameters:
----------
data: Tensor
The input data, format NHWC
weight: Tensor
The conv weight, format output_feature * kH * kW * input_feature
pad: str
Padding method, 'SAME' or 'VALID'
stride: int
convolution stride
Returns
-------
output: Tensor
The result tensor
"""
n
,
hi
,
wi
,
ci
=
data
.
shape
co
,
kh
,
kw
,
ciw
=
weight
.
shape
padding
=
0
if
pad
==
'SAME'
else
1
ho
=
hi
//
stride
wo
=
wi
//
stride
return
_api
.
extern
(
(
n
,
ho
,
wo
,
co
),
[
data
,
weight
],
lambda
ins
,
outs
:
_intrin
.
call_packed
(
"tvm.contrib.mps.conv2d"
,
ins
[
0
],
ins
[
1
],
outs
[
0
],
padding
,
stride
),
name
=
"C"
)
python/tvm/contrib/rpc_proxy.py
View file @
078c767c
...
...
@@ -70,6 +70,7 @@ class ForwardHandler(object):
ProxyServerHandler
.
current
.
handler_ready
(
self
)
def
on_data
(
self
,
message
):
"""on data"""
assert
isinstance
(
message
,
bytes
)
if
self
.
forward_proxy
:
self
.
forward_proxy
.
send_data
(
message
)
...
...
@@ -98,6 +99,7 @@ class ForwardHandler(object):
self
.
close
()
def
on_close_event
(
self
):
"""on close event"""
assert
not
self
.
_done
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
self
.
_done
=
True
...
...
src/contrib/mps/conv.mm
0 → 100644
View file @
078c767c
#include "mps_utils.h"
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *buf = args[0];
DLTensor *img = args[1];
// copy to temp
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr
);
MPSImageDescriptor *desc = [MPSImageDescriptor
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
width:buf->shape[2]
height:buf->shape[1]
featureChannels:buf->shape[3]];
MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc);
[mpsimg writeBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
img->data = (__bridge void *)mpsimg;
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *img = args[0];
DLTensor *buf = args[1];
id<MTLBuffer> mtlbuf = (__bridge id<MTLBuffer>)(buf->data);
MPSImage *mpsimg = (__bridge MPSImage *)(img->data);
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
[mpsimg readBytes:[temp contents]
dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels
imageIndex:0];
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length],
buf->ctx, buf->ctx, nullptr);
});
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d")
.set_body([](TVMArgs args, TVMRetValue *ret) {
// MPS-NHWC
DLTensor *data = args[0];
DLTensor *weight = args[1];
DLTensor *output = args[2];
int pad = args[3];
int stride = args[4];
CHECK_EQ(data->ndim, 4);
CHECK_EQ(weight->ndim, 4);
CHECK_EQ(output->ndim, 4);
CHECK(output->strides == nullptr);
CHECK(weight->strides == nullptr);
CHECK(data->strides == nullptr);
CHECK_EQ(data->shape[0], 1);
CHECK_EQ(output->shape[0], 1);
int oCh = weight->shape[0];
int kH = weight->shape[1];
int kW = weight->shape[2];
int iCh = weight->shape[3];
auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img");
auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer");
// Get Metal device API
MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal();
runtime::metal::MetalThreadEntry *rt =
runtime::metal::MetalThreadEntry::ThreadLocal();
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(data->ctx);
id<MTLCommandQueue> queue =
entry_ptr->metal_api->GetCommandQueue(data->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
// data to MPSImage
DLTensor tmp_in;
(*f_buf2img)(data, &tmp_in);
MPSImage *tempA = (__bridge MPSImage *)tmp_in.data;
// weight to temp memory
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
entry_ptr->metal_api->CopyDataFromTo(
(__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length],
weight->ctx, weight->ctx, nullptr);
float *ptr_w = (float *)[tempB contents];
// output to MPSImage
DLTensor tmp_out;
(*f_buf2img)(output, &tmp_out);
MPSImage *tempC = (__bridge MPSImage *)tmp_out.data;
// conv desc
MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor
cnnConvolutionDescriptorWithKernelWidth:kW
kernelHeight:kH
inputFeatureChannels:iCh
outputFeatureChannels:oCh];
[conv_desc setStrideInPixelsX:stride];
[conv_desc setStrideInPixelsY:stride];
MPSCNNConvolution *conv =
[[MPSCNNConvolution alloc] initWithDevice:dev
convolutionDescriptor:conv_desc
kernelWeights:ptr_w
biasTerms:nil
flags:MPSCNNConvolutionFlagsNone];
if (pad == 0) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeSame];
} else if (pad == 1) {
conv.padding = [MPSNNDefaultPadding
paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft |
MPSNNPaddingMethodAlignCentered |
MPSNNPaddingMethodSizeValidOnly];
}
[conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC];
[cb commit];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder synchronizeResource:tempC.texture];
[encoder endEncoding];
[cb waitUntilCompleted];
(*f_img2buf)(&tmp_out, output);
});
} // namespace contrib
} // namespace tvm
src/contrib/mps/gemm.mm
View file @
078c767c
#include "../../runtime/metal/metal_common.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include "mps_utils.h"
namespace tvm {
namespace contrib {
...
...
@@ -11,7 +7,7 @@ namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
...
...
@@ -28,40 +24,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
MetalThreadEntry*
entry_ptr = MetalThreadEntry::ThreadLocal();
CHECK_EQ(A->ctx, B->ctx);
CHECK_EQ(A->ctx, C->ctx);
MetalThreadEntry *
entry_ptr = MetalThreadEntry::ThreadLocal();
//
CHECK_EQ(A->ctx, B->ctx);
//
CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + transa?1:0];
NSUInteger N = B->shape[1 - transb?1:0];
NSUInteger K = B->shape[0 + transb?1:0];
CHECK_EQ(A->shape[1-transa?1:0], K);
NSUInteger M = A->shape[0 + (transa ? 1 : 0)];
NSUInteger N = B->shape[1 - (transb ? 1 : 0)];
NSUInteger K = B->shape[0 + (transb ? 1 : 0)];
CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:K
rowBytes:M * sizeof(dtype
)
dataType:dtype
];
rowBytes:K * sizeof(MPSDataTypeFloat32
)
dataType:MPSDataTypeFloat32
];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
MPSMatrix *matrixA =
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:K
MPSMatrixDescriptor *descB =
[MPSMatrixDescriptor
matrixDescriptorWithDimensions:K
columns:N
rowBytes:K
* sizeof(dtype)
rowBytes:N
* sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
MPSMatrix *matrixB =
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
MPSMatrixDescriptor *descC =
[MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:N
rowBytes:M
* sizeof(dtype)
rowBytes:N
* sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
MPSMatrix *matrixC =
...
...
@@ -83,10 +80,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
rightMatrix:matrixB
resultMatrix:matrixC];
[cb commit];
[mul_obj dealloc];
[matrixA dealloc];
[matrixB dealloc];
[matrixC dealloc];
});
} // namespace contrib
...
...
src/contrib/mps/mps_utils.h
View file @
078c767c
...
...
@@ -6,11 +6,15 @@
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
#include "../../runtime/metal/metal_common.h"
namespace
tvm
{
namespace
contrib
{
...
...
@@ -19,12 +23,15 @@ struct MPSType {
static
MPSDataType
DLTypeToMPSType
(
const
DLDataType
&
dtype
);
};
// struct MPSType
struct
MetalThreadEntry
{
MetalThreadEntry
();
~
MetalThreadEntry
();
runtime
::
MetalWorkspace
*
metal_api
{
nullptr
};
static
MetalThreadEntry
*
ThreadLocal
();
MPSImage
*
AllocMPSImage
(
id
<
MTLDevice
>
dev
,
MPSImageDescriptor
*
desc
);
MPSTemporaryImage
*
AllocTempImage
(
id
<
MTLCommandBuffer
>
cb
,
MPSImageDescriptor
*
desc
);
runtime
::
metal
::
MetalWorkspace
*
metal_api
{
nullptr
};
static
MetalThreadEntry
*
ThreadLocal
();
std
::
vector
<
MPSImage
*>
img_table
;
};
// MetalThreadEntry
}
// namespace contrib
...
...
src/contrib/mps/mps_utils.
cc
→
src/contrib/mps/mps_utils.
mm
View file @
078c767c
...
...
@@ -3,10 +3,6 @@
* \file Use external mps utils function
*/
#include "mps_utils.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace contrib {
...
...
@@ -15,30 +11,53 @@ namespace contrib {
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
switch (dtype.code) {
case kDLInt:
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeInt8
;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeInt16
;
if (dtype.bits == 8 && dtype.lanes == 1)
return MPSDataTypeInt8;
else if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeInt16;
else
LOG(FATAL) << "Unsupported type";
break;
case kDLUInt:
if
(
dtype
.
bits
==
8
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt8
;
else
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt16
;
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeUInt32
;
if (dtype.bits == 8 && dtype.lanes == 1)
return MPSDataTypeUInt8;
else if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeUInt16;
else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeUInt32;
LOG(FATAL) << "Unsupported type";
break;
case kDLFloat:
if
(
dtype
.
bits
==
16
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeFloat16
;
else
if
(
dtype
.
bits
==
32
&&
dtype
.
lanes
==
1
)
return
MPSDataTypeFloat32
;
if (dtype.bits == 16 && dtype.lanes == 1)
return MPSDataTypeFloat16;
else if (dtype.bits == 32 && dtype.lanes == 1)
return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
return MPSDataTypeFloat32;
}
// MetalThreadEntry
MPSImage *MetalThreadEntry::AllocMPSImage(id<MTLDevice> dev,
MPSImageDescriptor *desc) {
MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc];
img_table.push_back(mpsimg);
return mpsimg;
}
MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id<MTLCommandBuffer> cb,
MPSImageDescriptor *desc) {
MPSTemporaryImage *mpsimg =
[MPSTemporaryImage temporaryImageWithCommandBuffer:cb
imageDescriptor:desc];
return mpsimg;
}
MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal");
void *ret = (*func)();
...
...
@@ -46,11 +65,14 @@ MetalThreadEntry::MetalThreadEntry() {
}
MetalThreadEntry::~MetalThreadEntry() {
for (int i = 0; i < img_table.size(); ++i) {
[img_table[i] dealloc];
}
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry
*
MetalThreadEntry
::
ThreadLocal
()
{
MetalThreadEntry
*
MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get();
}
...
...
src/runtime/metal/metal_device_api.mm
View file @
078c767c
...
...
@@ -126,10 +126,18 @@ void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// allocate buffer in GPU only mode.
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [
dev newBufferWithLength:nbytes
options:
MTLResourceStorageModePrivat
e];
options:
storage_mod
e];
CHECK(buf != nil);
return (__bridge void*)([buf retain]);
}
...
...
tests/python/contrib/test_mps.py
View file @
078c767c
...
...
@@ -2,39 +2,83 @@ import tvm
import
numpy
as
np
from
tvm.contrib
import
mps
def
test_matmul_add
():
def
test_matmul
():
if
not
tvm
.
module
.
enabled
(
"metal"
):
print
(
"skip because
%
s is not enabled..."
%
"metal"
)
return
n
=
1024
l
=
128
m
=
235
bias
=
tvm
.
var
(
'bias'
,
dtype
=
tvm
.
float32
)
m
=
256
A
=
tvm
.
placeholder
((
n
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
l
,
m
),
name
=
'B'
)
C1
=
mps
.
matmul
(
A
,
B
)
C2
=
mps
.
matmul
(
B
,
A
,
True
,
True
)
D1
=
tvm
.
compute
(
C1
.
shape
,
lambda
i
,
j
:
C1
[
i
,
j
]
+
bias
,
name
=
"D1"
)
D2
=
tvm
.
compute
(
C2
.
shape
,
lambda
i
,
j
:
C2
[
i
,
j
]
+
bias
,
name
=
"D2"
)
s1
=
tvm
.
create_schedule
(
D1
.
op
)
s2
=
tvm
.
create_schedule
(
D2
.
op
)
def
verify
(
A
,
B
,
D
,
s
,
bias
,
target
=
"llvm"
):
if
not
tvm
.
module
.
enabled
(
target
):
print
(
"skip because
%
s is not enabled..."
%
target
)
return
C
=
mps
.
matmul
(
A
,
B
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
*
i
:
C
(
*
i
)
+
1.
)
s
=
tvm
.
create_schedule
(
D
.
op
)
yo
,
xo
=
D
.
op
.
axis
block_y
=
tvm
.
thread_axis
(
"blockIdx.y"
)
block_x
=
tvm
.
thread_axis
(
"blockIdx.x"
)
thread_y
=
tvm
.
thread_axis
(
"threadIdx.y"
)
thread_x
=
tvm
.
thread_axis
(
"threadIdx.x"
)
by
,
ty
=
s
[
D
]
.
split
(
yo
,
factor
=
16
)
bx
,
tx
=
s
[
D
]
.
split
(
xo
,
factor
=
16
)
s
[
D
]
.
bind
(
by
,
block_y
)
s
[
D
]
.
bind
(
bx
,
block_x
)
s
[
D
]
.
bind
(
ty
,
thread_y
)
s
[
D
]
.
bind
(
tx
,
thread_x
)
def
verify
(
A
,
B
,
D
,
s
,
target
=
"metal"
):
if
not
tvm
.
get_global_func
(
"tvm.contrib.mps.matmul"
,
True
):
print
(
"skip because extern function is not avalable"
)
return
ctx
=
tvm
.
cpu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
,
bias
],
target
)
ctx
=
tvm
.
metal
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
B
,
D
],
"metal"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
l
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
l
,
m
))
.
astype
(
B
.
dtype
),
ctx
)
d
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
D
.
dtype
),
ctx
)
bb
=
10.0
f
(
a
,
b
,
d
,
bb
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
m
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
d
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
bb
,
rtol
=
1e-5
)
verify
(
A
,
B
,
D1
,
s1
,
bias
)
verify
(
A
,
B
,
D2
,
s2
,
bias
)
c
.
asnumpy
(),
np
.
dot
(
a
.
asnumpy
(),
b
.
asnumpy
())
+
1
,
rtol
=
1e-5
)
verify
(
A
,
B
,
D
,
s
)
def
test_conv2d
():
if
not
tvm
.
module
.
enabled
(
"metal"
):
print
(
"skip because
%
s is not enabled..."
%
"metal"
)
return
n
=
1
h
=
14
w
=
14
ci
=
2
co
=
4
kh
=
3
kw
=
3
stride
=
2
A
=
tvm
.
placeholder
((
n
,
h
,
w
,
ci
),
name
=
"x"
)
B
=
tvm
.
placeholder
((
co
,
kh
,
kw
,
ci
),
name
=
"w"
)
C
=
mps
.
conv2d
(
A
,
B
,
'SAME'
,
2
)
s1
=
tvm
.
create_schedule
(
C
.
op
)
def
verify
(
A
,
B
,
C
,
target
=
"llvm"
):
if
not
tvm
.
get_global_func
(
"tvm.contrib.mps.conv2d"
,
True
):
print
(
"skip because extern function is not avalable"
)
return
ctx
=
tvm
.
metal
(
0
)
f
=
tvm
.
build
(
s1
,
[
A
,
B
,
C
],
"metal"
)
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
n
,
h
,
w
,
ci
))
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
(
co
,
kh
,
kw
,
ci
))
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
((
n
,
h
//
stride
,
w
//
stride
,
co
),
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
b
,
c
)
# print(c.asnumpy())
# print(c.shape)
verify
(
A
,
B
,
C
,
s1
)
if
__name__
==
"__main__"
:
test_matmul_add
()
#test_matmul()
test_conv2d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment