Commit 510bd8f6 by Neo Chien Committed by Zhi

[Relay][Frontend][ONNX] operator support: DepthToSpace, SpaceToDepth (#4271)

parent 135587aa
......@@ -472,6 +472,76 @@ class Reshape(OnnxOpConverter):
static_shape.asnumpy().astype('int32')))
return out
class DepthToSpace(OnnxOpConverter):
""" Operator converter for DepthToSpace.
"""
@classmethod
def _impl_v11(cls, inputs, attr, params):
block_size = int(attr['blocksize'])
mode = attr.get("mode", "DCR")
# handle NCHW layout
indata = infer_value_simulated(inputs[0], params)
in_n, in_c, in_h, in_w = indata.shape
# reshape to proper output
new_c = int(in_c / (block_size * block_size))
new_h = in_h * block_size
new_w = in_w * block_size
newshape = (in_n, new_c, new_h, new_w)
if mode == "DCR":
# expand input to larger dimension.
expanded = _op.reshape(inputs[0],
newshape=(in_n, block_size, block_size, new_c, in_h, in_w))
# reorder to expand spatial blocks.
transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2))
else: # CRD mode
# expand input to larger dimension.
expanded = _op.reshape(inputs[0],
newshape=(in_n, new_c, block_size, block_size, in_h, in_w))
# reorder to expand spatial blocks.
transposed = _op.transpose(expanded, axes=(0, 1, 4, 2, 5, 3))
return AttrCvt(op_name="reshape",
extras={'newshape': newshape},
ignores=['mode', 'blocksize'])([transposed], attr)
class SpaceToDepth(OnnxOpConverter):
""" Operator converter for SpaceToDepth.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
block_size = int(attr['blocksize'])
# handle NCHW layout
indata = infer_value_simulated(inputs[0], params)
in_n, in_c, in_h, in_w = indata.shape
# reshape to proper output
new_c = in_c * (block_size * block_size)
new_h = int(in_h / block_size)
new_w = int(in_w / block_size)
newshape = (in_n, new_c, new_h, new_w)
# expand input to larger dimension.
expanded = _op.reshape(inputs[0],
newshape=(in_n, in_c, new_h, block_size, new_w, block_size))
# reorder to expand spatial blocks.
transposed = _op.transpose(expanded, axes=(0, 3, 5, 1, 2, 4))
return AttrCvt(op_name="reshape",
extras={'newshape': newshape},
ignores=['blocksize'])([transposed], attr)
class Concat(OnnxOpConverter):
""" Operator converter for Concat.
"""
......@@ -1121,6 +1191,8 @@ def _get_convert_map(opset):
'Split': Split.get_converter(opset),
'Slice': Slice.get_converter(opset),
'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
'DepthToSpace': DepthToSpace.get_converter(opset),
'SpaceToDepth': SpaceToDepth.get_converter(opset),
'Gather': Gather.get_converter(opset),
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
......
......@@ -77,19 +77,19 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
return tvm_output.asnumpy()
def get_caffe2_output(model, x, dtype='float32'):
import caffe2.python.onnx.backend
prepared_backend = caffe2.python.onnx.backend.prepare(model)
W = {model.graph.input[0].name: x.astype(dtype)}
c2_out = prepared_backend.run(W)[0]
return c2_out
def get_onnxruntime_output(model, x, dtype='float32'):
import onnxruntime.backend
rep = onnxruntime.backend.prepare(model, 'CPU')
x = x.astype(dtype)
ort_out = rep.run(x)[0]
return ort_out
def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
dtype = 'float32'
x = np.random.uniform(size=data_shape)
model = onnx.load_model(graph_file)
c2_out = get_caffe2_output(model, x, dtype)
c2_out = get_onnxruntime_output(model, x, dtype)
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
......@@ -142,6 +142,57 @@ def test_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
def verify_depth_to_space(inshape, outshape, mode, blockSize):
node = onnx.helper.make_node('DepthToSpace',
inputs=['x'],
outputs=['y'],
blocksize=blockSize)
graph = helper.make_graph([node],
"depth_to_space_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
model = helper.make_model(graph, producer_name='depth_to_space_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=inshape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
onnx_out = get_onnxruntime_output(model, x, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out)
def test_depth_to_space():
# current onnx.checker use OpSet-1 version of DepthToSpace, which doesn't have a mode argument.
# TO-DO, we can add mode arguement to test CRD mode and DCR mode
# in the future when we update to a newer onnx version.
verify_depth_to_space((1, 8, 2, 3), (1, 2, 4, 6), mode="CRD", blockSize=2)
def verify_space_to_depth(inshape, outshape, blockSize):
node = onnx.helper.make_node('SpaceToDepth',
inputs=['x'],
outputs=['y'],
blocksize=blockSize)
graph = helper.make_graph([node],
"space_to_depth_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])
model = helper.make_model(graph, producer_name='space_to_depth_test')
for target, ctx in ctx_list():
x = np.random.uniform(size=inshape).astype('float32')
tvm_out = get_tvm_output(model, x, target, ctx, outshape, 'float32')
onnx_out = get_onnxruntime_output(model, x, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out)
def test_space_to_depth():
verify_space_to_depth((1, 1, 4, 6), (1, 4, 2, 3), 2)
def test_shape():
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)
......@@ -1372,7 +1423,7 @@ def check_torch_conversion(model, input_size):
onnx_model = onnx.load(file_name)
for target, ctx in ctx_list():
input_data = np.random.uniform(size=input_size).astype('int32')
c2_out = get_caffe2_output(onnx_model, input_data)
c2_out = get_onnxruntime_output(onnx_model, input_data)
tvm_out = get_tvm_output(onnx_model, input_data, target, ctx)
tvm.testing.assert_allclose(c2_out, tvm_out)
......@@ -1574,6 +1625,7 @@ def test_erf():
z = scipy.special.erf(x)
verify_erf(x, z)
def verify_where(condition, x, y, dtype, outdata):
node = helper.make_node('Where', inputs=['condition', 'x', 'y'], outputs=['out'])
graph = helper.make_graph([node],
......@@ -1588,6 +1640,7 @@ def verify_where(condition, x, y, dtype, outdata):
tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)
def test_where():
condition = np.array([[1, 0], [1, 1]], dtype=np.bool)
x = np.array([[1, 2], [3, 4]], dtype=np.int64)
......@@ -1704,3 +1757,5 @@ if __name__ == '__main__':
test_erf()
test_where()
test_or()
test_depth_to_space()
test_space_to_depth()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment