Commit 9e4f07b4 by Peter Yeh Committed by masahi

Enable miopen transpose convolution and fp16 support (#3952)

* Enable miopen transpose convolution and fp16 support

* linter
parent 0482623e
...@@ -49,7 +49,8 @@ def conv2d_forward(x, ...@@ -49,7 +49,8 @@ def conv2d_forward(x,
pad_w=0, pad_w=0,
dilation_h=1, dilation_h=1,
dilation_w=1, dilation_w=1,
conv_mode=0): conv_mode=0,
data_type=1):
"""Create an extern op that compute 2D convolution with MIOpen """Create an extern op that compute 2D convolution with MIOpen
Parameters Parameters
...@@ -73,18 +74,22 @@ def conv2d_forward(x, ...@@ -73,18 +74,22 @@ def conv2d_forward(x,
conv_mode: int conv_mode: int
0: miopenConvolution 0: miopenConvolution
1: miopenTranspose 1: miopenTranspose
data_type: int
0: miopenHalf (fp16)
1: miopenFloat (fp32)
Returns Returns
------- -------
y: Tensor y: Tensor
The result tensor The result tensor
""" """
assert conv_mode == 0, "Transpose convolutions not supported yet." assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose"
oshape = np.zeros((len(x.shape)), dtype=np.int32) oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape xshape = x.shape
wshape = w.shape wshape = w.shape
setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup") setup_func = _get_global_func("tvm.contrib.miopen.conv2d.setup")
algo = setup_func(conv_mode, algo = setup_func(conv_mode,
data_type,
pad_h, pad_h,
pad_w, pad_w,
stride_h, stride_h,
...@@ -106,6 +111,7 @@ def conv2d_forward(x, ...@@ -106,6 +111,7 @@ def conv2d_forward(x,
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.miopen.conv2d.forward", "tvm.contrib.miopen.conv2d.forward",
conv_mode, conv_mode,
data_type,
pad_h, pad_h,
pad_w, pad_w,
stride_h, stride_h,
......
...@@ -35,21 +35,22 @@ using namespace runtime; ...@@ -35,21 +35,22 @@ using namespace runtime;
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
const int mode = args[0]; const int mode = args[0];
const int pad_h = args[1]; const int dtype = args[1];
const int pad_w = args[2]; const int pad_h = args[2];
const int stride_h = args[3]; const int pad_w = args[3];
const int stride_w = args[4]; const int stride_h = args[4];
const int dilation_h = args[5]; const int stride_w = args[5];
const int dilation_w = args[6]; const int dilation_h = args[6];
const int x_dim0 = args[7]; const int dilation_w = args[7];
const int x_dim1 = args[8]; const int x_dim0 = args[8];
const int x_dim2 = args[9]; const int x_dim1 = args[9];
const int x_dim3 = args[10]; const int x_dim2 = args[10];
const int w_dim0 = args[11]; const int x_dim3 = args[11];
const int w_dim1 = args[12]; const int w_dim0 = args[12];
const int w_dim2 = args[13]; const int w_dim1 = args[13];
const int w_dim3 = args[14]; const int w_dim2 = args[14];
void *out_shape = args[15]; const int w_dim3 = args[15];
void *out_shape = args[16];
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
// Set Mode // Set Mode
...@@ -57,7 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ...@@ -57,7 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
// Set Ctx // Set Ctx
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32 entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
// this moment.
// Set Desc // Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.mode, entry_ptr->conv_entry.mode,
...@@ -170,16 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ...@@ -170,16 +173,17 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
const int mode = args[0]; const int mode = args[0];
const int pad_h = args[1]; const int dtype = args[1];
const int pad_w = args[2]; const int pad_h = args[2];
const int stride_h = args[3]; const int pad_w = args[3];
const int stride_w = args[4]; const int stride_h = args[4];
const int dilation_h = args[5]; const int stride_w = args[5];
const int dilation_w = args[6]; const int dilation_h = args[6];
const int algo = args[7]; const int dilation_w = args[7];
const DLTensor *x = args[8]; const int algo = args[8];
const DLTensor *w = args[9]; const DLTensor *x = args[9];
const DLTensor *y = args[10]; const DLTensor *w = args[10];
const DLTensor *y = args[11];
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo); entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
...@@ -188,7 +192,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") ...@@ -188,7 +192,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
// Set Ctx // Set Ctx
entry_ptr->conv_entry.ctx = x->ctx; entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = miopenFloat; // MIOpen only suppports fp32 entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
// this moment.
// Set Desc // Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
entry_ptr->conv_entry.mode, entry_ptr->conv_entry.mode,
......
...@@ -50,7 +50,8 @@ def test_conv2d(): ...@@ -50,7 +50,8 @@ def test_conv2d():
pad_w, pad_w,
dilation_h, dilation_h,
dilation_w, dilation_w,
conv_mode=0) conv_mode=0,
data_type=1)
yshape = [x.value for x in Y.shape] yshape = [x.value for x in Y.shape]
import topi import topi
...@@ -65,7 +66,7 @@ def test_conv2d(): ...@@ -65,7 +66,7 @@ def test_conv2d():
y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
f(x, w, y) f(x, w, y)
Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w)) Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w))
with tvm.target.rocm(): with tvm.target.rocm():
s_ref = topi.generic.schedule_conv2d_nchw([Y_ref]) s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm") f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
......
...@@ -78,7 +78,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -78,7 +78,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
pad_w, pad_w,
dilation_h, dilation_h,
dilation_w, dilation_w,
conv_mode=0) conv_mode=0,
data_type=1)
return conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) return conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment