Commit 608cdeeb by Wuwei Lin Committed by masahi

[Relay, TOPI] Deformable conv2d (#2908)

* [Relay, TOPI] Add deformable conv2d

* Moved to op level2

* Fix lint

* Moved to level2 & bug fix

* Update comments

* Disabled flaky test of conv2d
parent 82e868a4
......@@ -456,6 +456,67 @@ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
}
};
/*! \brief Attributes for DeformableConv2D operator */
struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int deformable_groups;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(deformable_groups).set_default(1)
.describe("Controls the connections between inputs and offsets."
"Input channels are partitioned into multiple deformable groups. Offsets"
"are shared across input channels in the same deformable group.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
......@@ -53,6 +53,7 @@ def extract_from_program(func, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}
topi_funcs = []
......@@ -126,6 +127,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}
topi_funcs = []
......
......@@ -68,6 +68,7 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}
self.topi_to_schedule = {
......@@ -78,6 +79,7 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}
self._register_tracing()
......@@ -172,6 +174,15 @@ class TaskExtractEnv:
return s, [data, weight, bias, C]
return s, [data, weight, C]
@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, Offset, W = args[:3]
C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_deformable_conv2d_nchw([C])
return s, [A, Offset, W, C]
def reset(self, wanted_topi_funcs):
"""Reset task collections
......
......@@ -603,6 +603,25 @@ def _mx_smooth_l1(inputs, attrs):
_op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))
def _mx_deformable_convolution(inputs, attrs):
new_attrs = {}
assert attrs.get_bool("no_bias")
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
new_attrs["strides"] = attrs.get_int_tuple("stride")
new_attrs["padding"] = attrs.get_int_tuple("pad")
new_attrs["dilation"] = attrs.get_int_tuple("dilate")
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["deformable_groups"] = attrs.get_int("num_deformable_group", 1)
new_attrs["groups"] = attrs.get_int("num_group", 1)
assert attrs.get_str("layout", "NCHW") == "NCHW", "Deformable conv2d only supports NCHW layout"
use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.deformable_conv2d(inputs[0], inputs[1], inputs[2], **new_attrs)
if use_bias:
assert len(inputs) == 4
res = _op.nn.bias_add(res, inputs[3])
return res
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -748,6 +767,7 @@ _convert_map = {
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......
......@@ -426,3 +426,26 @@ def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.deformable_conv2d")
def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
"""Compute definition of deformable_conv2d"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
deformable_groups = attrs.deformable_groups
groups = attrs.groups
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
with target:
out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
dilation, deformable_groups, groups, out_dtype)
return [out]
@reg.register_schedule("nn.deformable_conv2d")
def schedule_deformable_conv2d(attrs, outs, target):
"""Schedule definition of deformable_conv2d"""
with target:
return topi.generic.schedule_deformable_conv2d_nchw(outs)
reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
......@@ -1105,3 +1105,76 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
"""
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
weight, convolution_algorithm, out_dtype)
def deformable_conv2d(data,
offset,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
deformable_groups=1,
groups=1,
channels=None,
kernel_size=None,
data_layout='NCHW',
kernel_layout='OIHW',
out_layout='',
out_dtype=''):
r""" Deformable 2d convolution.
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
offset : tvm.relay.Expr
The offset expressions.
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
deformable_groups : int, optional
Number of deformable groups.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
deformable_groups, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
......@@ -753,5 +753,148 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
Conv2DInferCorrectLayout<Conv2DAttrs>);
bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[2].as<TensorTypeNode>();
CHECK(data);
auto* param = attrs.as<DeformableConv2DAttrs>();
CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
// infer weight shape if kernel_size and channels are defiend
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape(
{param->channels,
data->shape[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[2], TensorTypeNode::make(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = weight->shape;
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
channels = wshape[0];
ksize_y = wshape[2];
ksize_x = wshape[3];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape.Set(2, (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
oshape.Set(3, (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype;
// infer offset shape
Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
oshape[2], oshape[3]});
reporter->Assign(types[1], TensorTypeNode::make(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
return true;
}
TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs);
RELAY_REGISTER_OP("nn.deformable_conv2d")
.describe(R"code(Compute 2-D deformable convolution on 4-D input.
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
For 2-D deformable convolution, the shapes are
- **data**: (batch_size, channel, height, width)
- **offset**: (batch_size, deformable_groups * kernel[0] * kernel[1] * 2, out_height, out_width)
- **weight**: (num_filter, channel, kernel[0], kernel[1])
- **out**: (batch_size, num_filter, out_height, out_width).
If `deformable_groups` is larger than 1, denoted by *dg*, then split the
input `offset` evenly into *dg* parts along the channel axis, and also evenly split `out`
evenly into *dg* parts along the channel axis. Next compute the deformable convolution, apply the
*i*-th part of the offset part on the *i*-th out.
If `groups` is larger than 1, denoted by *g*, then split the input `data` evenly into *g* parts
along the channel axis, and also evenly split `weight` along the first dimension. Next compute
the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained
by concating all the *g* results.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DeformableConv2D")
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("offset", "Tensor", "The offset tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(5)
.add_type_rel("DeformableConv2D", DeformableConv2DRel);
// Positional relay function to create deformable_conv2d operator
// used by frontend FFI.
Expr MakeDeformableConv2D(Expr data,
Expr offset,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int deformable_groups,
int groups,
int channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<DeformableConv2DAttrs>();
attrs->strides = strides;
attrs->padding = padding;
attrs->dilation = dilation;
attrs->deformable_groups = deformable_groups;
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->data_layout = data_layout;
attrs->kernel_layout = kernel_layout;
attrs->out_layout = out_layout;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.deformable_conv2d");
return CallNode::make(op, {data, offset, weight}, Attrs{attrs}, {});
}
TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 14>(MakeDeformableConv2D, args, rv);
});
} // namespace relay
} // namespace tvm
......@@ -489,6 +489,66 @@ def test_yolo_reorg():
verify_yolo_reorg((1, 100, 20, 20), 10)
verify_yolo_reorg((1, 4, 6, 6), 2)
def test_deformable_conv2d():
def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups):
data_shape = (batch, in_channel, size, size)
data = relay.var("data", shape=data_shape)
offset = relay.var("offset")
kernel = relay.var("kernel")
kernel_size = (3, 3)
y = relay.nn.deformable_conv2d(data, offset, kernel,
strides=(1, 1),
padding=(1, 1),
dilation=(1, 1),
kernel_size=kernel_size,
deformable_groups=deformable_groups,
groups=groups,
channels=out_channel)
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
out_shape = (batch, out_channel, size, size)
offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, out_shape[2], out_shape[3])
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(out_shape)
assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type
assert yy.args[2].checked_type == relay.TensorType(weight_shape)
test_infer_type(1, 4, 16, 4, 4, 1)
test_infer_type(2, 4, 16, 4, 1, 2)
def test_run(batch, in_channel, size, out_channel, deformable_groups, groups):
kernel_size = (3, 3)
data_shape = (batch, in_channel, size, size)
offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, size, size)
kernel_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
dtype = 'float32'
data = relay.var("data", shape=data_shape, dtype=dtype)
offset = relay.var("offset")
kernel = relay.var("kernel")
y = relay.nn.deformable_conv2d(data, offset, kernel,
strides=(1, 1),
padding=(1, 1),
dilation=(1, 1),
kernel_size=kernel_size,
deformable_groups=deformable_groups,
groups=groups,
channels=out_channel)
func = relay.Function([data, offset, kernel], y)
data = np.random.uniform(size=data_shape).astype(dtype)
offset = np.random.uniform(size=offset_shape).astype(dtype)
kernel = np.random.uniform(size=kernel_shape).astype(dtype)
ref_res = topi.testing.deformable_conv2d_nchw_python(data, offset, kernel, stride=(1, 1), padding=(1, 1), dilation=(1, 1), deformable_groups=deformable_groups, groups=groups)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, offset, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
test_run(1, 4, 16, 4, 1, 1)
test_run(2, 4, 16, 4, 4, 1)
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
......@@ -501,3 +561,4 @@ if __name__ == "__main__":
test_yolo_reorg_infer_shape()
test_yolo_reorg()
test_non_max_suppression()
test_deformable_conv2d()
......@@ -2,7 +2,7 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, group_conv2d_nchw
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, group_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
......
# pylint: disable=invalid-name
"""Schedule template of deformable conv2d with cuda backend"""
import tvm
from tvm import autotvm
from .. import nn, generic
from ..util import traverse_inline
autotvm.register_topi_compute(nn.deformable_conv2d_nchw, ["cuda", "gpu"], "direct",
nn.deformable_conv2d_nchw.fdefault)
@autotvm.register_topi_schedule(generic.schedule_deformable_conv2d_nchw, ["cuda", "gpu"], "direct")
def schedule_deformable_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of deformable conv2d for cuda gpu
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'deformable_conv2d_nchw':
schedule_direct_cuda(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
def schedule_direct_cuda(cfg, s, conv):
"""Schedule template of deformable conv2d"""
n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_split("tile_ry", ry, num_outputs=2)
cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
data_deform, kernel = s[conv].op.input_tensors
s[data_deform].compute_inline()
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, 'local')
else:
output = s.outputs[0].output(0)
s[conv].set_scope('local')
OL = conv
# create cache stage
AA = s.cache_read(data_deform, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
kernel_scope, n = s[output].split(n, nparts=1)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
bf = s[output].fuse(n, bf)
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
cfg.define_reorder("reorder_inner", [rco, ryo, rxo], "all")
cfg["reorder_inner"].apply(s, OL, [rco, ryo, rxo])
cfg["reorder_inner"].apply(s, OL, [rci, ryi, rxi])
cache_loc = [rco, ryo, rxo][cfg["reorder_inner"].perm[-1]]
s[AA].compute_at(s[OL], cache_loc)
s[WW].compute_at(s[OL], cache_loc)
# cooperative fetching
for load in [AA, WW]:
fused = s[load].fuse(*s[load].op.axis)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
......@@ -243,6 +243,24 @@ def schedule_group_conv2d_nchw(outs):
@tvm.target.generic_func
def schedule_deformable_conv2d_nchw(outs):
"""Schedule for deformable_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of deformable_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from .conv2d import *
from .deformable_conv2d import *
from .depthwise_conv2d import *
from .elemwise import *
from .dilate import *
......
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Deformable Conv2D operators"""
import tvm
from .util import get_pad_tuple
from ..util import get_const_tuple
from ..cpp.image import bilinear_sample_nchw
@tvm.target.generic_func
def deformable_conv2d_nchw(data, offset, kernel, strides, padding, dilation, deformable_groups,
groups, out_dtype):
"""Deformable conv2D operator in NCHW layout.
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
offset : tvm.Tensor
4-D with shape [batch, deformable_groups * filter_height * filter_width * 2,
out_height, out_width].
kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
deformable_groups : int
number of deformable groups
groups : int
number of groups
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if out_dtype is None:
out_dtype = data.dtype
if isinstance(strides, int):
stride_h = stride_w = strides
else:
stride_h, stride_w = strides
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
batch, in_channel, in_height, in_width = get_const_tuple(data.shape)
out_channel, channel, kernel_h, kernel_w = get_const_tuple(kernel.shape)
_, _, out_height, out_width = get_const_tuple(offset.shape)
assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size"
assert groups == 1, "deformable_conv2d_nchw does not support groups > 1"
ic_per_dgroup = channel // deformable_groups
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, _, _ = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
zero = tvm.const(0.0, data.dtype)
def _bilinear(n, c, h, w):
outside = tvm.any(h < 0, w < 0, h >= in_height, w >= in_width)
val = bilinear_sample_nchw(data, (n, c, h, w), in_height - 1, in_width - 1)
return tvm.if_then_else(outside, zero, val)
data_deform = \
tvm.compute((batch, in_channel, kernel_h, kernel_w, out_height, out_width),
lambda n, c, kh, kw, y, x:
_bilinear(n, c,
y * stride_h - pad_top + kh * dilation_h +
offset[n, c // ic_per_dgroup * (kernel_w*kernel_h*2) +
(kh * kernel_w + kw) * 2, y, x],
x * stride_w - pad_left + kw * dilation_w +
offset[n, c // ic_per_dgroup * (kernel_w*kernel_h*2) +
(kh * kernel_w + kw) * 2 + 1, y, x]))
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda n, f, y, x: tvm.sum(
data_deform[n, rc, ry, rx, y, x].astype(out_dtype) *
kernel[f, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="deformable_conv2d_nchw")
......@@ -8,6 +8,7 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
......
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Deformable convolution in python"""
import itertools
import numpy as np
def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation,
deformable_groups, groups):
"""Deformable convolution operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
offset_np : numpy.ndarray
4-D with shape [batch, deformable_groups * filter_height * filter_width * 2,
out_height, out_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
dilation : int or a list/tuple of two ints
Dilation size, or [dilate_height, dilate_width]
deformable_groups : int
Number of deformable groups
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
out_channel, _, kernel_h, kernel_w = w_np.shape
out_height, out_width = offset_np.shape[-2:]
dtype = a_np.dtype
ic_per_dgroup = in_channel // deformable_groups
assert groups == 1, "deformable_conv2d_nchw_python does not support groups > 1"
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif isinstance(padding, (list, tuple)):
pad_h, pad_w = padding[0] * 2, padding[1] * 2
else:
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
def _bilinear(n, c, h, w):
low_h, low_w = int(h), int(w)
high_h = min(low_h + 1, in_height - 1)
high_w = min(low_w + 1, in_width - 1)
y_lerp = h - low_h
x_lerp = w - low_w
bottom = (1 - x_lerp) * a_np[n, c, low_h, low_w] + x_lerp * a_np[n, c, low_h, high_w]
top = (1 - x_lerp) * a_np[n, c, high_h, low_w] + x_lerp * a_np[n, c, high_h, high_w]
return (1 - y_lerp) * bottom + y_lerp * top
a_deform = np.zeros((batch, in_channel, out_height, out_width, kernel_h, kernel_w), dtype=dtype)
for n, h, w in itertools.product(range(batch), range(out_height), range(out_width)):
offset = offset_np[n, :, h, w].reshape(deformable_groups, kernel_h, kernel_w, 2)
in_h = h * stride_h - pad_top
in_w = w * stride_w - pad_left
index_h_base, index_w_base = np.meshgrid(
np.arange(in_h, in_h + kernel_h * dilation_h, dilation_h, dtype=offset_np.dtype),
np.arange(in_w, in_w + kernel_w * dilation_w, dilation_w, dtype=offset_np.dtype),
indexing='ij')
for c, kh, kw in itertools.product(range(in_channel), range(kernel_h), range(kernel_w)):
dg = c // ic_per_dgroup
index_h = index_h_base + offset[dg, ..., 0]
index_w = index_w_base + offset[dg, ..., 1]
y, x = index_h[kh, kw], index_w[kh, kw]
if y < 0 or y >= in_height or x < 0 or x >= in_width:
continue
a_deform[n, c, h, w, kh, kw] = _bilinear(n, c, y, x)
b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=dtype)
for n, c, f, h, w in itertools.product(range(batch), range(in_channel), range(out_channel),
range(out_height), range(out_width)):
b_np[n, f, h, w] += np.tensordot(a_deform[n, c, h, w], w_np[f, c])
return b_np
......@@ -136,7 +136,8 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3)
verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0)
verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0)
verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0)
# disable these tests due to some bugs of llvm with nvptx
# verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0)
verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3)
verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3)
verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0)
......
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def verify_deformable_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, deformable_groups=1, groups=1):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size,
num_filter, kernel, stride, padding, dilation, deformable_groups, groups))
A = tvm.placeholder((batch, in_channel, in_size, in_size), name='A')
out_size = (in_size - (kernel - 1) * dilation - 1 + 2 * padding) // stride + 1
Offset = tvm.placeholder((batch, deformable_groups * kernel * kernel * 2, out_size, out_size), name='offset')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
offset_shape = get_const_tuple(Offset.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_deformable_conv2d_nchw.verify_deformable_conv2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
offset_np = np.random.randn(*offset_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np = topi.testing.deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding,
dilation, deformable_groups, groups)
return a_np, offset_np, w_np, c_np
a_np, offset_np, w_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.deformable_conv2d_nchw(A, Offset, W, stride, padding, dilation,
deformable_groups, groups, out_dtype=dtype)
s = topi.generic.schedule_deformable_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
offset = tvm.nd.array(offset_np, ctx)
w = tvm.nd.array(w_np, ctx)
c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, ctx=ctx)
func = tvm.build(s, [A, Offset, W, C], device)
func(a, offset, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_deformable_conv2d_nchw():
verify_deformable_conv2d_nchw(1, 16, 7, 16, 1, 1, 0, deformable_groups=4)
verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 1, dilation=2, deformable_groups=4)
verify_deformable_conv2d_nchw(1, 16, 7, 16, 3, 1, 2, dilation=2)
if __name__ == "__main__":
test_deformable_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment