Commit 8f56949b by Ruizhe Zhao (Vincent) Committed by Wuwei Lin

Fixed issue #3069 by checking op tag (#3070)

* Fixed issue #3069 by adding in_channels

* Registerd group_conv2d_nchw as topi compute

* Improved by checking tag value

* Removed group_conv2d_nchw topi registration

* Added test for relay group_conv2d_nchw

* Added assertions to forbid small group size

* Removed hard-coded oc_block_factor

* Added explanatory comments to group_conv2d_nchw_cuda

* Updated group_conv2d_nchw_cuda schedule

Removed 'direct' CUDA tests

* Reverted an accidental change in a conv2d test

* Fixed indentation problems

* Fixed a mis-commented line

* Reverted change in group_conv2d_nchw tag

* Removed commented int8 group_conv2d test

* Fixed group size assertions in group_conv2d_nchw_cuda
parent 7e68d63f
...@@ -169,7 +169,6 @@ class Conv(OnnxOpConverter): ...@@ -169,7 +169,6 @@ class Conv(OnnxOpConverter):
@classmethod @classmethod
def _impl_v1(cls, inputs, attr, params): def _impl_v1(cls, inputs, attr, params):
# get number of channels
out = AttrCvt(op_name=dimension_picker('conv'), out = AttrCvt(op_name=dimension_picker('conv'),
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
#pylint: disable=invalid-name, unused-argument # pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target): ...@@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target):
with target: with target:
return topi.generic.schedule_softmax(outputs) return topi.generic.schedule_softmax(outputs)
reg.register_pattern("nn.softmax", OpPattern.OPAQUE) reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
schedule_broadcast = schedule_injective schedule_broadcast = schedule_injective
@reg.register_schedule("nn.log_softmax") @reg.register_schedule("nn.log_softmax")
def schedule_log_softmax(_, outputs, target): def schedule_log_softmax(_, outputs, target):
"""Schedule definition of log_softmax""" """Schedule definition of log_softmax"""
with target: with target:
return topi.generic.schedule_softmax(outputs) return topi.generic.schedule_softmax(outputs)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
...@@ -55,12 +58,14 @@ def compute_dense(attrs, inputs, out_type, target): ...@@ -55,12 +58,14 @@ def compute_dense(attrs, inputs, out_type, target):
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)] return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]
@reg.register_schedule("nn.dense") @reg.register_schedule("nn.dense")
def schedule_dense(attrs, outputs, target): def schedule_dense(attrs, outputs, target):
"""Schedule definition of dense""" """Schedule definition of dense"""
with target: with target:
return topi.generic.schedule_dense(outputs) return topi.generic.schedule_dense(outputs)
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -70,16 +75,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target): ...@@ -70,16 +75,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul""" """Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])] return [topi.nn.batch_matmul(inputs[0], inputs[1])]
@reg.register_schedule("nn.batch_matmul") @reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target): def schedule_batch_matmul(attrs, outputs, target):
"""Schedule definition of batch_matmul""" """Schedule definition of batch_matmul"""
with target: with target:
return topi.generic.schedule_batch_matmul(outputs) return topi.generic.schedule_batch_matmul(outputs)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d # conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
if 'conv2d' in op.tag:
return op
for tensor in op.input_tensors:
op_ = _find_conv2d_op(tensor.op)
if op_ is not None:
return op_
return None
@reg.register_compute("nn.conv2d") @reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target): def compute_conv2d(attrs, inputs, out_type, target):
"""Compute definition of conv2d""" """Compute definition of conv2d"""
...@@ -127,6 +145,7 @@ def schedule_conv2d(attrs, outs, target): ...@@ -127,6 +145,7 @@ def schedule_conv2d(attrs, outs, target):
groups = attrs.groups groups = attrs.groups
layout = attrs.data_layout layout = attrs.data_layout
kernel_layout = attrs.kernel_layout kernel_layout = attrs.kernel_layout
with target: with target:
if groups == 1 and layout == "NCHW": if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs) return topi.generic.schedule_conv2d_nchw(outs)
...@@ -135,12 +154,19 @@ def schedule_conv2d(attrs, outs, target): ...@@ -135,12 +154,19 @@ def schedule_conv2d(attrs, outs, target):
if groups == 1 and layout == "NHWC": if groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1: if groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
is_depthwise = 'depthwise' in op.tag
if is_depthwise:
if layout == "NCHW": if layout == "NCHW":
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return topi.generic.schedule_depthwise_conv2d_nchw(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs)
if layout == "NHWC" and kernel_layout == "HWOI": if layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs) return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
if layout == "NCHW4c": else:
if layout in ["NCHW", "NCHW4c"]:
return topi.generic.schedule_group_conv2d_nchw(outs) return topi.generic.schedule_group_conv2d_nchw(outs)
raise ValueError("No compatible schedule") raise ValueError("No compatible schedule")
...@@ -151,6 +177,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): ...@@ -151,6 +177,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -169,18 +196,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): ...@@ -169,18 +196,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
assert layout == "NCHW", "only support nchw for now" assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now" assert groups == 1, "only support groups == 1 for now"
out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype) out = topi.nn.conv2d_transpose_nchw(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding) output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out, out = topi.nn.pad(out,
[0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
return [out] return [out]
@reg.register_schedule("nn.conv2d_transpose") @reg.register_schedule("nn.conv2d_transpose")
def schedule_conv2d_transpose(attrs, outs, target): def schedule_conv2d_transpose(attrs, outs, target):
"""Schedule definition of conv2d_transpose""" """Schedule definition of conv2d_transpose"""
with target: with target:
return topi.generic.schedule_conv2d_transpose_nchw(outs) return topi.generic.schedule_conv2d_transpose_nchw(outs)
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
# bias_add # bias_add
...@@ -196,6 +226,7 @@ def schedule_max_pool2d(attrs, outs, target): ...@@ -196,6 +226,7 @@ def schedule_max_pool2d(attrs, outs, target):
with target: with target:
return topi.generic.schedule_pool(outs, layout) return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -207,6 +238,7 @@ def schedule_avg_pool2d(attrs, outs, target): ...@@ -207,6 +238,7 @@ def schedule_avg_pool2d(attrs, outs, target):
with target: with target:
return topi.generic.schedule_pool(outs, layout) return topi.generic.schedule_pool(outs, layout)
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -217,6 +249,7 @@ def schedule_global_max_pool2d(_, outs, target): ...@@ -217,6 +249,7 @@ def schedule_global_max_pool2d(_, outs, target):
with target: with target:
return topi.generic.schedule_global_pool(outs) return topi.generic.schedule_global_pool(outs)
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -227,6 +260,7 @@ def schedule_global_avg_pool2d(_, outs, target): ...@@ -227,6 +260,7 @@ def schedule_global_avg_pool2d(_, outs, target):
with target: with target:
return topi.generic.schedule_global_pool(outs) return topi.generic.schedule_global_pool(outs)
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# leaky_relu # leaky_relu
...@@ -250,12 +284,14 @@ def compute_lrn(attrs, inputs, out_dtype, target): ...@@ -250,12 +284,14 @@ def compute_lrn(attrs, inputs, out_dtype, target):
return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
attrs.alpha, attrs.beta, attrs.bias)] attrs.alpha, attrs.beta, attrs.bias)]
@reg.register_schedule("nn.lrn") @reg.register_schedule("nn.lrn")
def schedule_lrn(attrs, outs, target): def schedule_lrn(attrs, outs, target):
"""Schedule definition of lrn""" """Schedule definition of lrn"""
with target: with target:
return topi.generic.schedule_lrn(outs) return topi.generic.schedule_lrn(outs)
reg.register_pattern("nn.lrn", OpPattern.OPAQUE) reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
...@@ -265,20 +301,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target): ...@@ -265,20 +301,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target):
"""Compute definition of l2 normalize""" """Compute definition of l2 normalize"""
return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)] return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]
@reg.register_schedule("nn.l2_normalize") @reg.register_schedule("nn.l2_normalize")
def schedule_l2_normalize(attrs, outs, target): def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize""" """Schedule definition of l2 normalize"""
with target: with target:
return topi.generic.schedule_l2_normalize(outs) return topi.generic.schedule_l2_normalize(outs)
reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
# upsampling # upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective) reg.register_schedule("nn.upsampling", reg.schedule_injective)
def schedule_upsampling(_, outs, target): def schedule_upsampling(_, outs, target):
"""Schedule definition of upsampling""" """Schedule definition of upsampling"""
with target: with target:
return topi.generic.schedule_injective(outs) return topi.generic.schedule_injective(outs)
# pad # pad
reg.register_schedule("nn.pad", schedule_broadcast) reg.register_schedule("nn.pad", schedule_broadcast)
...@@ -304,12 +346,14 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_ ...@@ -304,12 +346,14 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_
return [out] return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_without_weight_transform""" """Schedule definition of conv2d_winograd_without_weight_transform"""
with target: with target:
return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -317,15 +361,18 @@ reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", ...@@ -317,15 +361,18 @@ reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_weight_transform""" """Compute definition of contrib_conv2d_winograd_weight_transform"""
out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size')) out = topi.nn.conv2d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))
return [out] return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_weight_transform""" """Schedule definition of contrib_conv2d_winograd_weight_transform"""
with target: with target:
return topi.generic.schedule_conv2d_winograd_weight_transform(outs) return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -353,12 +400,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform( ...@@ -353,12 +400,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
return [out] return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target): def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform""" """Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
with target: with target:
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs) return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform", reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE) OpPattern.OPAQUE)
...@@ -371,12 +420,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d ...@@ -371,12 +420,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
inputs[0], convolution_algorithm, out_dtype) inputs[0], convolution_algorithm, out_dtype)
return [out] return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform") @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform""" """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
with target: with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
OpPattern.OPAQUE) OpPattern.OPAQUE)
...@@ -397,15 +448,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): ...@@ -397,15 +448,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
data_layout, out_layout, out_dtype) data_layout, out_layout, out_dtype)
return [out] return [out]
@reg.register_schedule("nn.contrib_conv2d_NCHWc") @reg.register_schedule("nn.contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target): def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc""" """Schedule definition of contrib_conv2d_NCHWc"""
with target: with target:
return topi.generic.schedule_conv2d_NCHWc(outs) return topi.generic.schedule_conv2d_NCHWc(outs)
reg.register_pattern("nn.contrib_conv2d_NCHWc", reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc""" """Compute definition of depthwise conv2d NCHWc"""
...@@ -422,15 +476,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): ...@@ -422,15 +476,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
data_layout, out_layout, out_dtype) data_layout, out_layout, out_dtype)
return [out] return [out]
@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc""" """Schedule definition of contrib_conv2d_NCHWc"""
with target: with target:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.deformable_conv2d") @reg.register_compute("nn.deformable_conv2d")
def compute_deformable_conv2d(attrs, inputs, out_dtype, target): def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
"""Compute definition of deformable_conv2d""" """Compute definition of deformable_conv2d"""
...@@ -446,10 +503,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target): ...@@ -446,10 +503,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
dilation, deformable_groups, groups, out_dtype) dilation, deformable_groups, groups, out_dtype)
return [out] return [out]
@reg.register_schedule("nn.deformable_conv2d") @reg.register_schedule("nn.deformable_conv2d")
def schedule_deformable_conv2d(attrs, outs, target): def schedule_deformable_conv2d(attrs, outs, target):
"""Schedule definition of deformable_conv2d""" """Schedule definition of deformable_conv2d"""
with target: with target:
return topi.generic.schedule_deformable_conv2d_nchw(outs) return topi.generic.schedule_deformable_conv2d_nchw(outs)
reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
...@@ -296,7 +296,7 @@ def override_native_generic_func(func_name): ...@@ -296,7 +296,7 @@ def override_native_generic_func(func_name):
def generic_func(fdefault): def generic_func(fdefault):
"""Wrap a target generic function. """Wrap a target generic function.
Generic function allows registeration of further functions Generic function allows registration of further functions
that can be dispatched on current target context. that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called. If no registered dispatch is matched, the fdefault will be called.
......
...@@ -86,9 +86,13 @@ def test_conv2d_run(): ...@@ -86,9 +86,13 @@ def test_conv2d_run():
fref=None, fref=None,
groups=1, groups=1,
dilation=(1, 1), dilation=(1, 1),
except_targets=None,
**attrs): **attrs):
x = relay.var("x", shape=dshape) if except_targets is None:
w = relay.var("w") except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv2d(x, w, y = relay.nn.conv2d(x, w,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
...@@ -100,11 +104,15 @@ def test_conv2d_run(): ...@@ -100,11 +104,15 @@ def test_conv2d_run():
dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
if fref is None: if fref is None:
ref_res = topi.testing.conv2d_nchw_python( ref_res = topi.testing.conv2d_nchw_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding) data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding,
groups=groups)
else: else:
ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
for target, ctx in ctx_list(): for target, ctx in ctx_list():
if target in except_targets:
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target) intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel) op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
...@@ -117,6 +125,21 @@ def test_conv2d_run(): ...@@ -117,6 +125,21 @@ def test_conv2d_run():
fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw( fref=lambda x, w: topi.testing.depthwise_conv2d_python_nchw(
x, w, (1, 1), "SAME")) x, w, (1, 1), "SAME"))
# CUDA is disabled for 'direct' schedule:
# https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
# group conv2d
dshape = (1, 32, 18, 18)
kshape = (32, 4, 3, 3)
run_test_conv2d("float32", "float32", 1, dshape, kshape,
padding=(1, 1), channels=32, groups=8, kernel_size=(3 ,3),
except_targets=['cuda'])
# also group conv2d
dshape = (1, 32, 18, 18)
kshape = (64, 1, 3, 3)
run_test_conv2d("float32", "float32", 1, dshape, kshape,
padding=(1, 1), channels=64, groups=32, kernel_size=(3 ,3),
except_targets=['cuda'])
# normal conv2d # normal conv2d
dshape = (1, 3, 224, 224) dshape = (1, 3, 224, 224)
kshape = (10, 3, 3, 3) kshape = (10, 3, 3, 3)
......
...@@ -27,10 +27,13 @@ from ..util import traverse_inline, get_const_tuple, get_const_int ...@@ -27,10 +27,13 @@ from ..util import traverse_inline, get_const_tuple, get_const_int
from .. import nn, generic from .. import nn, generic
@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['direct', 'int8']) autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], 'direct',
nn.group_conv2d_nchw.fdefault)
@autotvm.register_topi_compute(nn.group_conv2d_nchw, ['cuda', 'gpu'], ['int8'])
def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
out_dtype='float32'): out_dtype='float32'):
"""Group convolution operator in NCHW layout. """Group convolution operator for 'group_conv2d_NCHWc_int8'.
Parameters Parameters
---------- ----------
...@@ -76,7 +79,7 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, ...@@ -76,7 +79,7 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
assert out_channels % groups == 0, "output channels must divide group size" assert out_channels % groups == 0, "output channels must divide group size"
assert channels % ic_block_factor == 0, \ assert channels % ic_block_factor == 0, \
"Number of input channels per group must divide {}".format(ic_block_factor) "Number of input channels per group must divide {}".format(ic_block_factor)
assert out_channels % 4 == 0, \ assert out_channels % oc_block_factor == 0, \
"Number of output channels per group must divide {}".format(oc_block_factor) "Number of output channels per group must divide {}".format(oc_block_factor)
packed_data = tvm.compute((batch, channels // ic_block_factor, height, width, packed_data = tvm.compute((batch, channels // ic_block_factor, height, width,
...@@ -99,6 +102,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, ...@@ -99,6 +102,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple( oc_chunk, _, kernel_h, kernel_w, oc_block, ic_block = get_const_tuple(
packed_kernel.shape) packed_kernel.shape)
# TODO(kumasento): these assertions ensure that the number of groups
# should be smaller or equal to the number of blocks, so that each
# group will have at least one block.
# Shall we pad the channels to avoid raising assertions?
assert groups <= oc_chunk, \
('Number of groups {} should be less than '
'output channel chunk size {}'.format(groups, oc_chunk))
assert groups <= ic_chunk, \
('Number of groups {} should be less than '
'input channel chunk size {}'.format(groups, ic_chunk))
if isinstance(stride, int): if isinstance(stride, int):
stride_h = stride_w = stride stride_h = stride_w = stride
else: else:
...@@ -109,9 +123,9 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, ...@@ -109,9 +123,9 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
# pad the input data
pad_top, pad_left, pad_down, pad_right = get_pad_tuple( pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w)) padding, (kernel_h, kernel_w))
# compute graph
pad_before = [0, 0, pad_top, pad_left, 0] pad_before = [0, 0, pad_top, pad_left, 0]
pad_after = [0, 0, pad_down, pad_right, 0] pad_after = [0, 0, pad_down, pad_right, 0]
pad_data = pad(packed_data, pad_before, pad_after, name="pad_data") pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")
...@@ -129,6 +143,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, ...@@ -129,6 +143,17 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
kh = tvm.reduce_axis((0, kernel_h), name='kh') kh = tvm.reduce_axis((0, kernel_h), name='kh')
kw = tvm.reduce_axis((0, kernel_w), name='kw') kw = tvm.reduce_axis((0, kernel_w), name='kw')
# NOTE(kumasento): explanation of this snippet -
# oc_chunk//groups and ic_chunk//groups give you the number of blocks,
# i.e., chunk, per group.
# occ is the ID of the output channel block, so that occ//(oc_chunk//groups)
# produces the ID of the group.
# Multiplying that result with ic_chunk//groups resulting in the ID
# of the beginning block of the corresponding input group.
# Adding the block offset (icc) will give you the exact block ID.
#
# Compared with a normal convolution, group convolution only sums
# input channels from the group that an output channel resides in.
conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb: conv = tvm.compute(oshape, lambda n, occ, oh, ow, ocb:
tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc, tvm.sum(pad_data[n, occ//(oc_chunk//groups)*(ic_chunk//groups)+icc,
oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb] oh*stride_h+kh*dilation_h, ow*stride_w+kw*dilation_w, icb]
...@@ -138,8 +163,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups, ...@@ -138,8 +163,10 @@ def group_conv2d_nchw_cuda(cfg, data, kernel, stride, padding, dilation, groups,
.astype('int32'), .astype('int32'),
axis=[icc, kh, kw, icb])) axis=[icc, kh, kw, icb]))
# Type conversion
output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype), output = tvm.compute(oshape, lambda *index: conv(*index).astype(out_dtype),
tag='group_conv2d_NCHWc_int8') tag='group_conv2d_NCHWc_int8')
num_flop = batch * oc_chunk * oc_block * out_height * out_width * \ num_flop = batch * oc_chunk * oc_block * out_height * out_width * \
ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups ic_chunk * ic_block * kernel_h * kernel_w * 2 // groups
cfg.add_flop(num_flop) cfg.add_flop(num_flop)
...@@ -295,7 +322,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output): ...@@ -295,7 +322,7 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw, @autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
["cuda", "gpu"], ["direct", "int8"]) ["cuda", "gpu"], ["int8"])
def schedule_conv2d_nchw_cuda(cfg, outs): def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of group conv2d for cuda gpu """TOPI schedule callback of group conv2d for cuda gpu
......
...@@ -242,7 +242,7 @@ def schedule_depthwise_conv2d_NCHWc(outs): ...@@ -242,7 +242,7 @@ def schedule_depthwise_conv2d_NCHWc(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_group_conv2d_nchw(outs): def schedule_group_conv2d_nchw(outs):
"""Schedule for conv2d_nchw """Schedule for group_conv2d_nchw
Parameters Parameters
---------- ----------
......
...@@ -603,4 +603,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp ...@@ -603,4 +603,4 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
yy * stride_h + ry * dilation_h, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) * xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype), Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw") axis=[rc, ry, rx]), tag='group_conv2d_nchw')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment