Commit ce031438 by lixiaoquan Committed by Tianqi Chen

[TensorFlow] Fix limitation that depth_mult can only be 1 for DepthwiseConv2dNative (#3676)

* [TensorFlow] Fix limitation that depth_mult can only be 1 for DepthwiseConv2dNative

* Improve code readability
parent 710ac146
...@@ -251,9 +251,6 @@ def _conv(opname): ...@@ -251,9 +251,6 @@ def _conv(opname):
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if opname == 'depthwise': if opname == 'depthwise':
if depth_mult > 1:
raise tvm.error.OpNotImplemented('depth_mult > 1 of operator DepthwiseConv2dNative'
' is not supported.')
attr['groups'] = attr['channels'] attr['groups'] = attr['channels']
# Fix padding # Fix padding
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import topi import topi
from topi.util import get_const_int, get_const_tuple from topi.util import get_const_tuple
from .. import op as reg from .. import op as reg
from ..op import OpPattern, schedule_injective from ..op import OpPattern, schedule_injective
...@@ -144,19 +144,20 @@ def compute_conv2d(attrs, inputs, out_type, target): ...@@ -144,19 +144,20 @@ def compute_conv2d(attrs, inputs, out_type, target):
if dilation_h < 1 or dilation_w < 1: if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value") raise ValueError("dilation should be positive value")
def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]
if groups == 1: if groups == 1:
out = topi.nn.conv2d( out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype) dilation, layout, out_dtype)
elif layout == "NCHW" and \ elif layout == "NCHW" and _get_out_depth() == groups:
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw( out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype) inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout == "NHWC" and \ elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc( out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype) inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ['NCHW', 'NCHW4c']: elif layout in ['NCHW', 'NCHW4c']:
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
* \brief Convolution operators * \brief Convolution operators
*/ */
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
#include <tvm/ir_pass.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <vector> #include <vector>
...@@ -74,11 +75,23 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -74,11 +75,23 @@ bool Conv2DRel(const Array<Type>& types,
if (param->kernel_size.defined() && param->channels.defined()) { if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2); CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape( Array<IndexExpr> wshape;
{param->channels,
dshape_nchw[1] / param->groups, if (tvm::ir::Equal(param->channels, param->groups)) {
param->kernel_size[0], // infer weight's shape for depthwise convolution
param->kernel_size[1]}); wshape = {
{dshape_nchw[1],
param->groups / dshape_nchw[1],
param->kernel_size[0],
param->kernel_size[1]}};
} else {
wshape = {
{param->channels,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]}};
}
wshape = trans_kernel_layout.BackwardShape(wshape); wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels; channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
......
...@@ -275,6 +275,7 @@ def test_forward_convolution(): ...@@ -275,6 +275,7 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 19, 17, 17], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 19, 17, 17], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW') _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
...@@ -284,6 +285,7 @@ def test_forward_convolution(): ...@@ -284,6 +285,7 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
####################################################################### #######################################################################
# BiasAdd # BiasAdd
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment