Commit 0257a88b by Peter Yeh Committed by masahi

Enable miopen Group Convolution (#3987)

* enable group conv through miopen

* linter fix
parent beb1c252
......@@ -50,7 +50,8 @@ def conv2d_forward(x,
dilation_h=1,
dilation_w=1,
conv_mode=0,
data_type=1):
data_type=1,
group_count=1):
"""Create an extern op that compute 2D convolution with MIOpen
Parameters
......@@ -77,13 +78,16 @@ def conv2d_forward(x,
data_type: int
0: miopenHalf (fp16)
1: miopenFloat (fp32)
group_count: int
number of groups
Returns
-------
y: Tensor
The result tensor
"""
assert (conv_mode == 0 or conv_mode == 1), "0: miopenConvolution / 1: miopenTranspose"
assert (0 <= conv_mode <= 2), "0: miopenConvolution / 1: miopenTranspose / 2: miopenGroupConv"
if group_count > 1:
conv_mode = 2
oshape = np.zeros((len(x.shape)), dtype=np.int32)
xshape = x.shape
wshape = w.shape
......@@ -104,6 +108,7 @@ def conv2d_forward(x,
wshape[1].value,
wshape[2].value,
wshape[3].value,
group_count,
_get_np_int32_array_handle(oshape))
return _api.extern(
......
......@@ -50,16 +50,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
const int w_dim1 = args[13];
const int w_dim2 = args[14];
const int w_dim3 = args[15];
void *out_shape = args[16];
const int n_group = args[16];
void *out_shape = args[17];
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
assert(n_group > 0 && "Group Size > 0 is expected");
if (n_group > 1)
assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
// Set Mode
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
// Set Ctx
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
// Set Data Type
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
// this moment.
// Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
......@@ -70,11 +74,13 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
stride_w,
dilation_h,
dilation_w));
if (n_group > 1)
MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
// Set Filter
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type,
w_dim0,
w_dim1,
w_dim1/n_group,
w_dim2,
w_dim3));
// Set Input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment