Commit 46e6cae5 by Sheng Zha Committed by Tianqi Chen

[CONTRIB] MPS DNN Dense (#615)

* mps

* update
parent 72992208
...@@ -133,6 +133,7 @@ endif ...@@ -133,6 +133,7 @@ endif
include make/contrib/cblas.mk include make/contrib/cblas.mk
include make/contrib/nnpack.mk include make/contrib/nnpack.mk
include make/contrib/cudnn.mk include make/contrib/cudnn.mk
include make/contrib/mps.mk
ifdef ADD_CFLAGS ifdef ADD_CFLAGS
CFLAGS += $(ADD_CFLAGS) CFLAGS += $(ADD_CFLAGS)
......
...@@ -68,3 +68,6 @@ USE_NNPACK = 0 ...@@ -68,3 +68,6 @@ USE_NNPACK = 0
# Whether use CuDNN # Whether use CuDNN
USE_CUDNN = 0 USE_CUDNN = 0
# Whether use MPS
USE_MPS = 0
MPS_CONTRIB_SRC = $(wildcard src/contrib/mps/*.mm, src/contrib/mps/*.cc)
MPS_CONTRIB_OBJ = $(patsubst src/%.mm, build/%.o, $(MPS_CONTRIB_SRC))
ifeq ($(USE_MPS), 1)
FRAMEWORKS += -framework MetalPerformanceShaders
CFLAGS +=
ADD_LDFLAGS +=
RUNTIME_DEP += $(MPS_CONTRIB_OBJ)
endif
build/contrib/mps/%.o: src/contrib/mps/%.mm src/contrib/mps/%.cc
@mkdir -p $(@D)
$(CXX) $(OBJCFLAGS) $(CFLAGS) -MM -MT build/contrib/mps/$*.o $< >build/contrib/mps/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
"""External function interface to MPS libraroes."""
from __future__ import absolute_import as _abs
from .. import api as _api
from .. import intrin as _intrin
def matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to calle external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
m = lhs.shape[0]
n = rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb),
name="C")
#include "../../runtime/metal/metal_common.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
namespace tvm {
namespace contrib {
using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
// Get Metal device API
MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal();
CHECK_EQ(A->ctx, B->ctx);
CHECK_EQ(A->ctx, C->ctx);
id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(A->ctx);
id<MTLCommandQueue> queue = entry_ptr->metal_api->GetCommandQueue(A->ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
NSUInteger M = A->shape[0 + transa?1:0];
NSUInteger N = B->shape[1 - transb?1:0];
NSUInteger K = B->shape[0 + transb?1:0];
CHECK_EQ(A->shape[1-transa?1:0], K);
// mps a
MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype);
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:K
rowBytes:M * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufA = (__bridge id<MTLBuffer>)(A->data);
MPSMatrix *matrixA =
[[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA];
// mps b
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:K
columns:N
rowBytes:K * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(B->data);
MPSMatrix *matrixB =
[[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB];
// mps c
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
matrixDescriptorWithDimensions:M
columns:N
rowBytes:M * sizeof(dtype)
dataType:dtype];
id<MTLBuffer> bufC = (__bridge id<MTLBuffer>)(C->data);
MPSMatrix *matrixC =
[[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC];
// kernel
MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init];
MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev
transposeLeft:transa
transposeRight:transb
resultRows:M
resultColumns:N
interiorColumns:K
alpha:1.0f
beta:0.0f];
CHECK(sgemm != nil);
[sgemm encodeToCommandBuffer:cb
leftMatrix:matrixA
rightMatrix:matrixB
resultMatrix:matrixC];
[cb commit];
[mul_obj dealloc];
[matrixA dealloc];
[matrixB dealloc];
[matrixC dealloc];
});
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external mps utils function
*/
#include "mps_utils.h"
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace contrib {
// MPS Data Type
MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) {
switch (dtype.code) {
case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeInt8;
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeInt16;
else
LOG(FATAL) << "Unsupported type";
break;
case kDLUInt:
if (dtype.bits == 8 && dtype.lanes == 1) return MPSDataTypeUInt8;
else if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeUInt16;
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeUInt32;
LOG(FATAL) << "Unsupported type";
break;
case kDLFloat:
if (dtype.bits == 16 && dtype.lanes == 1) return MPSDataTypeFloat16;
else if (dtype.bits == 32 && dtype.lanes == 1) return MPSDataTypeFloat32;
else
LOG(FATAL) << "Unsupported type";
break;
default:
LOG(FATAL) << "Unsupported type";
}
}
// MetalThreadEntry
MetalThreadEntry::MetalThreadEntry() {
auto func = runtime::Registry::Get("device_api.metal");
void *ret = (*func)();
metal_api = static_cast<runtime::metal::MetalWorkspace *>(ret);
}
MetalThreadEntry::~MetalThreadEntry() {
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get();
}
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external mps utils function
*/
#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include "../../runtime/metal/metal_common.h"
namespace tvm {
namespace contrib {
/*! breif Convert DLTensor type to MPS type */
struct MPSType {
static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
}; // struct MPSType
struct MetalThreadEntry {
MetalThreadEntry();
~MetalThreadEntry();
runtime::MetalWorkspace *metal_api{nullptr};
static MetalThreadEntry* ThreadLocal();
}; // MetalThreadEntry
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MPS_MPS_UTILS_H_
import tvm
import numpy as np
from tvm.contrib import mps
def test_matmul_add():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C1 = mps.matmul(A, B)
C2 = mps.matmul(B, A, True, True)
D1 = tvm.compute(C1.shape, lambda i, j: C1[i,j] + bias, name="D1")
D2 = tvm.compute(C2.shape, lambda i, j: C2[i,j] + bias, name="D2")
s1 = tvm.create_schedule(D1.op)
s2 = tvm.create_schedule(D2.op)
def verify(A, B, D, s, bias, target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.mps.matmul", True):
print("skip because extern function is not avalable")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
verify(A, B, D1, s1, bias)
verify(A, B, D2, s2, bias)
if __name__ == "__main__":
test_matmul_add()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment