Commit bb7df695 by Siva Committed by Tianqi Chen

[NNVM][CONVOLUTION] Group convolution generalization for NHWC (#1232)

parent 5b33e7b8
...@@ -33,6 +33,8 @@ class AttrCvt(object): ...@@ -33,6 +33,8 @@ class AttrCvt(object):
self._ignores.append('_input_shapes') self._ignores.append('_input_shapes')
self._ignores.append('T') self._ignores.append('T')
self._ignores.append('use_cudnn_on_gpu') self._ignores.append('use_cudnn_on_gpu')
self._ignores.append('_node_name')
self._ignores.append('is_training')
return AttrConvert(self._op_name, self._transforms, self._excludes, return AttrConvert(self._op_name, self._transforms, self._excludes,
self._disables, self._ignores, self._extras, self._disables, self._ignores, self._extras,
self._custom_check)(inputs, attrs, *args) self._custom_check)(inputs, attrs, *args)
...@@ -230,6 +232,85 @@ def _conv(): ...@@ -230,6 +232,85 @@ def _conv():
custom_check=_dimension_constraint())(inputs, attr) custom_check=_dimension_constraint())(inputs, attr)
return _impl return _impl
def _depthwise_conv():
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
input_shapes = attr['_input_shapes'][inputs[0]]
# Extract kernel shape from params
conv_param_weights = params[inputs[1].list_output_names()[0]]
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
attr['channels'] = input_shapes[0][3] * depth_mult
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
attr['channels'] = input_shapes[0][1] * depth_mult
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
# Fix groups
attr['groups'] = attr['channels']
# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
if attr['data_format'] == 'NHWC':
in_h = input_shapes[0][1]
in_w = input_shapes[0][2]
else:
in_h = input_shapes[0][2]
in_w = input_shapes[0][3]
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
return AttrCvt(
op_name=_dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
extras={'use_bias': len(inputs) == 3},
custom_check=_dimension_constraint())(inputs, attr)
return _impl
def _decode_image(): def _decode_image():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
...@@ -358,9 +439,27 @@ def _batch_norm(): ...@@ -358,9 +439,27 @@ def _batch_norm():
op_name='batch_norm', op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis extras={'axis': 3}, # Fix axis
ignores=['data_format'],
disables=['momentum'])(new_inputs, attr) disables=['momentum'])(new_inputs, attr)
return _impl return _impl
def _relu6():
def _impl(inputs, attr, params):
return _sym.clip(inputs[0], a_min=0, a_max=6)
return _impl
def _shape():
def _impl(inputs, attr, params):
input_shapes = attr['_input_shapes'][inputs[0]]
# Fix the -1 dimensions to 1
input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
params[attr['_node_name']] = tvm.nd.array(input_shapes[0])
return _sym.Variable(name=attr['_node_name'],
shape=params[attr['_node_name']].shape)
return _impl
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -392,6 +491,10 @@ _convert_map = { ...@@ -392,6 +491,10 @@ _convert_map = {
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'Rsqrt' : _rsqrt(), 'Rsqrt' : _rsqrt(),
'Squeeze' : _squeeze(), 'Squeeze' : _squeeze(),
'FusedBatchNorm' : _batch_norm(),
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(),
} }
...@@ -458,9 +561,13 @@ class GraphProto(object): ...@@ -458,9 +561,13 @@ class GraphProto(object):
self._num_input += 1 self._num_input += 1
self._nodes[node.name] = _sym.Variable(name=node.name) self._nodes[node.name] = _sym.Variable(name=node.name)
self._output_shapes[node.name] = \ try:
[tensor_util.TensorShapeProtoToList(shape) \ self._output_shapes[node.name] = \
for shape in self._parse_attr(node.attr)['_output_shapes']] [tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
elif node.op == "Const": elif node.op == "Const":
# Assuming first Const node as Graph Input node # Assuming first Const node as Graph Input node
if self._input_node == '': if self._input_node == '':
...@@ -476,17 +583,29 @@ class GraphProto(object): ...@@ -476,17 +583,29 @@ class GraphProto(object):
raise NotImplementedError( \ raise NotImplementedError( \
"Const {} couldn't be converted to Param.".format(node.name)) "Const {} couldn't be converted to Param.".format(node.name))
self._output_shapes[node.name] = \ try:
[tensor_util.TensorShapeProtoToList(shape) \ self._output_shapes[node.name] = \
for shape in self._parse_attr(node.attr)['_output_shapes']] [tensor_util.TensorShapeProtoToList(shape) \
for shape in self._parse_attr(node.attr)['_output_shapes']]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
else: else:
attr = self._parse_attr(node.attr) attr = self._parse_attr(node.attr)
self._output_shapes[node.name] = \ try:
[tensor_util.TensorShapeProtoToList(shape) for shape in attr['_output_shapes']] self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(shape) \
for shape in attr['_output_shapes']]
except KeyError:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")
# Pass the parsed shapes instead # Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name] attr["_output_shapes"] = self._output_shapes[node.name]
# Pass the node name too in attr
attr["_node_name"] = node.name
try: try:
inputs = [self._nodes[i] for i in node.input] inputs = [self._nodes[i] for i in node.input]
input_shapes = {} input_shapes = {}
......
...@@ -84,6 +84,7 @@ def compute_conv2d(attrs, inputs, _): ...@@ -84,6 +84,7 @@ def compute_conv2d(attrs, inputs, _):
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels") channels = attrs.get_int("channels")
layout = attrs["layout"] layout = attrs["layout"]
kernel_layout = attrs["kernel_layout"]
assert layout == "NCHW" or layout == "NHWC" assert layout == "NCHW" or layout == "NHWC"
(dilation_h, dilation_w) = dilation (dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1: if dilation_h < 1 or dilation_w < 1:
...@@ -97,10 +98,18 @@ def compute_conv2d(attrs, inputs, _): ...@@ -97,10 +98,18 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1: if groups == 1:
out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout) out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding) out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding)
else: else:
raise ValueError("not support arbitrary group number for now") raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"): if attrs.get_bool("use_bias"):
bias = inputs[2] bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0 expand_axis = 1 if layout == "NCHW" else 0
...@@ -112,13 +121,20 @@ def compute_conv2d(attrs, inputs, _): ...@@ -112,13 +121,20 @@ def compute_conv2d(attrs, inputs, _):
def schedule_conv2d(attrs, outs, target): def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d""" """Schedule definition of conv2d"""
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs["layout"] layout = attrs["layout"]
kernel_layout = attrs["kernel_layout"]
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NHWC": elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_conv2d_nhwc(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs) elif groups == channels and layout == "NCHW":
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
else:
raise ValueError("No compatible schedule")
@reg.register_alter_op_layout("conv2d") @reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos): def alter_conv2d_layout(attrs, inputs, tinfos):
......
...@@ -79,7 +79,8 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -79,7 +79,8 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
param.kernel_size[1]}); param.kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[0] *= param.groups;
wshape[kernel_layout.indexof('O')] *= param.groups;
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
if (param.use_bias) { if (param.use_bias) {
......
...@@ -58,7 +58,7 @@ def test_dilated_conv2d(): ...@@ -58,7 +58,7 @@ def test_dilated_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d(): def test_grouped_conv2d_nchw():
x = sym.Variable("x") x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32, y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
name="y", padding=(1,1)) name="y", padding=(1,1))
...@@ -80,6 +80,28 @@ def test_grouped_conv2d(): ...@@ -80,6 +80,28 @@ def test_grouped_conv2d():
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_grouped_conv2d_nhwc():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
name="y", padding=(1,1), layout="NHWC", kernel_layout ='HWOI')
dtype = "float32"
dshape = (1, 18, 18, 32)
kshape = (3, 3, 32, 1)
oshape = (1, 18, 18, 32)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[2]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nhwc(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(1, 1, kshape[2])
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_conv2d_transpose(): def test_conv2d_transpose():
x = sym.Variable("x") x = sym.Variable("x")
...@@ -269,7 +291,8 @@ def test_resize_bilinear(): ...@@ -269,7 +291,8 @@ def test_resize_bilinear():
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d() test_conv2d()
test_dilated_conv2d() test_dilated_conv2d()
test_grouped_conv2d() test_grouped_conv2d_nchw()
test_grouped_conv2d_nhwc()
test_conv2d_transpose() test_conv2d_transpose()
test_max_pool2d() test_max_pool2d()
test_avg_pool2d() test_avg_pool2d()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment