Commit 3b9f1652 by masahi Committed by Tianqi Chen

[ROCM] MIOpen contrib for convolution kernels (#722)

* fist working miopen support

* do FindFwdAlgo during build time

* fix lint

* update doc string

* import topi after checking if rocm is enabled

* add miopen namespace

* fixed descriptor overwrite bug

* add use_miopen option

* fix lint

* better miopen option handling

* fix typo

* fix options handling
parent 5d37be62
......@@ -134,6 +134,7 @@ include make/contrib/cblas.mk
include make/contrib/random.mk
include make/contrib/nnpack.mk
include make/contrib/cudnn.mk
include make/contrib/miopen.mk
include make/contrib/mps.mk
ifdef ADD_CFLAGS
......
......@@ -72,5 +72,8 @@ USE_NNPACK = 0
# Whether use CuDNN
USE_CUDNN = 0
# Whether use MIOpen
USE_MIOPEN = 0
# Whether use MPS
USE_MPS = 0
MIOPEN_CONTRIB_SRC = $(wildcard src/contrib/miopen/*.cc)
MIOPEN_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(MIOPEN_CONTRIB_SRC))
ifeq ($(USE_MIOPEN), 1)
CFLAGS += -DTVM_USE_MIOPEN=1
ADD_LDFLAGS += -lMIOpen
RUNTIME_DEP += $(MIOPEN_CONTRIB_OBJ)
endif
"""External function interface to MIOpen library."""
# pylint: disable-msg=C0103
import ctypes
import numpy as np
from .. import api as _api
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func
def _get_np_int32_array_handle(arr):
"""Return a void_p handle for a numpy array
Parameters
----------
arr: numpy.NDArray
source numpy array
Returns
-------
ptr: ctypes.c_void_p
pointer to the data
"""
assert arr.dtype == np.int32
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
return ctypes.cast(ptr, ctypes.c_void_p)
def conv2d_forward(x,
w,
stride_h=1,
stride_w=1,
pad_h=0,
pad_w=0,
dilation_h=1,
dilation_w=1,
conv_mode=0):
"""Create an extern op that compute 2D convolution with MIOpen
Parameters
----------
x: Tensor
input feature map
w: Tensor
convolution weight
stride_h: int
height stride
stride_w: int
width stride
pad_h: int
height pad
pad_w: int
weight pad
dilation_h: int
height dilation
dilation_w: int
width dilation
conv_mode: int
0: miopenConvolution
1: miopenTranspose
Returns
-------
y: Tensor
The result tensor
"""
assert conv_mode == 0, "Transpose convolutions not supported yet."
oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape
wshape = w.shape
setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
algo = setup_func(conv_mode,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
xshape[0].value,
xshape[1].value,
xshape[2].value,
xshape[3].value,
wshape[0].value,
wshape[1].value,
wshape[2].value,
wshape[3].value,
_get_np_int32_array_handle(oshape))
return _api.extern(
list(oshape), [x, w],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.miopen.conv2d.forward",
conv_mode,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
algo,
ins[0],
ins[1],
outs[0]), name="y")
......@@ -88,12 +88,17 @@ class Target(object):
target_name,
options=None):
self.target_name = target_name
self.options = _merge_opts([], options)
self.options = []
self.device_name = ""
self.libs = []
# Parse device option
for item in self.options:
if item.startswith("-device="):
for item in _merge_opts([], options):
if item.startswith("-libs="):
self.libs.append(item.split("=")[1])
continue
elif item.startswith("-device="):
self.device_name = item.split("=")[1]
self.options.append(item)
# Target query searchs device name first
if self.device_name:
self.keys = (self.device_name,)
......
/*!
* Copyright (c) 2017 by Contributors
* \file Use external miopen utils function
*/
#include "miopen_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <vector>
#include <string>
namespace tvm {
namespace contrib {
namespace miopen {
std::string miopenGetErrorString(int error_code) {
const std::vector<std::string> mio_err{
"StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ",
"StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ",
"StatusNotImplemented ", "StatusUnknownError "};
return mio_err[error_code];
}
// MiopenThreadEntry
MIOpenThreadEntry::MIOpenThreadEntry() {
auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
auto func = runtime::Registry::Get("device_api.rocm");
void *ret = (*func)();
rocm_api = static_cast<runtime::DeviceAPI*>(ret);
MIOPEN_CALL(miopenCreate(&handle));
MIOPEN_CALL(miopenSetStream(handle, stream));
conv_entry.rocm_api = rocm_api;
}
MIOpenThreadEntry::~MIOpenThreadEntry() {
MIOPEN_CALL(miopenDestroy(handle));
}
typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
return MIOpenThreadStore::Get();
}
// ConvEntry
ConvEntry::ConvEntry() {
MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
}
ConvEntry::~ConvEntry() {
MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
CleanWorkspace();
}
void ConvEntry::UpdateWorkspace(const size_t wsize) {
if (workspace_size < wsize) {
if (workspace != nullptr) {
CleanWorkspace();
}
workspace_size = wsize;
workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
}
}
void ConvEntry::CleanWorkspace() {
if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
workspace_size = 0;
}
} // namespace miopen
} // namespace contrib
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file Use external miopen utils function
*/
#ifndef TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
#define TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
#include <dmlc/logging.h>
#include <miopen/miopen.h>
#include <tvm/runtime/device_api.h>
#include <string>
#include "../../runtime/rocm/rocm_common.h"
namespace tvm {
namespace contrib {
namespace miopen {
std::string miopenGetErrorString(int error_code);
#define MIOPEN_CALL(func) \
{ \
miopenStatus_t e = (func); \
CHECK_EQ(e, miopenStatusSuccess) \
<< "miopen error: " << miopenGetErrorString(e); \
}
struct ConvEntry {
miopenConvolutionDescriptor_t conv_desc;
miopenConvolutionMode_t mode{miopenConvolution};
miopenTensorDescriptor_t filter_desc;
miopenDataType_t data_type{miopenFloat};
miopenTensorDescriptor_t input_desc;
miopenTensorDescriptor_t output_desc;
miopenConvFwdAlgorithm_t fwd_algo;
TVMContext ctx;
runtime::DeviceAPI *rocm_api;
void *workspace{nullptr};
size_t workspace_size{0};
ConvEntry();
~ConvEntry();
void UpdateWorkspace(const size_t wsize);
void CleanWorkspace();
}; // ConvThreadEntry
struct MIOpenThreadEntry {
MIOpenThreadEntry();
~MIOpenThreadEntry();
miopenHandle_t handle{nullptr};
ConvEntry conv_entry;
runtime::DeviceAPI *rocm_api{nullptr};
static MIOpenThreadEntry *ThreadLocal();
}; // MIOpenThreadEntry
} // namespace miopen
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MIOPEN_MIOPEN_UTILS_H_
import tvm
from tvm.contrib import miopen
import numpy as np
def test_conv2d():
in_channel = 64
out_channel = 128
filter_h = 3
filter_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
dilation_h = 1
dilation_w = 1
xshape = [1, in_channel, 64, 64]
if not tvm.module.enabled("rocm"):
print("skip because rocm is not enabled...")
return
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
print("skip because miopen is not enabled...")
return
wshape = (out_channel, in_channel, filter_h, filter_w)
X = tvm.placeholder(xshape, name='X')
W = tvm.placeholder(wshape, name='W')
Y = miopen.conv2d_forward(X,
W,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
conv_mode=0)
yshape = [x.value for x in Y.shape]
s = tvm.create_schedule(Y.op)
def verify():
ctx = tvm.rocm(0)
f = tvm.build(s, [X, W, Y], "rocm", target_host="llvm", name="conv2d")
x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), ctx)
w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), ctx)
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y)
import topi
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f_ref(x, w, y_ref)
print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
np.testing.assert_allclose(y.asnumpy(), y_ref.asnumpy(), atol=1e-3)
verify()
if __name__ == "__main__":
test_conv2d()
#!/bin/bash
export PYTHONPATH=python
export PYTHONPATH=python:topi/python
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment