Commit 0257a88b by Peter Yeh Committed by masahi

Enable miopen Group Convolution (#3987)

* enable group conv through miopen

* linter fix
parent beb1c252
...@@ -50,7 +50,8 @@ def conv2d_forward(x, ...@@ -50,7 +50,8 @@ def conv2d_forward(x,
dilation_h=1, dilation_h=1,
dilation_w=1, dilation_w=1,
conv_mode=0, conv_mode=0,
data_type=1): data_type=1,
group_count=1):
"""Create an extern op that compute 2D convolution with MIOpen """Create an extern op that compute 2D convolution with MIOpen
Parameters Parameters
...@@ -77,13 +78,16 @@ def conv2d_forward(x, ...@@ -77,13 +78,16 @@ def conv2d_forward(x,
data_type: int data_type: int
0: miopenHalf (fp16) 0: miopenHalf (fp16)
1: miopenFloat (fp32) 1: miopenFloat (fp32)
group_count: int
number of groups
Returns Returns
------- -------
y: Tensor y: Tensor
The result tensor The result tensor
""" """
assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose" assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
if group_count > 1:
conv_mode = 2
oshape = np.zeros((len(x.shape)), dtype=np.int32) oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape xshape = x.shape
wshape = w.shape wshape = w.shape
...@@ -104,6 +108,7 @@ def conv2d_forward(x, ...@@ -104,6 +108,7 @@ def conv2d_forward(x,
wshape[1].value, wshape[1].value,
wshape[2].value, wshape[2].value,
wshape[3].value, wshape[3].value,
group_count,
_get_np_int32_array_handle(oshape)) _get_np_int32_array_handle(oshape))
return _api.extern( return _api.extern(
......
...@@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ...@@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
const int w_dim1 = args[13]; const int w_dim1 = args[13];
const int w_dim2 = args[14]; const int w_dim2 = args[14];
const int w_dim3 = args[15]; const int w_dim3 = args[15];
void *out_shape = args[16]; const int n_group = args[16];
void *out_shape = args[17];
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
assert(n_group > 0 && "Group Size > 0 is expected");
if (n_group > 1)
assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
// Set Mode // Set Mode
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode); entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
// Set Ctx // Set Ctx
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>( entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
// this moment. // this moment.
// Set Desc // Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
...@@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ...@@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
stride_w, stride_w,
dilation_h, dilation_h,
dilation_w)); dilation_w));
if (n_group > 1)
MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
// Set Filter // Set Filter
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type, entry_ptr->conv_entry.data_type,
w_dim0, w_dim0,
w_dim1, w_dim1/n_group,
w_dim2, w_dim2,
w_dim3)); w_dim3));
// Set Input // Set Input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment