Commit c5fdb000 by Haichen Shen Committed by Yizhi Liu

[Relay][Frontend] Add Crop op converter (#3241)

* Add Crop op converter

* lint

* x
parent be0340eb
......@@ -269,7 +269,7 @@ def _crop_like(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented(
'Center crop is not supported in operator crop_like.')
if len(inputs) < 2:
raise RuntimeError("Only support crop_like pattern.")
raise tvm.error.OpAttributeUnimplemented("Only support crop_like pattern.")
new_attrs["axis"] = [2, 3]
return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs)
......
......@@ -149,7 +149,7 @@ def _mx_conv2d_transpose(inputs, attrs):
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False)
use_bias = not attrs.get_bool("no_bias", True)
res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs)
if use_bias:
......@@ -277,6 +277,28 @@ def _mx_slice_axis(inputs, attrs):
return _op.strided_slice(inputs[0], begin, end)
def _mx_crop_like(inputs, attrs):
if len(inputs) < 2:
raise tvm.error.OpAttributeUnimplemented(
"Only support crop_like pattern for operator Crop.")
if attrs.get_bool("center_crop", False):
raise tvm.error.OpAttributeUnimplemented(
"Center crop is not supported in operator Crop.")
if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0):
raise tvm.error.OpAttributeUnimplemented(
"Doesn't support h_w in operator Crop.")
offset = attrs.get_int_tuple("offset", (0, 0))
new_attrs = {}
if offset == (0, 0):
new_attrs["axes"] = (2, 3)
return _op.slice_like(*inputs, **new_attrs)
like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
return _op.strided_slice(inputs[0], **new_attrs)
def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1)
new_attrs = {}
......@@ -300,6 +322,10 @@ def _mx_softmax_output(inputs, attrs):
return _op.nn.softmax(inputs[0])
def _mx_linear_regression_output(inputs, _):
return inputs[0]
def _mx_concat(inputs, attrs):
axis = attrs.get_int("dim", 1)
return _op.concatenate(tuple(inputs), axis=axis)
......@@ -890,6 +916,7 @@ _convert_map = {
"argsort" : _mx_argsort,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
......@@ -905,11 +932,12 @@ _convert_map = {
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# Depricated:
"Crop" : _mx_crop_like,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "Crop" : _crop_like,
}
# set identity list
......
......@@ -583,6 +583,31 @@ def test_forward_rnn_layer():
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
def test_forward_Crop():
def verify(xshape, yshape, offset=None):
x_data = np.random.uniform(size=xshape).astype("float32")
y_data = np.random.uniform(size=yshape).astype("float32")
if offset is None:
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"))
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data))
else:
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset)
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
if offset is None or offset == (0, 0):
op_res = intrp.evaluate(new_sym)(x_data, y_data)
else:
op_res = intrp.evaluate(new_sym)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 3, 40, 40), (1, 3, 20, 20))
verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0))
verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10))
verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
if __name__ == '__main__':
test_forward_mlp()
......@@ -624,3 +649,4 @@ if __name__ == '__main__':
test_forward_gather_nd()
test_forward_bilinear_resize()
test_forward_rnn_layer()
test_forward_Crop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment