Unverified Commit 7bc0b27e by Yao Wang Committed by GitHub

[Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement (#5020)

* Improve TF control flow major logic

* Pass mod into operator convert function

* Fix LoopBound

* Add more control flow tests

* Add two test cases for stridedslice

* Fix docstring

* Fix lint

* Fix import

* Fix test assert

* Minor fix conv3d

* Add more comments

* Fix for dilation2d

* Change newly added atan

* Change newly added unravel
parent a422589c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=broad-except
"""Common utilities""" """Common utilities"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import logging import logging
...@@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False): ...@@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False):
return channels return channels
def infer_value(input_val, params): def infer_value(input_val, params, mod=None):
"""A hack for getting the value of an expression by evaluating a """A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor. whose output shape depends on the value of a tensor.
""" """
# pylint: disable=import-outside-toplevel try:
from tvm.contrib import graph_runtime # TODO(kevinthesun): Use VM for all cases.
# Check that all free variables have associated parameters. # pylint: disable=import-outside-toplevel
assert all(var.name_hint in params.keys() for var in analysis.free_vars( from tvm.contrib import graph_runtime
input_val)), "All inputs to infer must be available in params." # Check that all free variables have associated parameters.
func = _function.Function(analysis.free_vars(input_val), input_val) assert all(var.name_hint in params.keys() for var in analysis.free_vars(
with tvm.relay.build_config(opt_level=0): input_val)), "All inputs to infer must be available in params."
graph, lib, params = tvm.relay.build(func, target="llvm", params=params) func = _function.Function(analysis.free_vars(input_val), input_val)
ctx = tvm.cpu(0) with tvm.relay.build_config(opt_level=0):
m = graph_runtime.create(graph, lib, ctx) graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
m.set_input(**params) ctx = tvm.cpu(0)
m.run() m = graph_runtime.create(graph, lib, ctx)
return m.get_output(0) m.set_input(**params)
m.run()
return m.get_output(0)
except Exception:
if isinstance(mod, IRModule):
mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = exc.evaluate()(*inputs)
return result
def infer_value_simulated(input_val, params): def infer_value_simulated(input_val, params):
......
...@@ -23,17 +23,17 @@ from collections import defaultdict ...@@ -23,17 +23,17 @@ from collections import defaultdict
# Numpy support # Numpy support
import numpy as np import numpy as np
import tvm import tvm
from tvm.ir import IRModule from tvm.ir import IRModule
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.analysis import structural_hash as s_hash
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function from .. import function as _function
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator from ..expr_functor import ExprMutator, ExprVisitor
from .common import AttrCvt, get_relay_op from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
...@@ -92,43 +92,40 @@ def _get_list_param(params, input_node): ...@@ -92,43 +92,40 @@ def _get_list_param(params, input_node):
def _get_tuple_param(params, input_node): def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node)) return tuple(_get_param(params, input_node))
def _need_module_for_shape_inference(op):
return op in ['StridedSlice']
def _need_prelude_for_shape_inference(op): def _need_prelude_for_shape_inference(op):
return "TensorArray" in op return "TensorArray" in op
def _rsqrt(): def _rsqrt():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
inputs.append(tvm.relay.const(-0.5, attr['T'].name)) inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return AttrCvt(op_name="power")(inputs, attr) return AttrCvt(op_name="power")(inputs, attr)
return _impl return _impl
def _argx(func, func_name): def _argx(func, func_name):
""" A common wrapper for argmin and argmax operations """ """ A common wrapper for argmin and argmax operations """
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
try: try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We # In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant. # support the case where it inputs from a scalar constant.
axis_input_value = [_get_num_param(params, inputs[1])] axis_input_value = [_get_num_param(params, inputs[1])]
except (IndexError, KeyError): except (IndexError, KeyError):
raise TypeError( \ raise TypeError(
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) "Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
return func(inputs[0], axis=axis_input_value, keepdims=False) return func(inputs[0], axis=axis_input_value, keepdims=False)
return _impl return _impl
def _elemwise(name): def _elemwise(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return get_relay_op(name)(*inputs) return get_relay_op(name)(*inputs)
return _impl return _impl
def _pool3d(name): def _pool3d(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
if attr['data_format'] == 'NDHWC': if attr['data_format'] == 'NDHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3])
...@@ -141,10 +138,9 @@ def _pool3d(name): ...@@ -141,10 +138,9 @@ def _pool3d(name):
'is not valid.' 'is not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['data_format'] == "NDHWC": if attr['data_format'] == "NDHWC":
input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)] input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
attr['data_format'] = "NCDHW" attr['data_format'] = "NCDHW"
attr['_input_shapes'][inputs[0]] = input_shape
flip_layout = True flip_layout = True
attr['padding'] = attr['padding'].decode("utf-8") attr['padding'] = attr['padding'].decode("utf-8")
...@@ -188,12 +184,12 @@ def _pool3d(name): ...@@ -188,12 +184,12 @@ def _pool3d(name):
return _impl return _impl
def _pooling(name): def _pooling(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
if attr['data_format'] == 'NHWC': if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
...@@ -207,7 +203,7 @@ def _pooling(name): ...@@ -207,7 +203,7 @@ def _pooling(name):
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
tmp_shape = attr['_input_shapes'][inputs[0]] tmp_shape = _infer_shape(inputs[0], mod)
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW" attr['data_format'] = "NCHW"
...@@ -256,17 +252,16 @@ def _pooling(name): ...@@ -256,17 +252,16 @@ def _pooling(name):
return _impl return _impl
def _conv(opname): def _conv(opname):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
# transform to NCHW for TVM backend compatible and set 'flip_layout' # transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC # to have output flip back to NHWC
tmp_shape = attr['_input_shapes'][inputs[2]] tmp_shape = _infer_shape(inputs[2], mod)
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['_input_shapes'][inputs[2]] = tmp_shape
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW' attr['data_format'] = 'NCHW'
...@@ -281,19 +276,19 @@ def _conv(opname): ...@@ -281,19 +276,19 @@ def _conv(opname):
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
# NCHW Layout require weights transpose # NCHW Layout require weights transpose
weights_shape = _infer_shape(inputs[1])
if attr['data_format'] == 'NCHW': if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = weights_shape
if opname in ['conv', 'conv_transpose']: if opname in ['conv', 'conv_transpose']:
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else: else:
tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['_input_shapes'][inputs[1]] = tmp_shape weights_shape = tmp_shape
input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]
input_shape = _infer_shape(inputs_data)
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
...@@ -390,8 +385,8 @@ def _conv(opname): ...@@ -390,8 +385,8 @@ def _conv(opname):
# Ignore the new attributes from TF2.0, for now. # Ignore the new attributes from TF2.0, for now.
out = AttrCvt( out = AttrCvt(
op_name=_dimension_picker('conv', \ op_name=_dimension_picker('conv',
surfix="_transpose" if opname == 'conv_transpose' else ""), surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'], ignores=['explicit_paddings'],
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
...@@ -414,12 +409,12 @@ def _conv(opname): ...@@ -414,12 +409,12 @@ def _conv(opname):
# Dilation2d # Dilation2d
def _dilation2d(): def _dilation2d():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if 'data_format' not in attr: if 'data_format' not in attr:
attr['data_format'] = 'NHWC' attr['data_format'] = 'NHWC'
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
weights_shape = attr['_input_shapes'][inputs[1]] weights_shape = _infer_shape(inputs[1], mod)
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
...@@ -497,21 +492,21 @@ def _dilation2d(): ...@@ -497,21 +492,21 @@ def _dilation2d():
def _conv3d(opname): def _conv3d(opname):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
attr['data_format'] = attr['data_format'].decode("utf-8") attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False flip_layout = False
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
# NCDHW Layout require weights transpose # NCDHW Layout require weights transpose
weights_shape = _infer_shape(inputs[1], mod)
if attr['data_format'] == 'NCDHW': if attr['data_format'] == 'NCDHW':
tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = weights_shape
tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
attr['_input_shapes'][inputs[1]] = tmp_shape weights_shape = tmp_shape
input_shape = attr['_input_shapes'][inputs_data] input_shape = _infer_shape(inputs_data, mod)
weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC": if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
...@@ -532,7 +527,7 @@ def _conv3d(opname): ...@@ -532,7 +527,7 @@ def _conv3d(opname):
attr['channels'] = weights_shape[3] attr['channels'] = weights_shape[3]
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] =\ attr['dilations'] = \
(attr['dilations'][1], attr['dilations'][2], attr['dilations'][3]) (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW': elif attr['data_format'] == 'NCDHW':
...@@ -544,7 +539,7 @@ def _conv3d(opname): ...@@ -544,7 +539,7 @@ def _conv3d(opname):
attr['channels'] = weights_shape[1] attr['channels'] = weights_shape[1]
if 'dilations' in attr: if 'dilations' in attr:
attr['dilations'] =\ attr['dilations'] = \
(attr['dilations'][2], attr['dilations'][3], attr['dilations'][4]) (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else: else:
...@@ -599,8 +594,8 @@ def _conv3d(opname): ...@@ -599,8 +594,8 @@ def _conv3d(opname):
# Ignore the new attributes from TF2.0, for now. # Ignore the new attributes from TF2.0, for now.
out = AttrCvt( out = AttrCvt(
op_name=_dimension_picker('conv', \ op_name=_dimension_picker('conv',
surfix="_transpose" if opname == 'conv_transpose' else ""), surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'], ignores=['explicit_paddings'],
transforms={ transforms={
'kernel_shape': 'kernel_size', 'kernel_shape': 'kernel_size',
...@@ -621,19 +616,19 @@ def _conv3d(opname): ...@@ -621,19 +616,19 @@ def _conv3d(opname):
return _impl return _impl
def _decode_image(): def _decode_image():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0] return inputs[0]
return _impl return _impl
def _unravel_index(): def _unravel_index():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return _op.unravel_index(inputs[0], inputs[1]) return _op.unravel_index(inputs[0], inputs[1])
return _impl return _impl
def _crop_and_resize(): def _crop_and_resize():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try: try:
...@@ -654,12 +649,12 @@ def _crop_and_resize(): ...@@ -654,12 +649,12 @@ def _crop_and_resize():
return _impl return _impl
def _cast(): def _cast():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return inputs[0].astype(attr['DstT'].name) return inputs[0].astype(attr['DstT'].name)
return _impl return _impl
def _expand_dims(): def _expand_dims():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
dim_input = inputs.pop(1) dim_input = inputs.pop(1)
axis = _get_num_param(params, dim_input) axis = _get_num_param(params, dim_input)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
...@@ -667,14 +662,16 @@ def _expand_dims(): ...@@ -667,14 +662,16 @@ def _expand_dims():
return _impl return _impl
def _resize(method): def _resize(method):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
output_shape0 = attr['_output_shapes'][0] if attr['_output_shapes'][0] is not None:
# Dynamic size models might have _output_shapes attr equal to [None] here size = attr['_output_shapes'][0][1:3]
size = output_shape0[1:3] if output_shape0 is not None else [-1, -1] # Important that the size is defined. If an axis is not, we need to infer what
# Important that the size is defined. If an axis is not, we need to infer what # the shape should be.
# the shape should be. if -1 in size:
if -1 in size: size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
else:
size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
attr['size'] = size attr['size'] = size
inputs.pop(1) inputs.pop(1)
# NHWC # NHWC
...@@ -691,7 +688,7 @@ def _resize(method): ...@@ -691,7 +688,7 @@ def _resize(method):
return _impl return _impl
def _check_numerics(): def _check_numerics():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# Making a copy node assuming no need to verify # Making a copy node assuming no need to verify
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
return _impl return _impl
...@@ -704,7 +701,7 @@ def _assert(): ...@@ -704,7 +701,7 @@ def _assert():
return _no_op() return _no_op()
def _no_op(): def _no_op():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# ToDo: This should really be an op that returns nothing, which could # ToDo: This should really be an op that returns nothing, which could
# be represented as an empty tuple. It turns out that TVM # be represented as an empty tuple. It turns out that TVM
# infrastructure doesn't like running functions that return None and # infrastructure doesn't like running functions that return None and
...@@ -716,7 +713,7 @@ def _no_op(): ...@@ -716,7 +713,7 @@ def _no_op():
return _impl return _impl
def _matmul(): def _matmul():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
channels = _infer_channels(inputs[1], not attr['transpose_b']) channels = _infer_channels(inputs[1], not attr['transpose_b'])
if attr['transpose_a']: if attr['transpose_a']:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
...@@ -729,11 +726,11 @@ def _matmul(): ...@@ -729,11 +726,11 @@ def _matmul():
return _impl return _impl
def _batch_matmul(): def _batch_matmul():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
input_x = inputs[0] input_x = inputs[0]
input_y = inputs[1] input_y = inputs[1]
orig_shape_x = attr['_input_shapes'][input_x] orig_shape_x = _infer_shape(input_x, mod)
orig_shape_y = attr['_input_shapes'][input_y] orig_shape_y = _infer_shape(input_y, mod)
# reshape n-dimensional batch matmul into 3d # reshape n-dimensional batch matmul into 3d
if len(orig_shape_x) > 3: if len(orig_shape_x) > 3:
...@@ -761,12 +758,12 @@ def _batch_matmul(): ...@@ -761,12 +758,12 @@ def _batch_matmul():
return _impl return _impl
def _identity(): def _identity():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return inputs[0] return inputs[0]
return _impl return _impl
def _concatV2(): def _concatV2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
pop_node = inputs.pop(len(inputs)-1) pop_node = inputs.pop(len(inputs)-1)
axis = int(_get_num_param(params, pop_node)) axis = int(_get_num_param(params, pop_node))
return AttrCvt( return AttrCvt(
...@@ -775,7 +772,7 @@ def _concatV2(): ...@@ -775,7 +772,7 @@ def _concatV2():
return _impl return _impl
def _concat(): def _concat():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
pop_node = inputs.pop(0) pop_node = inputs.pop(0)
axis = int(_get_num_param(params, pop_node)) axis = int(_get_num_param(params, pop_node))
return AttrCvt( return AttrCvt(
...@@ -784,7 +781,7 @@ def _concat(): ...@@ -784,7 +781,7 @@ def _concat():
return _impl return _impl
def _pack(): def _pack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = int(attr["axis"]) axis = int(attr["axis"])
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis) return _op.concatenate(inputs_reshaped, axis)
...@@ -854,7 +851,7 @@ def _tensor_array_concat(): ...@@ -854,7 +851,7 @@ def _tensor_array_concat():
return _impl return _impl
def _tile(): def _tile():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
reps = _get_list_param(params, inputs.pop()) reps = _get_list_param(params, inputs.pop())
new_input = [] new_input = []
new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0))
...@@ -866,7 +863,7 @@ def _tile(): ...@@ -866,7 +863,7 @@ def _tile():
return _impl return _impl
def _slice(): def _slice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
try: try:
begin = _get_list_param(params, inputs[1]) begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError): except (IndexError, KeyError, AttributeError):
...@@ -874,21 +871,26 @@ def _slice(): ...@@ -874,21 +871,26 @@ def _slice():
try: try:
size = _get_list_param(params, inputs[2]) size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError): except (IndexError, KeyError, AttributeError):
size = _infer_value(inputs[2], params).asnumpy().tolist()[0] # Handle symbolic size
data_shape = attr['_input_shapes'][inputs[0]] try:
size = _infer_value(inputs[2], params).asnumpy().tolist()[0]
except Exception:
size = inputs[2]
data_shape = _infer_shape(inputs[0], mod)
data_dim = len(data_shape) data_dim = len(data_shape)
end = size end = size
for i in range(data_dim): if not isinstance(end, (_expr.Call, _expr.Var)):
if size[i] == -1: for i in range(data_dim):
end[i] = data_shape[i] if size[i] == -1:
else: end[i] = data_shape[i]
end[i] += begin[i] else:
end[i] += begin[i]
return _op.strided_slice(inputs[0], begin=begin, end=end) return _op.strided_slice(inputs[0], begin=begin, end=end)
return _impl return _impl
def _reshape(): def _reshape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
pop_node = inputs.pop(1) pop_node = inputs.pop(1)
try: try:
...@@ -917,7 +919,7 @@ def _reshape(): ...@@ -917,7 +919,7 @@ def _reshape():
def _depth_to_space(): def _depth_to_space():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
block_size = int(attr['block_size']) block_size = int(attr['block_size'])
layout = attr['data_format'].decode("utf-8") layout = attr['data_format'].decode("utf-8")
return _op.nn.depth_to_space(inputs[0], block_size, layout) return _op.nn.depth_to_space(inputs[0], block_size, layout)
...@@ -926,7 +928,7 @@ def _depth_to_space(): ...@@ -926,7 +928,7 @@ def _depth_to_space():
def _space_to_depth(): def _space_to_depth():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
block_size = int(attr['block_size']) block_size = int(attr['block_size'])
layout = attr['data_format'].decode("utf-8") layout = attr['data_format'].decode("utf-8")
return _op.nn.space_to_depth(inputs[0], block_size, layout) return _op.nn.space_to_depth(inputs[0], block_size, layout)
...@@ -935,7 +937,7 @@ def _space_to_depth(): ...@@ -935,7 +937,7 @@ def _space_to_depth():
def _bias_add(): def _bias_add():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# Must expand for proper broadcasting in NCHW. # Must expand for proper broadcasting in NCHW.
if attr['data_format'].decode("utf-8") == 'NCHW': if attr['data_format'].decode("utf-8") == 'NCHW':
bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1))
...@@ -945,7 +947,7 @@ def _bias_add(): ...@@ -945,7 +947,7 @@ def _bias_add():
return _impl return _impl
def _broadcast_to(): def _broadcast_to():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if isinstance(inputs[1], _expr.Var): if isinstance(inputs[1], _expr.Var):
shape = params[inputs[1].name_hint] shape = params[inputs[1].name_hint]
else: else:
...@@ -955,7 +957,7 @@ def _broadcast_to(): ...@@ -955,7 +957,7 @@ def _broadcast_to():
return _impl return _impl
def _squeeze(): def _squeeze():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if len(attr['squeeze_dims']) == 0: if len(attr['squeeze_dims']) == 0:
attr['squeeze_dims'] = None attr['squeeze_dims'] = None
return AttrCvt( return AttrCvt(
...@@ -965,7 +967,7 @@ def _squeeze(): ...@@ -965,7 +967,7 @@ def _squeeze():
return _impl return _impl
def _fused_batch_norm(): def _fused_batch_norm():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance) # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# Relay: (data, gamma, beta, moving_mean, moving_varience) # Relay: (data, gamma, beta, moving_mean, moving_varience)
assert len(inputs) == 5 assert len(inputs) == 5
...@@ -1001,7 +1003,7 @@ def _fused_batch_norm(): ...@@ -1001,7 +1003,7 @@ def _fused_batch_norm():
return _impl return _impl
def _batch_norm(): def _batch_norm():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# Rearrange inputs from # Rearrange inputs from
# (data, moving_mean, moving_variance, beta, gamma) # (data, moving_mean, moving_variance, beta, gamma)
# to # to
...@@ -1023,14 +1025,15 @@ def _batch_norm(): ...@@ -1023,14 +1025,15 @@ def _batch_norm():
return _impl return _impl
def _relu6(): def _relu6():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return _op.clip(inputs[0], a_min=0, a_max=6) return _op.clip(inputs[0], a_min=0, a_max=6)
return _impl return _impl
def _shape(): def _shape():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
is_symbolic_shape = False is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]: input_shape = _infer_shape(inputs[0], mod)
for axis in input_shape:
if not isinstance(axis, (int, tvm.tir.IntImm)): if not isinstance(axis, (int, tvm.tir.IntImm)):
is_symbolic_shape = True is_symbolic_shape = True
break break
...@@ -1038,13 +1041,13 @@ def _shape(): ...@@ -1038,13 +1041,13 @@ def _shape():
if is_symbolic_shape: if is_symbolic_shape:
ret = _op.shape_of(inputs[0], dtype='int32') ret = _op.shape_of(inputs[0], dtype='int32')
else: else:
ret = np.array(attr['_input_shapes'][inputs[0]], dtype='int32') ret = np.array(input_shape, dtype='int32')
return ret return ret
return _impl return _impl
def _fill(): def _fill():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
output_shape = attr['_output_shapes'][0] output_shape = attr['_output_shapes'][0]
# Output shape must be defined to avoid errors. If any axis is not, we must # Output shape must be defined to avoid errors. If any axis is not, we must
# try to compute its shape. # try to compute its shape.
...@@ -1058,7 +1061,7 @@ def _fill(): ...@@ -1058,7 +1061,7 @@ def _fill():
return _impl return _impl
def _lrn(): def _lrn():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
attr_new = {} attr_new = {}
depth_radius = attr.get('depth_radius', 5) depth_radius = attr.get('depth_radius', 5)
size = (depth_radius * 2) + 1 size = (depth_radius * 2) + 1
...@@ -1071,7 +1074,7 @@ def _lrn(): ...@@ -1071,7 +1074,7 @@ def _lrn():
return _impl return _impl
def _sum(): def _sum():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = _get_tuple_param(params, inputs[1]) axis = _get_tuple_param(params, inputs[1])
return AttrCvt( return AttrCvt(
op_name='sum', op_name='sum',
...@@ -1081,7 +1084,7 @@ def _sum(): ...@@ -1081,7 +1084,7 @@ def _sum():
return _impl return _impl
def _reduce(op): def _reduce(op):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = _get_list_param(params, inputs[1]) axis = _get_list_param(params, inputs[1])
axis = tuple(axis) axis = tuple(axis)
return AttrCvt( return AttrCvt(
...@@ -1092,13 +1095,13 @@ def _reduce(op): ...@@ -1092,13 +1095,13 @@ def _reduce(op):
return _impl return _impl
def _square(): def _square():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return _op.multiply(inputs[0], inputs[0]) return _op.multiply(inputs[0], inputs[0])
return _impl return _impl
def _gather(): def _gather():
"GatherV2, Gather" "GatherV2, Gather"
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if len(inputs) > 2: if len(inputs) > 2:
axis = _get_num_param(params, inputs.pop(2)) axis = _get_num_param(params, inputs.pop(2))
else: else:
...@@ -1115,7 +1118,7 @@ def _gather(): ...@@ -1115,7 +1118,7 @@ def _gather():
def _gather_nd(): def _gather_nd():
"""GatherNd""" """GatherNd"""
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return AttrCvt(op_name="gather_nd", return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\ ignores=['Tindices', 'Tparams',\
'Taxis', '_class'])(inputs, attr) 'Taxis', '_class'])(inputs, attr)
...@@ -1136,7 +1139,7 @@ def _stridedSlice(): ...@@ -1136,7 +1139,7 @@ def _stridedSlice():
ellipsis_mask = int(attr.get('ellipsis_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0))
new_axis_mask = int(attr.get('new_axis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]] data_shape = _infer_shape(inputs[0], mod)
data_dim = len(data_shape) data_dim = len(data_shape)
stride_dim = len(stride) stride_dim = len(stride)
...@@ -1164,8 +1167,8 @@ def _stridedSlice(): ...@@ -1164,8 +1167,8 @@ def _stridedSlice():
mask = 1 << index mask = 1 << index
if mask & ellipsis_mask: if mask & ellipsis_mask:
#Identify the end index for applying ellipsis_mask #Identify the end index for applying ellipsis_mask
to_index = min(((data_dim - (stride_dim-index)) + 1 \ to_index = min(((data_dim - (stride_dim-index)) + 1
+ new_axes_after_ellipsis), data_dim) + new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index): for i in range(final_index, to_index):
m_begin[final_index] = 0 m_begin[final_index] = 0
m_end[final_index] = data_shape[final_index] m_end[final_index] = data_shape[final_index]
...@@ -1205,7 +1208,7 @@ def _stridedSlice(): ...@@ -1205,7 +1208,7 @@ def _stridedSlice():
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out, mod=mod) out_shape = _infer_shape(out, mod)
if not fshape_indices: if not fshape_indices:
fshape_indices = range(len(out_shape)) fshape_indices = range(len(out_shape))
...@@ -1220,12 +1223,25 @@ def _stridedSlice(): ...@@ -1220,12 +1223,25 @@ def _stridedSlice():
final_output.append(out_shape[gather_index]) final_output.append(out_shape[gather_index])
if not final_output: if not final_output:
return out if not shrink_axis_mask:
return _op.reshape(out, newshape=tuple(final_output)) ret = out
else:
final_shape = []
for dim in out_shape:
if dim != 1:
final_shape.append(dim)
if len(final_shape) == 0:
ret = _op.squeeze(out)
else:
# We need reshape to handle dynamic shape.
ret = _op.reshape(out, newshape=tuple(final_shape))
else:
ret = _op.reshape(out, newshape=tuple(final_output))
return ret
return _impl return _impl
def _pad(name): def _pad(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
padlist = _get_param(params, inputs[1]) padlist = _get_param(params, inputs[1])
paddings = tuple(tuple(l) for l in padlist) paddings = tuple(tuple(l) for l in padlist)
attr['pad_width'] = paddings attr['pad_width'] = paddings
...@@ -1240,7 +1256,7 @@ def _pad(name): ...@@ -1240,7 +1256,7 @@ def _pad(name):
return _impl return _impl
def _mirror_pad(): def _mirror_pad():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
padlist = _get_param(params, inputs[1]) padlist = _get_param(params, inputs[1])
paddings = tuple(tuple(l) for l in padlist) paddings = tuple(tuple(l) for l in padlist)
attr['pad_width'] = paddings attr['pad_width'] = paddings
...@@ -1253,7 +1269,7 @@ def _mirror_pad(): ...@@ -1253,7 +1269,7 @@ def _mirror_pad():
return _impl return _impl
def _transpose(): def _transpose():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
# If perm is not specified, axes is left empty, # If perm is not specified, axes is left empty,
# otherwise its value is get from params # otherwise its value is get from params
try: try:
...@@ -1264,21 +1280,21 @@ def _transpose(): ...@@ -1264,21 +1280,21 @@ def _transpose():
return _impl return _impl
def _where(): def _where():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if len(inputs) == 1: if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr) return AttrCvt(op_name="argwhere")(inputs, attr)
return AttrCvt(op_name="where")(inputs, attr) return AttrCvt(op_name="where")(inputs, attr)
return _impl return _impl
def _clip_by_value(): def _clip_by_value():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
a_min = _get_num_param(params, inputs[1]) a_min = _get_num_param(params, inputs[1])
a_max = _get_num_param(params, inputs[2]) a_max = _get_num_param(params, inputs[2])
return _op.clip(inputs[0], a_min=a_min, a_max=a_max) return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl return _impl
def _reverse_v2(): def _reverse_v2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = _get_num_param(params, inputs[1]) axis = _get_num_param(params, inputs[1])
return AttrCvt( return AttrCvt(
op_name="reverse", op_name="reverse",
...@@ -1287,8 +1303,8 @@ def _reverse_v2(): ...@@ -1287,8 +1303,8 @@ def _reverse_v2():
return _impl return _impl
def _rank(): def _rank():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
name = attr["_node_name"] name = attr["_node_name"]
params[name] = tvm.nd.array([len(input_shape)]) params[name] = tvm.nd.array([len(input_shape)])
...@@ -1298,31 +1314,61 @@ def _rank(): ...@@ -1298,31 +1314,61 @@ def _rank():
return _impl return _impl
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
start = _get_param(params, inputs[0])[0] try:
start = _get_param(params, inputs[0])[0]
except (IndexError, KeyError, AttributeError):
try:
start = _infer_value(inputs[1], params).asnumpy().tolist()
start = start if not isinstance(start, list) else start[0]
except Exception:
# Symbolic start
start = inputs[0]
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant):
limit = _get_param(params, inputs[1])[0] limit = _get_param(params, inputs[1])[0]
else: else:
if any(['Rank' in param for param in params]): if any(['Rank' in param for param in params]):
limit = params.pop('Rank').asnumpy()[0] limit = params.pop('Rank').asnumpy()[0]
else: else:
limit = _infer_value_simulated(inputs[1], params).asnumpy()[0] try:
delta = _get_param(params, inputs[2])[0] limit = _infer_value(inputs[1], params, mod).asnumpy().tolist()
limit = limit if not isinstance(limit, list) else limit[0]
except Exception:
# Symbolic limit
limit = inputs[1]
try:
delta = _get_param(params, inputs[2])[0]
except (IndexError, KeyError, AttributeError):
try:
delta = _infer_value(inputs[2], params, mod).asnumpy().tolist()
delta = delta if not isinstance(delta, list) else delta[0]
except Exception:
# Symbolic delta
delta = inputs[2]
dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)):
start = _expr.const(start)
if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)):
limit = _expr.const(limit)
if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)):
delta = _expr.const(delta)
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
ignores=['Tidx'], ignores=['Tidx'],
extras={'start': _expr.const(start), extras={'start': start,
"stop": _expr.const(limit), 'stop': limit,
'step': _expr.const(delta), 'step': delta,
'dtype': dtype})([], attr) 'dtype': dtype})([], attr)
return _impl return _impl
def _elu(): def _elu():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
dtype = attr['T'].name dtype = attr['T'].name
alpha = tvm.relay.const(-1.0, dtype) alpha = tvm.relay.const(-1.0, dtype)
return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
...@@ -1330,16 +1376,16 @@ def _elu(): ...@@ -1330,16 +1376,16 @@ def _elu():
return _impl return _impl
def _selu(): def _selu():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
dtype = attr['T'].name dtype = attr['T'].name
alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype)
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
return _impl return _impl
def _mean(): def _mean():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = _get_tuple_param(params, inputs[1]) axis = _get_tuple_param(params, inputs[1])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'}, transforms={'keep_dims': 'keepdims'},
...@@ -1347,7 +1393,7 @@ def _mean(): ...@@ -1347,7 +1393,7 @@ def _mean():
return _impl return _impl
def _broadcast(name): def _broadcast(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return AttrCvt( return AttrCvt(
op_name=name, op_name=name,
ignores=['name', 'incompatible_shape_error', 'Tidx'] ignores=['name', 'incompatible_shape_error', 'Tidx']
...@@ -1356,7 +1402,7 @@ def _broadcast(name): ...@@ -1356,7 +1402,7 @@ def _broadcast(name):
def _split(has_size_vector): def _split(has_size_vector):
# TF documentation https://www.tensorflow.org/api_docs/python/tf/split # TF documentation https://www.tensorflow.org/api_docs/python/tf/split
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
try: try:
# order and number of inputs are different: # order and number of inputs are different:
# if has_size_vector: # if has_size_vector:
...@@ -1379,8 +1425,8 @@ def _split(has_size_vector): ...@@ -1379,8 +1425,8 @@ def _split(has_size_vector):
input_node = inputs[input_node_index] input_node = inputs[input_node_index]
axis_input_value = _get_num_param(params, inputs[input_axis_index]) axis_input_value = _get_num_param(params, inputs[input_axis_index])
except (IndexError, KeyError): except (IndexError, KeyError):
raise TypeError( \ raise TypeError(
"Unsupported argument for split: `axis` and `num_or_size_splits` " \ "Unsupported argument for split: `axis` and `num_or_size_splits` "
"should be constants") "should be constants")
return _op.split(input_node, return _op.split(input_node,
indices_or_sections=indices_or_sections, indices_or_sections=indices_or_sections,
...@@ -1388,35 +1434,31 @@ def _split(has_size_vector): ...@@ -1388,35 +1434,31 @@ def _split(has_size_vector):
return _impl return _impl
def _unpack(): def _unpack():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
input_node = inputs[0] input_node = inputs[0]
axis = attr['axis'] axis = attr['axis']
input_shape = attr['_input_shapes'][input_node] input_shape = _infer_shape(input_node, mod)
axis_length = input_shape[axis] axis_length = input_shape[axis]
if axis_length < 0: if axis_length < 0:
raise TypeError("Unstack with unknown axis length") raise TypeError("Unstack with unknown axis length")
splitted = _op.split(input_node, splitted = _op.split(input_node,
indices_or_sections=axis_length, indices_or_sections=axis_length,
axis=axis) axis=axis)
#name=attr.get('_node_name', 'unstack')) axis = [axis]
if axis == 0:
axis = None
else:
axis = [axis]
return _expr.TupleWrapper( return _expr.TupleWrapper(
_expr.Tuple([_op.squeeze(split_item, axis=axis) \ _expr.Tuple([_op.squeeze(split_item, axis=axis) \
for split_item in splitted]), len(splitted)) for split_item in splitted]), len(splitted))
return _impl return _impl
def _softmax(): def _softmax():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return AttrCvt(op_name='softmax', return AttrCvt(op_name='softmax',
transforms={'axis': ('axis', 1)})([inputs[0]], attr) transforms={'axis': ('axis', 1)})([inputs[0]], attr)
return _impl return _impl
def _softplus(): def _softplus():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
exp_out = AttrCvt('exp')(inputs, attr) exp_out = AttrCvt('exp')(inputs, attr)
inputs.append(tvm.relay.const(1, attr['T'].name)) inputs.append(tvm.relay.const(1, attr['T'].name))
rh = tvm.relay.const(1, attr['T'].name) rh = tvm.relay.const(1, attr['T'].name)
...@@ -1425,8 +1467,12 @@ def _softplus(): ...@@ -1425,8 +1467,12 @@ def _softplus():
return _impl return _impl
def _topk(): def _topk():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
k = int(_get_num_param(params, inputs.pop(1))) k_input = inputs.pop(1)
try:
k = int(_get_num_param(params, k_input))
except (IndexError, KeyError, AttributeError):
k = int(_infer_value(k_input, params).asnumpy().tolist())
if k < 1: if k < 1:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2') 'Attribute k must be positive in operator TopKV2')
...@@ -1439,28 +1485,39 @@ def _topk(): ...@@ -1439,28 +1485,39 @@ def _topk():
return _impl return _impl
def _floordiv(): def _floordiv():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
assert len(inputs) == 2 assert len(inputs) == 2
return AttrCvt('floor_divide')(inputs, attr) return AttrCvt('floor_divide')(inputs, attr)
return _impl return _impl
def _floormod(): def _floormod():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
assert len(inputs) == 2 assert len(inputs) == 2
return AttrCvt('floor_mod')(inputs, attr) return AttrCvt('floor_mod')(inputs, attr)
return _impl return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
return _impl return _impl
def _space_to_batch_nd(): def _space_to_batch_nd():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
input_node = inputs[0] input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node] input_shape = _infer_shape(input_node, mod)
block_shape = _get_list_param(params, inputs[1]) try:
paddings = _get_list_param(params, inputs[2]) block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params).asnumpy().tolist()
try:
paddings = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
paddings = _infer_value(inputs[2], params).asnumpy()
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, exis=0)
paddings = paddings.tolist()
N = len(input_shape) N = len(input_shape)
M = len(block_shape) M = len(block_shape)
batch = input_shape[0] batch = input_shape[0]
...@@ -1495,18 +1552,29 @@ def _space_to_batch_nd(): ...@@ -1495,18 +1552,29 @@ def _space_to_batch_nd():
def _batch_to_space_nd(): def _batch_to_space_nd():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
input_node = inputs[0] input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node] input_shape = _infer_shape(input_node, mod)
block_shape = _get_list_param(params, inputs[1]) try:
crops = _get_list_param(params, inputs[2]) block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params).asnumpy().tolist()
try:
crops = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
crops = _infer_value(inputs[2], params).asnumpy()
crops = np.squeeze(crops)
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()
M = len(block_shape) M = len(block_shape)
batch = input_shape[0] batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# Reshape input to reshaped of shape: # Reshape input to reshaped of shape:
# [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), # [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape),
# input_shape[1], ..., input_shape[N-1]] # input_shape[1], ..., input_shape[N-1]]
shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] shape1 = block_shape + [batch // np.prod(block_shape)] + list(input_shape[1:])
reshaped = tvm.relay.reshape(input_node, newshape=shape1) reshaped = tvm.relay.reshape(input_node, newshape=shape1)
# Permute dimensions of reshaped to produce permuted of shape # Permute dimensions of reshaped to produce permuted of shape
# [batch / prod(block_shape), input_shape[1], block_shape[0], ..., # [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
...@@ -1541,13 +1609,13 @@ def _batch_to_space_nd(): ...@@ -1541,13 +1609,13 @@ def _batch_to_space_nd():
return _impl return _impl
def _atan2(): def _atan2():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
divide = _elemwise("divide")(inputs, attr, params) divide = _elemwise("divide")(inputs, attr, params, mod)
return get_relay_op("atan")(divide) return get_relay_op("atan")(divide)
return _impl return _impl
def _prod(): def _prod():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
axis = _get_num_param(params, inputs[1]) axis = _get_num_param(params, inputs[1])
keepdims = attr['keep_dims'] keepdims = attr['keep_dims']
return _op.prod(inputs[0], int(axis), keepdims=keepdims) return _op.prod(inputs[0], int(axis), keepdims=keepdims)
...@@ -1555,21 +1623,21 @@ def _prod(): ...@@ -1555,21 +1623,21 @@ def _prod():
def _log1p(): def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
one = tvm.relay.const(1, attr['T'].name) one = tvm.relay.const(1, attr['T'].name)
add_out = get_relay_op('add')(inputs[0], one) add_out = get_relay_op('add')(inputs[0], one)
return get_relay_op('log')(add_out) return get_relay_op('log')(add_out)
return _impl return _impl
def _one_hot(): def _one_hot():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
depth = int(_get_num_param(params, inputs[1])) depth = int(_get_num_param(params, inputs[1]))
dtype = attr['T'].name dtype = attr['T'].name
on_value = _get_num_param(params, inputs[2]) on_value = _get_num_param(params, inputs[2])
off_value = _get_num_param(params, inputs[3]) off_value = _get_num_param(params, inputs[3])
new_inputs = [inputs[0], \ new_inputs = [inputs[0],
tvm.relay.const(on_value, dtype), \ tvm.relay.const(on_value, dtype),
tvm.relay.const(off_value, dtype)] tvm.relay.const(off_value, dtype)]
return AttrCvt('one_hot', return AttrCvt('one_hot',
ignores=['TI'], ignores=['TI'],
...@@ -1577,20 +1645,20 @@ def _one_hot(): ...@@ -1577,20 +1645,20 @@ def _one_hot():
return _impl return _impl
def _squared_difference(): def _squared_difference():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
difference = _op.subtract(inputs[0], inputs[1]) difference = _op.subtract(inputs[0], inputs[1])
return _op.multiply(difference, difference) return _op.multiply(difference, difference)
return _impl return _impl
def _size(): def _size():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
new_attr = attr new_attr = attr
new_attr['out_type'] = attr['out_type'].name new_attr['out_type'] = attr['out_type'].name
return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr) return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr)
return _impl return _impl
def _add_n(): def _add_n():
def _impl(inputs, attr, params): def _impl(inputs, attr, params, mod):
if not isinstance(inputs, tuple): if not isinstance(inputs, tuple):
inputs = list(inputs) inputs = list(inputs)
assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given."
...@@ -1758,7 +1826,7 @@ _convert_map = { ...@@ -1758,7 +1826,7 @@ _convert_map = {
} }
def _LSTMBlockCell(): def _LSTMBlockCell():
def _impl(inputs, in_state_c, in_state_h, attr, params): def _impl(inputs, in_state_c, in_state_h, attr, params, mod):
"""LSTM Block cell. """LSTM Block cell.
Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
...@@ -1787,8 +1855,8 @@ def _LSTMBlockCell(): ...@@ -1787,8 +1855,8 @@ def _LSTMBlockCell():
in_weight = inputs[3] in_weight = inputs[3]
in_bias = inputs[7] in_bias = inputs[7]
forget_bias = attr.pop('forget_bias') forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = _infer_shape(inputs[3], mod)
batch_size, input_size = input_shape[0], input_shape[1] batch_size, input_size = input_shape[0], input_shape[1]
num_hidden_layers = weight_shape[1] num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4 num_hidden = num_hidden_layers // 4
...@@ -1883,7 +1951,7 @@ class RecurrentNetworks(object): ...@@ -1883,7 +1951,7 @@ class RecurrentNetworks(object):
sym : relay.Expr sym : relay.Expr
The returned relay Expr The returned relay Expr
""" """
def _impl(op_name, layer_name, inputs, attrs, params, num_layers): def _impl(op_name, layer_name, inputs, attrs, params, num_layers, mod):
in_state_c_name = layer_name+'_c' in_state_c_name = layer_name+'_c'
in_state_h_name = layer_name+'_h' in_state_h_name = layer_name+'_h'
...@@ -1914,8 +1982,8 @@ class RecurrentNetworks(object): ...@@ -1914,8 +1982,8 @@ class RecurrentNetworks(object):
def _LSTMBlockCellWrapper(inputs, attr, params, def _LSTMBlockCellWrapper(inputs, attr, params,
num_layers, layer): num_layers, layer):
"""LSTM cell warapper to prepare the inputs""" """LSTM cell warapper to prepare the inputs"""
input_shape = attr['_input_shapes'][inputs[0]] input_shape = _infer_shape(inputs[0], mod)
weight_shape = attr['_input_shapes'][inputs[3]] weight_shape = _infer_shape(inputs[3], mod)
batch_size = input_shape[0] batch_size = input_shape[0]
num_hidden = weight_shape[1] // 4 num_hidden = weight_shape[1] // 4
...@@ -1928,13 +1996,13 @@ class RecurrentNetworks(object): ...@@ -1928,13 +1996,13 @@ class RecurrentNetworks(object):
in_state_c = self._nodes[in_state_c_name] in_state_c = self._nodes[in_state_c_name]
in_state_h = self._nodes[in_state_h_name] in_state_h = self._nodes[in_state_h_name]
cur_in_state_c, cur_in_state_h = _get_cur_input_state( \ cur_in_state_c, cur_in_state_h = _get_cur_input_state(
in_state_c, in_state_h, in_state_c, in_state_h,
num_layers, layer, num_layers, layer,
batch_size, num_hidden) batch_size, num_hidden)
output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
cur_in_state_h, cur_in_state_h,
attr, params) attr, params, mod)
return output, out_state, in_state_c, in_state_h return output, out_state, in_state_c, in_state_h
sym, cur_out_state, in_state_c, in_state_h = \ sym, cur_out_state, in_state_c, in_state_h = \
...@@ -1948,7 +2016,7 @@ class RecurrentNetworks(object): ...@@ -1948,7 +2016,7 @@ class RecurrentNetworks(object):
return sym return sym
return _impl return _impl
def process_op(self, op_name, inputs, attrs, params): def process_op(self, op_name, inputs, attrs, params, mod):
"""Process recurrent layer operators. """Process recurrent layer operators.
List '_recurrent_ops_layer_map' map each Layer based operators with its List '_recurrent_ops_layer_map' map each Layer based operators with its
...@@ -1998,7 +2066,7 @@ class RecurrentNetworks(object): ...@@ -1998,7 +2066,7 @@ class RecurrentNetworks(object):
num_layers += 1 num_layers += 1
sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
params, num_layers) params, num_layers, mod)
return sym return sym
# An internal list to contain all the control flow primitives used in Tensorflow # An internal list to contain all the control flow primitives used in Tensorflow
...@@ -2051,7 +2119,6 @@ def _in_while_loop(control_flow_node_map, op_name): ...@@ -2051,7 +2119,6 @@ def _in_while_loop(control_flow_node_map, op_name):
return op_name in control_flow_node_map and \ return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name] "LoopCond" in control_flow_node_map[op_name]
class Branch: class Branch:
"""A class contains the components that are used to build up a Relay if """A class contains the components that are used to build up a Relay if
node. node.
...@@ -2133,6 +2200,82 @@ class Branch: ...@@ -2133,6 +2200,82 @@ class Branch:
return self._if return self._if
class LoopBound(ExprVisitor):
"""
When a loop body is create, we get a Relay expression backtracing all
the way back to input node. This will result in lots of unnecessary
expression placed into loop body and compute multiple times. For example,
consider the following tensorflow code:
.. code-block:: python
i = tf.constant(0)
data = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))
slice = tf.strided_slice(data, 0, 512)
def c(i): return tf.less(i, 10)
def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice]
r = tf.while_loop(c, b, [i])
If we directly create recursive function, slice will be placed into function body.
Instead, we recognize whether slice is inside while_loop block and pass it as an
extra loop variable to avoid duplicate computation.
TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function.
"""
def __init__(self, loop_name, hash2tfnode, while_loop_name_set):
ExprVisitor.__init__(self)
self._loop_name = loop_name
self._hash2tfnode = hash2tfnode
self._while_loop_name_set = while_loop_name_set
self.extra_loop_var_names = set()
def _find_parent_loop_name(self, node_name):
"""Find name of direct parent while loop."""
ploop_name = ""
name_prefix = node_name.rsplit('/', 1)[0]
if name_prefix.startswith("^"):
name_prefix = name_prefix[1:]
# To get the name of the direct parent while loop for a given node,
# we iterate all the while loop names inside TensorFlow graph def.
# If we find a loop name with which current node name starts,
# it means current node is under this loop. However, due to nested
# loop, this loop may not be the direct parent while loop of current
# node. We need to keep the longest loop name, which represents the
# innermost while loop corresponding to current node.
for lname in self._while_loop_name_set:
if name_prefix.startswith(lname) and len(ploop_name) < len(lname):
ploop_name = lname
if len(ploop_name) == 0:
ploop_name = name_prefix
return ploop_name
def visit(self, expr):
"""
For each expression in the body, look up the corresponding
TensorFlow node with its structural hash. If the current loop is the
direct parent of this node, we check whether its every input node belongs
to the current loop. If not, we mark this input node as an extra loop
variable to the current loop.
"""
expr_hash = s_hash(expr)
if expr_hash in self._hash2tfnode:
node = self._hash2tfnode[expr_hash]
ploop_name = self._find_parent_loop_name(node.name)
# It is possibel that a node is under nested loop of current loop.
# We only check the direct children of current loop.
if ploop_name == self._loop_name:
for iname in node.input:
iploop_name = self._find_parent_loop_name(iname)
# Use startswith to deal with nested loop
if not iploop_name.startswith(self._loop_name):
if iname not in self.extra_loop_var_names:
self.extra_loop_var_names.add(iname)
super().visit(expr)
class Loop: class Loop:
""" """
A class contains the components that are used to build up a Relay A class contains the components that are used to build up a Relay
...@@ -2189,11 +2332,18 @@ class Loop: ...@@ -2189,11 +2332,18 @@ class Loop:
%6 %6
} }
""" """
def __init__(self): def __init__(self, mod, loop_name, hash2tfnode,
node_map, while_loop_name_set):
self.loop_vars = [] self.loop_vars = []
self.cond = None self.cond = None
self.body = [] self.body = []
self._loop = None self._loop = None
self._mod = mod
self._loop_name = loop_name
self._hash2tfnode = hash2tfnode
self._node_map = node_map
self._while_loop_name_set = while_loop_name_set
self.aligned = False
def _while_loop(self): def _while_loop(self):
"""An internal API to create a Relay recursive call for a matched TF """An internal API to create a Relay recursive call for a matched TF
...@@ -2203,11 +2353,30 @@ class Loop: ...@@ -2203,11 +2353,30 @@ class Loop:
sb = tvm.relay.scope_builder.ScopeBuilder() sb = tvm.relay.scope_builder.ScopeBuilder()
loop_checker = LoopBound(self._loop_name,
self._hash2tfnode,
self._while_loop_name_set)
for body in self.body:
loop_checker.visit(body)
loop_vars = [] loop_vars = []
bind_map = {} bind_map = {}
loop_var_hash_set = set()
for var in self.loop_vars:
loop_var_hash_set.add(s_hash(var))
extra_nodes = []
for extra_loop_var_name in loop_checker.extra_loop_var_names:
extra_loop_var_name = extra_loop_var_name.split(':')[0].split("^")[-1]
extra_node = self._node_map[extra_loop_var_name]
extra_node = extra_node if isinstance(extra_node, _expr.Tuple) else extra_node[0]
if s_hash(extra_node) not in loop_var_hash_set:
self.loop_vars.append(extra_node)
extra_nodes.append(extra_node)
for i, var in enumerate(self.loop_vars): for i, var in enumerate(self.loop_vars):
if not isinstance(var, _expr.Var): if not isinstance(var, _expr.Var):
var_chk = _infer_type(var) var_chk = _infer_type(var, self._mod)
var_type = var_chk.checked_type var_type = var_chk.checked_type
else: else:
var_type = var.type_annotation var_type = var.type_annotation
...@@ -2216,21 +2385,37 @@ class Loop: ...@@ -2216,21 +2385,37 @@ class Loop:
loop_vars.append(v) loop_vars.append(v)
bind_map[var] = v bind_map[var] = v
self.cond = rewrite_subgraph(self.cond, bind_map) self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body] self.body = [rewrite_subgraph(b, bind_map) for b in self.body]
self.body_shape = []
for body in self.body:
current_node = body
shape = _infer_shape(current_node, self._mod)
while not isinstance(shape, (tuple, list)):
current_node = current_node.args[-1]
shape = _infer_shape(current_node, self._mod)
self.body_shape.append(shape)
cond = tvm.relay.op.min(self.cond) cond = tvm.relay.op.min(self.cond)
with sb.if_scope(cond): with sb.if_scope(cond):
sb.ret(wl(*self.body)) extra_args = []
if extra_nodes:
extra_args = list(loop_vars[-len(extra_nodes):])
sb.ret(wl(*list(self.body + extra_args)))
with sb.else_scope(): with sb.else_scope():
sb.ret(tvm.relay.Tuple(loop_vars)) sb.ret(tvm.relay.Tuple(loop_vars))
loop_fn = tvm.relay.Function(loop_vars, sb.get()) loop_fn = tvm.relay.Function(loop_vars, sb.get())
sb = tvm.relay.scope_builder.ScopeBuilder() sb = tvm.relay.scope_builder.ScopeBuilder()
sb.let(wl, loop_fn) sb.let(wl, loop_fn)
sb.ret(wl(*self.loop_vars)) loop_ret = wl(*self.loop_vars)
return sb.get()
sb.ret(loop_ret)
ret = sb.get()
return ret
def while_loop(self): def while_loop(self):
"""Instantiate a while loop if it has not been created yet.""" """Instantiate a while loop if it has not been created yet."""
...@@ -2247,16 +2432,21 @@ class GraphProto(object): ...@@ -2247,16 +2432,21 @@ class GraphProto(object):
""" """
def __init__(self): def __init__(self):
self._nodes = {} self._nodes = {}
self._tf_node_map = {}
self._params = {} self._params = {}
self._input_shapes = {} self._input_shapes = {}
self._output_shapes = {} self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False self._num_rnn_layer = False
self._input_shapes = {} self._input_shapes = {}
self._loops = {} self._loops = {}
self._branches = {} self._branches = {}
self._mod = IRModule({}) self._mod = IRModule({})
self._prelude = Prelude(self._mod) self._prelude = Prelude(self._mod)
self._control_flow_node_map = defaultdict(set)
self._loop_body_order = {}
self._loop_var_order = {}
self._hash2tfnode = {}
self._while_loop_name_set = set()
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef. """Construct relay nodes from tensorflow graph definition - GraphDef.
...@@ -2296,7 +2486,6 @@ class GraphProto(object): ...@@ -2296,7 +2486,6 @@ class GraphProto(object):
params : dict params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights A dict of name: tvm.nd.array pairs, used as pretrained weights
""" """
try: try:
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
except ImportError as e: except ImportError as e:
...@@ -2304,6 +2493,10 @@ class GraphProto(object): ...@@ -2304,6 +2493,10 @@ class GraphProto(object):
"Unable to import tensorflow which is required {}".format(e)) "Unable to import tensorflow which is required {}".format(e))
missing_operators = self._parse_import_prerequisites(graph) missing_operators = self._parse_import_prerequisites(graph)
control_flow_nodes = []
self._in_shape = shape
self._layout = layout
self._graph = graph
if missing_operators: if missing_operators:
freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list] freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
...@@ -2311,13 +2504,24 @@ class GraphProto(object): ...@@ -2311,13 +2504,24 @@ class GraphProto(object):
raise Exception("Graph is not frozen. Provide a frozen graph. " raise Exception("Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops)) "Found operators {}".format(freezed_ops))
raise NotImplementedError( \ raise NotImplementedError(
"The following operators are not implemented: {}".format(missing_operators)) "The following operators are not implemented: {}".format(missing_operators))
control_flow_node_map = defaultdict(set)
for node in graph.node: for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0] node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op) self._control_flow_node_map[node_name_prefix].add(node.op)
self._tf_node_map[node.name] = node
# Parse output_shapes attribute
parsed_attr = self._parse_attr(node.attr)
if '_output_shapes' in parsed_attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in parsed_attr['_output_shapes']]
else:
self._output_shapes[node.name] = [None]
# Parse placeholder and const here since input shape info is required.
if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault':
# Give priority to user argument. # Give priority to user argument.
if shape and node.name in shape: if shape and node.name in shape:
...@@ -2342,120 +2546,53 @@ class GraphProto(object): ...@@ -2342,120 +2546,53 @@ class GraphProto(object):
tensor_value = node.attr['value'].tensor tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = \ self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
self._output_shapes[node.name] = [self._input_shapes[node.name]]
if shape and node.name in shape: if shape and node.name in shape:
warnings.warn("Ignore the passed shape. Shape in graphdef " warnings.warn("Ignore the passed shape. Shape in graphdef "
"will be used for operator %s." % node.name) "will be used for operator %s." % node.name)
# Parse the nodes to re-create TF graph using Relay operators.
for node in graph.node:
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {}
attr = self._parse_attr(node.attr)
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
if node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
for key, value in node.attr.items(): for key, value in node.attr.items():
self._parse_param(key, value, node.name, shape) self._parse_param(key, value, node.name, self._in_shape)
if node.name not in self._nodes: elif node.op in _control_flow_nodes:
raise NotImplementedError( \ # We assume that the direct parent node of Exit is a while loop block
"Const {} couldn't be converted to Param.".format(node.name)) if node.op == "Exit":
self._while_loop_name_set.add(node_name_prefix)
attr = self._parse_attr(node.attr) control_flow_nodes.append(node)
elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # First, parse all control flow nodes.
# Pass the parsed shapes instead # Convert tf.cond to Branch and tf.while_loop to Loop.
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] sorted_cf_nodes = []
current_node_name_prefix = None
# Pass the node name too in attr exits = []
attr["_node_name"] = node.name # Sort control flow nodes to move all Exit nodes to the end
# of corresponding while_loop block.
# Pass the target layout for i, node in enumerate(control_flow_nodes):
attr["_target_layout"] = layout node_name_prefix = node.name.rsplit('/', 1)[0]
if current_node_name_prefix is None or current_node_name_prefix != node_name_prefix:
# Fill shapes for all inputs in a list if node_name_prefix in self._while_loop_name_set:
inputs = [] sorted_cf_nodes.extend(exits)
for i in node.input: exits.clear()
# Some TensorFlow operators internally maintain execution layers current_node_name_prefix = node_name_prefix
# and their output name includes the layer number along with
# graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
# output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
# the number has to be ignored for single-output nodes.
# On the other hand, for multi-output nodes the number is the output index,
# and the lack of the number implies 0.
tensor_name = i.split(':')
node_name = tensor_name[0]
if node_name in self._nodes:
in_sym = self._nodes[node_name]
if isinstance(in_sym, _expr.TupleWrapper):
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
in_sym = [in_sym[tensor_slot]]
input_shape = self._output_shapes[node_name][tensor_slot]
else:
tensor_slot = 0
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym[0])
input_shapes[in_sym[0]] = input_shape
attr['_input_shapes'] = input_shapes
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
attr,
control_flow_node_map)
else:
op = self._convert_operator(node.op, inputs, attr, graph)
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name,
shape=self._params[node.name].shape,
dtype=self._params[node.name].dtype)]
elif isinstance(op, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(op, _expr.Expr):
op = [op]
else:
raise RuntimeError("unexpected type %s" % type(op))
self._nodes[node.name] = op if node.op == "Exit":
exits.append(node)
else:
sorted_cf_nodes.append(node)
# Infer shapes even without specifying "add_shapes=True" if i == len(control_flow_nodes) - 1:
if output_shapes == [None]: sorted_cf_nodes.extend(exits)
out_shapes = [_infer_shape(node_item, self._mod)
for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape: for node in sorted_cf_nodes:
assert self._output_shapes[node.name] == list(shape[node.name]) self._backtrack_construct(node.name)
# Infer shapes if passed explicitly # Second, parse other nodes to re-create TF graph using Relay operators.
node_output = self._nodes[node.name] for node in graph.node:
if shape and (not self._output_shapes[node.name][0] self._backtrack_construct(node.name)
or -1 in self._output_shapes[node.name][0]):
out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
self._output_shapes[node.name] = out_shapes
out = [] out = []
if outputs is None: if outputs is None:
if node.op == "Exit": last_node = graph.node[-1]
op = self._nodes[last_node.name.split(":")[0]]
if last_node.op == "Exit":
out = [op[0].tuple_value] out = [op[0].tuple_value]
else: else:
out = op out = op
...@@ -2620,7 +2757,7 @@ class GraphProto(object): ...@@ -2620,7 +2757,7 @@ class GraphProto(object):
self._out_rnn = [] self._out_rnn = []
self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
self._num_rnn_layer = True self._num_rnn_layer = True
sym = self.rnn.process_op(op_name, inputs, attrs, params) sym = self.rnn.process_op(op_name, inputs, attrs, params, self._mod)
return sym return sym
def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
...@@ -2651,53 +2788,95 @@ class GraphProto(object): ...@@ -2651,53 +2788,95 @@ class GraphProto(object):
""" """
node_name_prefix = node.name.rsplit('/', 1)[0] node_name_prefix = node.name.rsplit('/', 1)[0]
if node.op == "Merge": if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix): if _in_while_loop(self._control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]] op = self._backtrack_construct(node.input[0])
self._loops[node_name_prefix] = Loop() if node_name_prefix not in self._loops:
self._loops[node_name_prefix] = Loop(self._mod,
node_name_prefix,
self._hash2tfnode,
self._nodes,
self._while_loop_name_set)
else: else:
if len(self._branches) == 0: if len(self._branches) == 0:
raise RuntimeError("Cannot find a created " raise RuntimeError("Cannot find a created "
"conditional for merge node") "conditional for merge node")
branch = self._branches[node_name_prefix] branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]] false_br = self._backtrack_construct(node.input[0])
true_br = self._nodes[node.input[1]] true_br = self._backtrack_construct(node.input[1])
assert len(true_br) == 1 assert len(true_br) == 1
assert len(false_br) == 1 assert len(false_br) == 1
branch.true_branch = true_br[0] branch.true_branch = true_br[0]
branch.false_branch = false_br[0] branch.false_branch = false_br[0]
op = [branch.if_node()] op = [branch.if_node()]
if node_name_prefix not in self._while_loop_name_set:
try:
cond_val = np.all(_infer_value(branch.cond, self._params,
self._mod).asnumpy())
if cond_val:
op = [branch.true_branch]
else:
op = [branch.false_branch]
except Exception:
op = [branch.if_node()]
elif node.op == "Exit": elif node.op == "Exit":
loop = self._loops[node_name_prefix] loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')
# TensorFlow has differen naming convention on different # Check whether the order of loop variables aligns
# versions. # with loop body. If not, create new loop variable list
# with correct order.
if not loop.aligned:
loop_vars = []
for i in self._loop_body_order[node_name_prefix]:
for j, k in enumerate(self._loop_var_order[node_name_prefix]):
if k == i:
loop_vars.append(loop.loop_vars[j])
loop.loop_vars = loop_vars
loop.aligned = True
exit_name = node.name.split('/')[-1]
if '_' in exit_name: if '_' in exit_name:
exit_number = int("0" + exit_name[5:]) exit_number = int(exit_name[5:])
else: else:
exit_number = int("0" + exit_name[4:]) exit_number = 0
expr = loop.while_loop() expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number) body_pos = exit_number
for i, j in enumerate(self._loop_body_order[node_name_prefix]):
if exit_number == j:
body_pos = i
break
op = [_expr.TupleGetItem(expr, body_pos)]
elif node.op == "Enter": elif node.op == "Enter":
op = self._nodes[node.input[0]] op = self._backtrack_construct(node.input[0])
elif node.op == "LoopCond": elif node.op == "LoopCond":
op = self._nodes[node.input[0]] op = self._backtrack_construct(node.input[0])
assert len(op) == 1 assert len(op) == 1
self._loops[node_name_prefix].cond = op[0] self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch": elif node.op == "Switch":
op = self._nodes[node.input[0]] op = self._backtrack_construct(node.input[0])
cond = self._backtrack_construct(node.input[1])
assert len(op) == 1 assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix): if _in_while_loop(self._control_flow_node_map, node_name_prefix):
if node_name_prefix not in self._loop_var_order:
self._loop_var_order[node_name_prefix] = []
if node.name.endswith("Switch"):
self._loop_var_order[node_name_prefix].append(0)
else:
self._loop_var_order[node_name_prefix].\
append(int(node.name.split("Switch_")[-1]))
self._loops[node_name_prefix].loop_vars.append(op[0]) self._loops[node_name_prefix].loop_vars.append(op[0])
else: else:
if node_name_prefix not in self._branches: if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch() self._branches[node_name_prefix] = Branch()
chk_op = _infer_type(op[0]) self._branches[node_name_prefix].cond = cond[0]
self._branches[node_name_prefix].cond = chk_op
elif node.op == "NextIteration": elif node.op == "NextIteration":
op = self._nodes[node.input[0]] if node_name_prefix not in self._loop_body_order:
self._loop_body_order[node_name_prefix] = []
if node.name.endswith("NextIteration"):
self._loop_body_order[node_name_prefix].append(0)
else:
self._loop_body_order[node_name_prefix].\
append(int(node.name.split("NextIteration_")[-1]))
op = self._backtrack_construct(node.input[0])
assert len(op) == 1 assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0]) self._loops[node_name_prefix].body.append(op[0])
else: else:
...@@ -2706,7 +2885,6 @@ class GraphProto(object): ...@@ -2706,7 +2885,6 @@ class GraphProto(object):
return op return op
def _convert_operator(self, op_name, inputs, attrs, def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None): graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator. """Convert from Tensorflow operator to relay operator.
...@@ -2741,10 +2919,8 @@ class GraphProto(object): ...@@ -2741,10 +2919,8 @@ class GraphProto(object):
elif op_name in convert_map: elif op_name in convert_map:
if _need_prelude_for_shape_inference(op_name): if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
elif _need_module_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
else: else:
sym = convert_map[op_name](inputs, attrs, self._params) sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
elif op_name in convert_map_rnn: elif op_name in convert_map_rnn:
sym = self._convert_rnn_operator(op_name, inputs, attrs, sym = self._convert_rnn_operator(op_name, inputs, attrs,
...@@ -2754,6 +2930,67 @@ class GraphProto(object): ...@@ -2754,6 +2930,67 @@ class GraphProto(object):
raise NotImplementedError("Operator {} not implemented.".format(op_name)) raise NotImplementedError("Operator {} not implemented.".format(op_name))
return sym return sym
def _backtrack_construct(self, node_name):
"""Convert a specific tensorflow node to relay expression.
If any of its ancestor node is not converted yet, backtrack as
far as input node and covert all nodes on the path.
This is required when parsing control flow nodes, since the parsing
order may not follow the original graph def.
Parameters
----------
node_name : str
Tensorflow node name.
Returns
-------
op : relay.Expr
Converted relay expression
"""
node_name = node_name.split(':')[0].split("^")[-1]
if node_name not in self._nodes:
node = self._tf_node_map[node_name]
attr = self._parse_attr(node.attr)
if node.op in _control_flow_nodes:
attr = self._parse_attr(node.attr)
op = self._convert_control_flow_operator(node, [],
attr,
self._control_flow_node_map)
else:
attr["_output_shapes"] = self._output_shapes[node_name]
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
inputs = []
for iname in node.input:
in_op = self._backtrack_construct(iname)
if isinstance(in_op, _expr.TupleWrapper):
tn = iname.split(':')
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
in_op = in_op[tensor_slot]
else:
in_op = in_op[0]
inputs.append(in_op)
op = self._convert_operator(node.op, inputs, attr, self._graph)
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name,
shape=self._params[node.name].shape,
dtype=self._params[node.name].dtype)]
elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)):
op = [op]
node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0])
self._hash2tfnode[node_hash] = node
self._nodes[node_name] = op
return self._nodes[node_name]
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay. """Load tensorflow graph which is a python tensorflow graph object into relay.
......
...@@ -27,14 +27,16 @@ from tvm import relay ...@@ -27,14 +27,16 @@ from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out): def check_equal(graph, tf_out, input_map=None):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
if input_map is not None:
params.update(input_map)
ex = relay.create_executor('vm', mod=mod) ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params) relay_out = ex.evaluate()(**params)
if isinstance(relay_out, nd.NDArray): if isinstance(relay_out, nd.NDArray):
np.testing.assert_allclose(tf_out, relay_out.asnumpy()) np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else: else:
if not isinstance(tf_out, list): if not isinstance(tf_out, (list, tuple)):
tf_out = [tf_out] tf_out = [tf_out]
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
np.testing.assert_allclose(x, y) np.testing.assert_allclose(x, y)
...@@ -303,9 +305,70 @@ def test_cond_in_loop(): ...@@ -303,9 +305,70 @@ def test_cond_in_loop():
check_equal(graph, tf_out) check_equal(graph, tf_out)
def test_vanilla_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
if __name__ == "__main__": check_equal(graph, tf_out, {dname: np_data})
def test_nested_loop_bound():
graph = tf.Graph()
with graph.as_default():
dshape = (2, 10)
dtype = "float32"
dname = "data"
np_data = np.random.uniform(size=dshape).astype(dtype)
data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
x = tf.slice(data, [1, 4], [1, 4])
outer = x + 5.0
def body(x, y):
res = tf.cond(tf.less(y, 10), lambda: tf.add(
10.0, 20.0), lambda: tf.square(10.0))
def nested_body(nx, ny):
return nx + 1, res + 2.0
def nested_cond(nx, ny):
return tf.less(nx, 15)
nx = tf.constant(0)
ny = tf.constant(0.0)
nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
res = res + nested_res[1]
z = tf.constant(7)
res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
return tf.multiply(res, x * outer), y + 1
y = tf.constant(0)
def condition(x, y):
return tf.less(y, 20)
r = tf.while_loop(condition, body, loop_vars=[x, y])
with tf.Session() as sess:
tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})
check_equal(graph, tf_out, {dname: np_data})
if __name__ == "__main__":
# tf.while_loop # tf.while_loop
test_vanilla_loop() test_vanilla_loop()
test_loop_2_vars() test_loop_2_vars()
...@@ -325,3 +388,5 @@ if __name__ == "__main__": ...@@ -325,3 +388,5 @@ if __name__ == "__main__":
test_nested_cond() test_nested_cond()
test_loop_in_cond() test_loop_in_cond()
test_cond_in_loop() test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()
...@@ -67,13 +67,11 @@ def test_assert_true_var_capture(): ...@@ -67,13 +67,11 @@ def test_assert_true_var_capture():
x_value = np.random.rand() x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None assert sess.run(assert_op, feed_dict={x: x_value}) is None
# ToDo: The frontend converter gets confused here as well, thinking # TODO: The frontend converter notes the output of
# that it needs to be told what x is twice. It also notes the output of
# the graph as a boolean, which is not correct - as you can see above, # the graph as a boolean, which is not correct - as you can see above,
# TF believes that the value of this graph is None. In addition, the # TF believes that the value of this graph is None.
# arity of the translated function should be 1, not 2.
np.testing.assert_allclose(True, np.testing.assert_allclose(True,
run_relay(g, None, x_value, x_value).asnumpy()) run_relay(g, None, x_value).asnumpy())
def test_assert_false(): def test_assert_false():
g = tf.Graph() g = tf.Graph()
......
...@@ -1207,6 +1207,8 @@ def test_forward_stridedslice(): ...@@ -1207,6 +1207,8 @@ def test_forward_stridedslice():
'''test StridedSlice''' '''test StridedSlice'''
_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 1), [0], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((2, 3, 4), [0], [1], [1], 'float32', shrink_axis_mask=8)
_test_stridedslice((3, 4, 3), [1, -1, 0], _test_stridedslice((3, 4, 3), [1, -1, 0],
[4, -5, 3], [2, -1, 1], 'float32') [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [ _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment