Commit cbec5b94 by Yong Wu Committed by Siva

[Relay] Add ResizeNearestNeighbor and CropAndResize in tf converter (#3393)

parent fb95a985
...@@ -484,6 +484,54 @@ def _decode_image(): ...@@ -484,6 +484,54 @@ def _decode_image():
return inputs[0] return inputs[0]
return _impl return _impl
def _crop_and_resize():
def _impl(inputs, attr, params):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist()
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
crop_size = _infer_value(inputs[3], params).asnumpy().tolist()
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
method = attr['method'].decode()
attrs = {}
attrs['size'] = crop_size
attrs['layout'] = 'NHWC'
if method.lower() == 'nearest':
raise tvm.error.OpAttributeUnimplemented(
'Attribute method=nearest is not supported')
else:
attrs['align_corners'] = True
attrs['method'] = 'BILINEAR'
out = None
begin = [0] * data_dim
size = data_shape[:]
for idx in box_ind:
# 1) Crop
# y is mapped to the image coordinate at y * (image_height - 1)
# x is mapped to the image coordinate at x * (image_width - 1)
begin[0] = idx
begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1)))
begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1)))
size[0] = idx + 1
size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1
size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
# 2) Resize
res_resize = _get_relay_op('resize')(res_crop, **attrs)
out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
return out
return _impl
def _cast(): def _cast():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return inputs[0].astype(attr['DstT'].name) return inputs[0].astype(attr['DstT'].name)
...@@ -514,6 +562,21 @@ def _resize_bilinear(): ...@@ -514,6 +562,21 @@ def _resize_bilinear():
extras={'method': "BILINEAR"})(inputs, attr) extras={'method': "BILINEAR"})(inputs, attr)
return _impl return _impl
def _resize_nearest_neighbor():
def _impl(inputs, attr, params):
size = attr['_output_shapes'][0][1:3]
if -1 in size:
size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
attr['size'] = size
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'
return AttrCvt(op_name="resize",
ignores=['Tdim'],
extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr)
return _impl
def _check_numerics(): def _check_numerics():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
# Making a copy node assuming no need to verify # Making a copy node assuming no need to verify
...@@ -593,7 +656,7 @@ def _slice(): ...@@ -593,7 +656,7 @@ def _slice():
end[i] = data_shape[i] - begin[i] end[i] = data_shape[i] - begin[i]
else: else:
end[i] += begin[i] end[i] += begin[i]
return _op.strided_slice(inputs[0], begin=begin, end=size) return _op.strided_slice(inputs[0], begin=begin, end=end)
return _impl return _impl
...@@ -1243,6 +1306,7 @@ _convert_map = { ...@@ -1243,6 +1306,7 @@ _convert_map = {
'Concat' : _concat(), 'Concat' : _concat(),
'ConcatV2' : _concatV2(), 'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'), 'Conv2D' : _conv('conv'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(), 'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'), 'DepthwiseConv2dNative' : _conv('depthwise'),
'DepthToSpace' : _depth_to_space(), 'DepthToSpace' : _depth_to_space(),
...@@ -1295,6 +1359,7 @@ _convert_map = { ...@@ -1295,6 +1359,7 @@ _convert_map = {
'Reshape' : _reshape(), 'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(), 'ResizeBilinear' : _resize_bilinear(),
'ResizeBicubic' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(),
'ResizeNearestNeighbor' : _resize_nearest_neighbor(),
'ReverseV2' : _reverse_v2(), 'ReverseV2' : _reverse_v2(),
'RightShift' : AttrCvt('right_shift'), 'RightShift' : AttrCvt('right_shift'),
'Round' : AttrCvt('round'), 'Round' : AttrCvt('round'),
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import print_function from __future__ import print_function
import os import os
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util from tvm.contrib import util
...@@ -35,12 +34,12 @@ class TFParser(object): ...@@ -35,12 +34,12 @@ class TFParser(object):
-------- --------
.. code-block:: python .. code-block:: python
parser = TfParser(model_dir) parser = TFParser(model_dir)
graph = parser.parse() graphdef = parser.parse()
# graph is related graphdef of the model
""" """
def __init__(self, model_dir): def __init__(self, model_dir):
from tensorflow.core.framework import graph_pb2
self._tmp_dir = util.tempdir() self._tmp_dir = util.tempdir()
self._model_dir = model_dir self._model_dir = model_dir
self._graph = graph_pb2.GraphDef() self._graph = graph_pb2.GraphDef()
...@@ -96,6 +95,7 @@ class TFParser(object): ...@@ -96,6 +95,7 @@ class TFParser(object):
from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"InputConfiguration: Unable to import tensorflow which is " "InputConfiguration: Unable to import tensorflow which is "
......
...@@ -949,8 +949,8 @@ def test_forward_multi_output(): ...@@ -949,8 +949,8 @@ def test_forward_multi_output():
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
####################################################################### #######################################################################
# Resize Bilinear # Resize Bilinear, Nearest_Neighbor
# --------------- # ---------------------------------
def _test_resize_bilinear(in_shape, to_shape, align_corners): def _test_resize_bilinear(in_shape, to_shape, align_corners):
""" One iteration of resize bilinear """ """ One iteration of resize bilinear """
...@@ -980,13 +980,31 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners): ...@@ -980,13 +980,31 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners):
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
def test_forward_resize_bilinear():
""" Resize Bilinear """ def _test_resize_nearest_neighbor(in_shape, to_shape):
""" One iteration of resize nearest neighbor """
data = np.random.uniform(size=in_shape).astype('float32')
shape_data = np.array(to_shape).astype('int32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(
shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor')
compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
def test_forward_resize():
""" Resize Bilinear, Nearest_Neighbor """
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
_test_resize_bilinear_from_tensor((4, 16, 32, 32), False) _test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
_test_resize_bilinear_from_tensor((6, 32, 50, 50), True) _test_resize_bilinear_from_tensor((6, 32, 50, 50), True)
_test_resize_nearest_neighbor((6, 32, 64, 64), [20, 20])
####################################################################### #######################################################################
# BroadcastTo # BroadcastTo
...@@ -1081,6 +1099,39 @@ def test_forward_crop(): ...@@ -1081,6 +1099,39 @@ def test_forward_crop():
####################################################################### #######################################################################
# CropAndResize
# -------------
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"):
image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, image.shape, name="in_data")
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size,
method=method, name="crop_and_resize")
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
def test_forward_crop_and_resize():
""" CropAndResize """
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
_test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
[0, 1],
[5, 5])
_test_forward_crop_and_resize([3, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8],[0, 0, 1, 1]],
[0, 1, 2],
[3, 3])
_test_forward_crop_and_resize([3, 11, 11, 3],
[[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]],
[2, 1, 0],
[3, 3])
#######################################################################
# LSTM # LSTM
# ---- # ----
...@@ -1979,10 +2030,11 @@ if __name__ == '__main__': ...@@ -1979,10 +2030,11 @@ if __name__ == '__main__':
test_forward_depthtospace() test_forward_depthtospace()
test_forward_squeeze() test_forward_squeeze()
test_forward_pack() test_forward_pack()
test_forward_resize_bilinear()
test_forward_broadcast_to() test_forward_broadcast_to()
test_forward_fill() test_forward_fill()
test_forward_crop() test_forward_crop()
test_forward_resize()
test_forward_crop_and_resize()
test_forward_pad() test_forward_pad()
test_forward_unpack() test_forward_unpack()
test_forward_gather() test_forward_gather()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment