# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except # pylint: disable=import-outside-toplevel """TF: Tensorflow frontend.""" import warnings from collections import defaultdict # Numpy support import numpy as np import tvm from tvm.ir import IRModule from tvm.relay.prelude import Prelude from .. import analysis from .. import expr as _expr from .. import function as _function from .. import op as _op from ..expr_functor import ExprMutator from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_tensorflow'] def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) else: pad = max(kernel1d - (input1d % stride1d), 0) pad_before = pad // 2 pad_after = pad - pad_before return [pad_before, pad_after] def _math_name_picker(surfix): def _impl(attr): return 'broadcast_' + surfix return _impl def _dimension_picker(prefix, surfix=''): def _impl(attr): kernel = attr['kernel_shape'] if len(kernel) == 2: return prefix + '2d' + surfix if len(kernel) == 3: return prefix + '3d' + surfix raise tvm.error.OpAttributeInvalid( 'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d')) return _impl def _dimension_constraint(): def _dim_check(attrs): if len(attrs['kernel_shape']) in (2, 3): return True return False return _dim_check, "Only 2d or 3d kernel supported." def _get_param(params, input_node): if isinstance(input_node, _expr.Constant): return np.atleast_1d(input_node.data.asnumpy()) return params.pop(input_node.name_hint).asnumpy() def _get_num_param(params, input_node): return _get_param(params, input_node).item() def _get_list_param(params, input_node): return _get_param(params, input_node).tolist() def _get_tuple_param(params, input_node): return tuple(_get_param(params, input_node)) def _need_module_for_shape_inference(op): return op in ['StridedSlice'] def _need_prelude_for_shape_inference(op): return "TensorArray" in op def _rsqrt(): def _impl(inputs, attr, params): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) return AttrCvt(op_name="power")(inputs, attr) return _impl def _argx(func, func_name): """ A common wrapper for argmin and argmax operations """ def _impl(inputs, attr, params): try: # In Tensorflow, `axis` argument is a Tensor, not attribute. We # support the case where it inputs from a scalar constant. axis_input_value = [_get_num_param(params, inputs[1])] except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) return func(inputs[0], axis=axis_input_value, keepdims=False) return _impl def _elemwise(name): def _impl(inputs, attr, params): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return get_relay_op(name)(*inputs) return _impl def _pool3d(name): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False input_shape = attr['_input_shapes'][inputs[0]] if attr['data_format'] == 'NDHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) elif attr['data_format'] == 'NCDHW': attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4]) attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) else: msg = 'Value {} of attribute "data_format" of operator Pooling ' \ 'is not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['data_format'] == "NDHWC": input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)] inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) attr['data_format'] = "NCDHW" attr['_input_shapes'][inputs[0]] = input_shape flip_layout = True attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': attr['padding'] = [0, 0, 0, 0, 0, 0] elif attr['padding'] == 'SAME': stride_d, stride_h, stride_w = attr['strides'] kernel_d, kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NDHWC': in_d = input_shape[1] in_h = input_shape[2] in_w = input_shape[3] else: in_d = input_shape[2] in_h = input_shape[3] in_w = input_shape[4] pad_d = _get_pad_pair(in_d, kernel_d, stride_d) pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Pooling is ' \ 'not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False attr['ceil_mode'] = False out = AttrCvt( op_name=name, transforms={ 'kernel_shape': 'pool_size', 'data_format': 'layout'}, ignores=['ksize'])(inputs, attr) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) return out return _impl def _pooling(name): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False input_shape = attr['_input_shapes'][inputs[0]] if attr['data_format'] == 'NHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2]) elif attr['data_format'] == 'NCHW': attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: msg = 'Value {} of attribute "data_format" of operator Pooling ' \ 'is not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) attr['data_format'] = "NCHW" flip_layout = True # Fix padding attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': attr['padding'] = [0, 0] elif attr['padding'] == 'SAME': stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] if attr['data_format'] == 'NHWC': in_h = input_shape[1] in_w = input_shape[2] else: in_h = input_shape[2] in_w = input_shape[3] pad_v = _get_pad_pair(in_h, kernel_h, stride_h) pad_h = _get_pad_pair(in_w, kernel_w, stride_w) attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Pooling is ' \ 'not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if name == "avg_pool": attr['count_include_pad'] = False out = AttrCvt( op_name=_dimension_picker(name), transforms={ 'kernel_shape':'pool_size', 'data_format':'layout'}, ignores=['ksize'], extras={'ceil_mode': False}, custom_check=_dimension_constraint())(inputs, attr) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) return out return _impl def _conv(opname): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': # transform to NCHW for TVM backend compatible and set 'flip_layout' # to have output flip back to NHWC tmp_shape = attr['_input_shapes'][inputs[2]] tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) attr['_input_shapes'][inputs[2]] = tmp_shape attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['data_format'] = 'NCHW' if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: tmp_shape = attr['_output_shapes'][0] tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] attr['_output_shapes'][0] = tmp_shape flip_layout = True inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCHW Layout require weights transpose if attr['data_format'] == 'NCHW': tmp_shape = attr['_input_shapes'][inputs[1]] if opname in ['conv', 'conv_transpose']: tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) attr['_input_shapes'][inputs[1]] = tmp_shape input_shape = attr['_input_shapes'][inputs_data] weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) if opname in ['conv', 'conv_transpose']: weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) attr['data_format'] = "NCHW" attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)] flip_layout = True if attr['data_format'] == 'NHWC': in_channels = input_shape[3] kernel_h, kernel_w, _, depth_mult = weights_shape attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) if opname == 'conv': attr['channels'] = weights_shape[3] elif opname == 'conv_transpose': attr['channels'] = weights_shape[2] else: attr['channels'] = input_shape[3] * depth_mult if 'dilations' in attr: attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2]) elif attr['data_format'] == 'NCHW': in_channels = input_shape[1] _, depth_mult, kernel_h, kernel_w = weights_shape attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) if opname == 'conv': attr['channels'] = weights_shape[0] elif opname == 'conv_transpose': attr['channels'] = weights_shape[1] else: attr['channels'] = input_shape[1] * depth_mult if attr['channels'] < 0: attr['channels'] *= -1 if 'dilations' in attr: attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][2], attr['strides'][3]) else: msg = 'Value {} in attribute "data_format" of operator Conv is ' \ 'not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if opname == 'depthwise': attr['groups'] = in_channels # Fix padding attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': attr['padding'] = [0, 0] elif attr['padding'] == 'SAME': stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] pdata_shape = input_shape if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: pdata_shape = attr['_output_shapes'][0] if attr['data_format'] == 'NHWC': in_h = pdata_shape[1] in_w = pdata_shape[2] else: in_h = pdata_shape[2] in_w = pdata_shape[3] dilation_h = attr['dilations'][0] dilation_w = attr['dilations'][1] dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ 'valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: if opname in ['conv', 'conv_transpose']: attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) channel_axis = 1 if attr['data_format'] == "NCHW" else 3 # Ignore the new attributes from TF2.0, for now. out = AttrCvt( op_name=_dimension_picker('conv', \ surfix="_transpose" if opname == 'conv_transpose' else ""), ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', 'dilations': ('dilation', (0, 0)), 'group': ('groups', 1)}, custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) if use_bias: out = _op.nn.bias_add(out, inputs[2] if opname != 'conv_transpose' else inputs[3], axis=channel_axis) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) return out return _impl # Dilation2d def _dilation2d(): def _impl(inputs, attr, params): if 'data_format' not in attr: attr['data_format'] = 'NHWC' input_shape = attr['_input_shapes'][inputs[0]] weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) weights_shape = [weights_shape[ii] for ii in (2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1)) attr['data_format'] = "NCHW" if attr['data_format'] in ['NHWC', 'NCHW']: if 'rates' in attr: attr['dilations'] = attr['rates'] if 'dilations' in attr: attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) attr['strides'] = (attr['strides'][1], attr['strides'][2]) else: msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' \ 'not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': attr['padding'] = [0, 0] elif attr['padding'] == 'SAME': stride_h, stride_w = attr['strides'] if attr['data_format'] == 'NHWC': kernel_h, kernel_w = weights_shape[0], weights_shape[1] else: kernel_h, kernel_w = weights_shape[1], weights_shape[2] if attr['data_format'] == 'NHWC': in_h = input_shape[1] in_w = input_shape[2] else: in_h = input_shape[2] in_w = input_shape[3] dilation_h = attr['dilations'][0] dilation_w = attr['dilations'][1] dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) if attr['data_format'] == 'NHWC': inputs[0] = _op.nn.pad(data=inputs[0], pad_width=((0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1]), (0, 0))) else: inputs[0] = _op.nn.pad(data=inputs[0], pad_width=((0, 0), (0, 0), (pad_v[0], pad_v[1]), (pad_h[0], pad_h[1]))) attr['padding'] = [0, 0] else: msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' \ 'valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) attr['kernel_layout'] = 'HWI' if attr['data_format'] == 'NHWC' else 'IHW' out = AttrCvt( op_name='dilation2d', ignores=['explicit_paddings', 'rates'], transforms={ 'data_format': 'data_layout', })([inputs[0], inputs[1]], attr) if attr['_target_layout'] == "NCHW": out = _op.transpose(out, axes=(0, 2, 3, 1)) return out return _impl def _conv3d(opname): def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCDHW Layout require weights transpose if attr['data_format'] == 'NCDHW': tmp_shape = attr['_input_shapes'][inputs[1]] tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) attr['_input_shapes'][inputs[1]] = tmp_shape input_shape = attr['_input_shapes'][inputs_data] weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC": input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3)) weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)] inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) attr['data_format'] = "NCDHW" attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)] flip_layout = True if attr['data_format'] == 'NDHWC': kernel_d, kernel_h, kernel_w, _, _ = weights_shape attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w) if opname == 'conv': attr['channels'] = weights_shape[4] elif opname == 'conv_transpose': attr['channels'] = weights_shape[3] if 'dilations' in attr: attr['dilations'] =\ (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) elif attr['data_format'] == 'NCDHW': _, _, kernel_d, kernel_h, kernel_w = weights_shape attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w) if opname == 'conv': attr['channels'] = weights_shape[0] elif opname == 'conv_transpose': attr['channels'] = weights_shape[1] if 'dilations' in attr: attr['dilations'] =\ (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4]) attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) else: msg = 'Value {} in attribute "data_format" of operator Conv is ' \ 'not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) # Fix padding attr['padding'] = attr['padding'].decode("utf-8") if attr['padding'] == 'VALID': attr['padding'] = [0, 0, 0] elif attr['padding'] == 'SAME': stride_d, stride_h, stride_w = attr['strides'] kernel_d, kernel_h, kernel_w = attr['kernel_shape'] pdata_shape = input_shape if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: pdata_shape = attr['_output_shapes'][0] if attr['data_format'] == 'NDHWC': in_d = pdata_shape[1] in_h = pdata_shape[2] in_w = pdata_shape[3] else: in_d = pdata_shape[2] in_h = pdata_shape[3] in_w = pdata_shape[4] dilation_d = attr['dilations'][0] dilation_h = attr['dilations'][1] dilation_w = attr['dilations'][2] dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d) pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ 'valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW' use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) channel_axis = 1 if attr['data_format'] == "NCDHW" else 4 # Ignore the new attributes from TF2.0, for now. out = AttrCvt( op_name=_dimension_picker('conv', \ surfix="_transpose" if opname == 'conv_transpose' else ""), ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', 'dilations': ('dilation', (0, 0)), 'group': ('groups', 1)}, custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) if use_bias: out = _op.nn.bias_add(out, inputs[2] if opname != 'conv_transpose' else inputs[3], axis=channel_axis) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 4, 1)) return out return _impl def _decode_image(): def _impl(inputs, attr, params): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") return inputs[0] return _impl def _crop_and_resize(): def _impl(inputs, attr, params): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] try: crop_size = _get_list_param(params, inputs[3]) except (IndexError, KeyError): crop_size = _infer_value(inputs[3], params).asnumpy().tolist() method = attr['method'].decode() method = 'nearest_neighbor' if method == 'nearest' else method if method not in ['bilinear', 'nearest_neighbor']: raise tvm.error.OpAttributeUnImplemented( 'Method {} is not supported'.format(method)) layout = attr['layout'] if 'layout' in attr else 'NHWC' extrapolation_value = attr['extrapolation_value'] return get_relay_op("crop_and_resize")(inputs[0], inputs[1], inputs[2], crop_size, layout, method, extrapolation_value) return _impl def _cast(): def _impl(inputs, attr, params): return inputs[0].astype(attr['DstT'].name) return _impl def _expand_dims(): def _impl(inputs, attr, params): dim_input = inputs.pop(1) axis = _get_num_param(params, dim_input) return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr) return _impl def _resize(method): def _impl(inputs, attr, params): output_shape0 = attr['_output_shapes'][0] # Dynamic size models might have _output_shapes attr equal to [None] here size = output_shape0[1:3] if output_shape0 is not None else [-1, -1] # Important that the size is defined. If an axis is not, we need to infer what # the shape should be. if -1 in size: size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() attr['size'] = size inputs.pop(1) # NHWC attr['layout'] = 'NHWC' if attr.pop('align_corners') is True: attr['coordinate_transformation_mode'] = 'align_corners' else: attr['coordinate_transformation_mode'] = 'asymmetric' # Ignore the new attributes from TF2.0, for now. return AttrCvt(op_name='resize', ignores=['Tdim', 'half_pixel_centers'], extras={'method': method})(inputs, attr) return _impl def _check_numerics(): def _impl(inputs, attr, params): # Making a copy node assuming no need to verify return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) return _impl def _assert(): # ToDo: In general people want asserts to be gone from TensorFlow graphs # when they are optimizing them, so converting it to a no-op is # reasonable. However, it would be nice to have the option to keep them # once Relay gets a Halt or Assert op. return _no_op() def _no_op(): def _impl(inputs, attr, params): # ToDo: This should really be an op that returns nothing, which could # be represented as an empty tuple. It turns out that TVM # infrastructure doesn't like running functions that return None and # also don't like running functions that return an empty tuple. So it # doesn't work, but it should be made to work and then this could be # improved. In the mean time, it is hard to imagine a case where it # matters in any real way that a no-op is converted to a constant 0. return tvm.relay.const(0) return _impl def _matmul(): def _impl(inputs, attr, params): channels = _infer_channels(inputs[1], not attr['transpose_b']) if attr['transpose_a']: inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) if not attr['transpose_b']: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) return AttrCvt(op_name="dense", extras={'units': channels}, ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr) return _impl def _batch_matmul(): def _impl(inputs, attr, params): input_x = inputs[0] input_y = inputs[1] orig_shape_x = attr['_input_shapes'][input_x] orig_shape_y = attr['_input_shapes'][input_y] # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) adj_x = attr['adj_x'] adj_y = attr['adj_y'] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y ret = get_relay_op('batch_matmul')(input_x, input_y) # reshape result back to n-dimensional if len(orig_shape_x) > 3: final_shape = list(orig_shape_x) final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] ret = _op.reshape(ret, newshape=final_shape) return ret return _impl def _identity(): def _impl(inputs, attr, params): return inputs[0] return _impl def _concatV2(): def _impl(inputs, attr, params): pop_node = inputs.pop(len(inputs)-1) axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['T', 'N', 'Tidx'], extras={'axis': axis})([inputs], attr) return _impl def _concat(): def _impl(inputs, attr, params): pop_node = inputs.pop(0) axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['N'], extras={'axis': axis})([inputs], attr) return _impl def _pack(): def _impl(inputs, attr, params): axis = int(attr["axis"]) inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] return _op.concatenate(inputs_reshaped, axis) return _impl def _tensor_array(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('dtype').name tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) return _impl def _tensor_array_scatter(): def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name values_rank = len(inputs[2].type_annotation.shape) unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) unstack_function = prelude.get_var(unstack_name, dtype_str) values = unstack_function(inputs[2]) tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) return tensor_array_scatter_func(inputs[0], inputs[1], values) return _impl def _tensor_array_gather(): def _impl(inputs, attr, params, prelude): return prelude.tensor_array_gather(inputs[2], inputs[1]) return _impl def _tensor_array_size(): def _impl(inputs, attr, params, prelude): return prelude.length(inputs[0]) return _impl def _tensor_array_write(): def _impl(inputs, attr, params, prelude): input_rank = len(inputs[2].type_annotation.shape) dtype = attr.get('T').name tensor_name = 'tensor{}'.format(input_rank) tensor_func = prelude.get_var(tensor_name, dtype) v = tensor_func(inputs[2]) write_func = prelude.get_var('tensor_array_write', dtype) return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) return _impl def _tensor_array_read(): def _impl(inputs, attr, params, prelude): read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) return _impl def _tensor_array_split(): def _impl(inputs, attr, params, prelude): input_rank = len(inputs[1].type_annotation.shape) dtype_str = attr.get('T').name v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) lengths = _op.cast(inputs[2], 'int32') split_var = prelude.get_var('tensor_array_split', dtype_str) return split_var(inputs[0], v, lengths) return _impl def _tensor_array_concat(): def _impl(inputs, attr, params, prelude): concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) return concat_func(inputs[1]) return _impl def _tile(): def _impl(inputs, attr, params): reps = _get_list_param(params, inputs.pop()) new_input = [] new_input.append(inputs.pop(0)) return AttrCvt( op_name='tile', extras={'reps': tuple(reps)}, ignores=['Tmultiples'])(new_input, attr) return _impl def _slice(): def _impl(inputs, attr, params): try: begin = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): begin = _infer_value(inputs[1], params).asnumpy().tolist()[0] try: size = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): size = _infer_value(inputs[2], params).asnumpy().tolist()[0] data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) end = size for i in range(data_dim): if size[i] == -1: end[i] = data_shape[i] else: end[i] += begin[i] return _op.strided_slice(inputs[0], begin=begin, end=end) return _impl def _reshape(): def _impl(inputs, attr, params): pop_node = inputs.pop(1) try: shape_arg = _get_tuple_param(params, pop_node) except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. try: params_new = _infer_value(pop_node, params) shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) except Exception: # Deal with symbolic shape case. # Currently only shape_of can be the direct ancestor. if not isinstance(pop_node, tvm.relay.expr.Call) or \ "shape_of" not in str(pop_node.op): raise RuntimeError("If shape operator is used in reshape to " "express reshape_like, shape_of must be " "the direct ancestor of reshape when input " "shape is symbolic.") return _op.reshape_like(inputs[0], pop_node.args[0]) return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, ignores=['Tshape'])(inputs, attr) return _impl def _depth_to_space(): def _impl(inputs, attr, params): block_size = int(attr['block_size']) layout = attr['data_format'].decode("utf-8") return _op.nn.depth_to_space(inputs[0], block_size, layout) return _impl def _space_to_depth(): def _impl(inputs, attr, params): block_size = int(attr['block_size']) layout = attr['data_format'].decode("utf-8") return _op.nn.space_to_depth(inputs[0], block_size, layout) return _impl def _bias_add(): def _impl(inputs, attr, params): # Must expand for proper broadcasting in NCHW. if attr['data_format'].decode("utf-8") == 'NCHW': bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) else: bias = inputs[1] return _op.add(inputs[0], bias) return _impl def _broadcast_to(): def _impl(inputs, attr, params): if isinstance(inputs[1], _expr.Var): shape = params[inputs[1].name_hint] else: shape = _infer_value(inputs[1], params) shape = list(shape.asnumpy().reshape([-1])) return _op.broadcast_to(inputs[0], shape) return _impl def _squeeze(): def _impl(inputs, attr, params): if len(attr['squeeze_dims']) == 0: attr['squeeze_dims'] = None return AttrCvt( op_name="squeeze", transforms={'squeeze_dims':'axis'}, ignores=['T'])(inputs, attr) return _impl def _fused_batch_norm(): def _impl(inputs, attr, params): # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) # Relay: (data, gamma, beta, moving_mean, moving_varience) assert len(inputs) == 5 axis = 3 need_cast = False if 'data_format' in attr: attr['data_format'] = attr['data_format'].decode("utf-8") if attr['data_format'] == 'NCHW': axis = 1 if 'U' in attr: need_cast = True inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name) # Check if mean and variance are empty # If so, replace them with Mean and Variance Ops # For run-time calculation moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0): inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) out = AttrCvt(op_name='batch_norm', transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, extras={'axis': axis}, ignores=['data_format', 'U'], disables=['momentum'])(inputs, attr) if need_cast: out = _expr.TupleGetItem(out.astuple(), 0) out = _op.cast(out, dtype=attr['T'].name) return out return _impl def _batch_norm(): def _impl(inputs, attr, params): # Rearrange inputs from # (data, moving_mean, moving_variance, beta, gamma) # to # (data, gamma, beta, moving_mean, moving_var) new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] axis = 3 if 'data_format' in attr: attr['data_format'] = attr['data_format'].decode("utf-8") if attr['data_format'] == 'NCHW': axis = 1 return AttrCvt( op_name='batch_norm', transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, extras={'axis': axis}, ignores=['data_format'], disables=['momentum'])(new_inputs, attr) return _impl def _relu6(): def _impl(inputs, attr, params): return _op.clip(inputs[0], a_min=0, a_max=6) return _impl def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True break if is_symbolic_shape: ret = _op.shape_of(inputs[0], dtype='int32') else: ret = np.array(attr['_input_shapes'][inputs[0]], dtype='int32') return ret return _impl def _fill(): def _impl(inputs, attr, params): output_shape = attr['_output_shapes'][0] # Output shape must be defined to avoid errors. If any axis is not, we must # try to compute its shape. if output_shape is None or -1 in output_shape: output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() fill_arg = _get_num_param(params, inputs.pop(1)) dtype = attr['T'].name return _op.full(tvm.relay.const(fill_arg, dtype), output_shape, dtype) return _impl def _lrn(): def _impl(inputs, attr, params): attr_new = {} depth_radius = attr.get('depth_radius', 5) size = (depth_radius * 2) + 1 attr_new['axis'] = 3 # Fix axis, NHWC format attr_new['size'] = size attr_new['bias'] = attr.get('bias', 1) attr_new['alpha'] = attr.get('alpha', 1) * size attr_new['beta'] = attr.get('beta', 0.5) return AttrCvt(op_name='lrn')(inputs, attr_new) return _impl def _sum(): def _impl(inputs, attr, params): axis = _get_tuple_param(params, inputs[1]) return AttrCvt( op_name='sum', extras={'axis': axis}, transforms={'keep_dims':'keepdims'}, ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl def _reduce(op): def _impl(inputs, attr, params): axis = _get_list_param(params, inputs[1]) axis = tuple(axis) return AttrCvt( op_name=op, extras={'axis': axis}, transforms={'keep_dims':'keepdims'}, ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl def _square(): def _impl(inputs, attr, params): return _op.multiply(inputs[0], inputs[0]) return _impl def _gather(): "GatherV2, Gather" def _impl(inputs, attr, params): if len(inputs) > 2: axis = _get_num_param(params, inputs.pop(2)) else: axis = 0 if int(attr.get('batch_dims', 0)) != 0: raise tvm.error.OpAttributeUnImplemented( 'Attribute batch_dims is not supported') new_input = inputs[0:2] return AttrCvt(op_name="take", extras={'axis': tvm.tir.const(axis, 'int32')}, ignores=['Tindices', 'Tparams', 'validate_indices', 'Taxis', '_class', 'batch_dims'])(new_input, attr) return _impl def _gather_nd(): """GatherNd""" def _impl(inputs, attr, params): return AttrCvt(op_name="gather_nd", ignores=['Tindices', 'Tparams',\ 'Taxis', '_class'])(inputs, attr) return _impl def _stridedSlice(): def _impl(inputs, attr, params, mod): """Strided Slice. Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/core/util/strided_slice_op.cc#L147-L368 """ begin = _get_list_param(params, inputs[1]) end = _get_list_param(params, inputs[2]) stride = _get_list_param(params, inputs[3]) begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) stride_dim = len(stride) def _transform_mask(stride_dim, ellipsis_mask): """Handle mask inputs to create new begin, end, stride and output shape""" m_begin = [0] * data_dim m_end = [0] * data_dim m_stride = [0] * data_dim fshape_indices = [] #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. ellipsis_seen = False new_axes_after_ellipsis = 0 for i in range(stride_dim): mask = 1 << i if ellipsis_seen and (mask & new_axis_mask) != 0: new_axes_after_ellipsis += 1 if (mask & ellipsis_mask) != 0: ellipsis_seen = True if not ellipsis_seen: #Used later for extending the stride attributes in the below loop. ellipsis_mask |= (1 << stride_dim) stride_dim += 1 final_index = 0 for index in range(stride_dim): mask = 1 << index if mask & ellipsis_mask: #Identify the end index for applying ellipsis_mask to_index = min(((data_dim - (stride_dim-index)) + 1 \ + new_axes_after_ellipsis), data_dim) for i in range(final_index, to_index): m_begin[final_index] = 0 m_end[final_index] = data_shape[final_index] m_stride[final_index] = 1 fshape_indices.append(final_index) final_index += 1 elif mask &new_axis_mask: fshape_indices.append(-1) elif not mask & new_axis_mask: if final_index == len(m_begin): break if mask & begin_mask: m_begin[final_index] = data_shape[final_index] \ if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 m_begin[final_index] = data_shape[final_index] + begin[index] \ if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) else: fshape_indices.append(final_index) final_index += 1 return m_begin, m_end, m_stride, fshape_indices fshape_indices = None if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) out_shape = _infer_shape(out, mod=mod) if not fshape_indices: fshape_indices = range(len(out_shape)) #Create final output shape. final_output = [] for gather_index in fshape_indices: if gather_index == -1: final_output.append(1) elif gather_index == -2: pass else: final_output.append(out_shape[gather_index]) if not final_output: return out return _op.reshape(out, newshape=tuple(final_output)) return _impl def _pad(name): def _impl(inputs, attr, params): padlist = _get_param(params, inputs[1]) paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings attr['pad_value'] = 0 new_inputs = [inputs[0]] if name == 'PadV2': constant_values = _get_num_param(params, inputs[2]) attr['pad_value'] = constant_values return AttrCvt( op_name='pad', ignores=['Tpaddings'],)(new_inputs, attr) return _impl def _mirror_pad(): def _impl(inputs, attr, params): padlist = _get_param(params, inputs[1]) paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings mode = attr['mode'].decode('utf-8') attr['mode'] = mode new_inputs = [inputs[0]] return AttrCvt( op_name='mirror_pad', ignores=['Tpaddings'],)(new_inputs, attr) return _impl def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, # otherwise its value is get from params try: axes = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): axes = _infer_value_simulated(inputs[1], params).asnumpy() return _op.transpose(inputs[0], axes=axes) return _impl def _where(): def _impl(inputs, attr, params): if len(inputs) == 1: return AttrCvt(op_name="argwhere")(inputs, attr) return AttrCvt(op_name="where")(inputs, attr) return _impl def _clip_by_value(): def _impl(inputs, attr, params): a_min = _get_num_param(params, inputs[1]) a_max = _get_num_param(params, inputs[2]) return _op.clip(inputs[0], a_min=a_min, a_max=a_max) return _impl def _reverse_v2(): def _impl(inputs, attr, params): axis = _get_num_param(params, inputs[1]) return AttrCvt( op_name="reverse", ignores=['Tidx'], extras={'axis': int(axis)})([inputs[0]], attr) return _impl def _rank(): def _impl(inputs, attr, params): input_shape = attr['_input_shapes'][inputs[0]] name = attr["_node_name"] params[name] = tvm.nd.array([len(input_shape)]) return [_expr.var(name, shape=params[name].shape, dtype='int32')] return _impl def _range(): def _impl(inputs, attr, params): start = _get_param(params, inputs[0])[0] if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): limit = _get_param(params, inputs[1])[0] else: if any(['Rank' in param for param in params]): limit = params.pop('Rank').asnumpy()[0] else: limit = _infer_value_simulated(inputs[1], params).asnumpy()[0] delta = _get_param(params, inputs[2])[0] dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) return AttrCvt( op_name="arange", ignores=['Tidx'], extras={'start': _expr.const(start), "stop": _expr.const(limit), 'step': _expr.const(delta), 'dtype': dtype})([], attr) return _impl def _elu(): def _impl(inputs, attr, params): dtype = attr['T'].name alpha = tvm.relay.const(-1.0, dtype) return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) return _impl def _selu(): def _impl(inputs, attr, params): dtype = attr['T'].name alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) return _impl def _mean(): def _impl(inputs, attr, params): axis = _get_tuple_param(params, inputs[1]) return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], transforms={'keep_dims': 'keepdims'}, extras={'axis': axis})([inputs[0]], attr) return _impl def _broadcast(name): def _impl(inputs, attr, params): return AttrCvt( op_name=name, ignores=['name', 'incompatible_shape_error', 'Tidx'] )(inputs, attr) return _impl def _split(has_size_vector): # TF documentation https://www.tensorflow.org/api_docs/python/tf/split def _impl(inputs, attr, params): try: # order and number of inputs are different: # if has_size_vector: # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v # else: # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, # we can only support constants if has_size_vector: input_node_index = 0 input_axis_index = 2 size_splits = _get_param(params, inputs[1]) section_beginnings = np.cumsum(size_splits)[:-1] indices_or_sections = tuple(section_beginnings) else: input_node_index = 1 input_axis_index = 0 indices_or_sections = attr['num_split'] input_node = inputs[input_node_index] axis_input_value = _get_num_param(params, inputs[input_axis_index]) except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for split: `axis` and `num_or_size_splits` " \ "should be constants") return _op.split(input_node, indices_or_sections=indices_or_sections, axis=int(axis_input_value)) return _impl def _unpack(): def _impl(inputs, attr, params): input_node = inputs[0] axis = attr['axis'] input_shape = attr['_input_shapes'][input_node] axis_length = input_shape[axis] if axis_length < 0: raise TypeError("Unstack with unknown axis length") splitted = _op.split(input_node, indices_or_sections=axis_length, axis=axis) #name=attr.get('_node_name', 'unstack')) if axis == 0: axis = None else: axis = [axis] return _expr.TupleWrapper( _expr.Tuple([_op.squeeze(split_item, axis=axis) \ for split_item in splitted]), len(splitted)) return _impl def _softmax(): def _impl(inputs, attr, params): return AttrCvt(op_name='softmax', transforms={'axis': ('axis', 1)})([inputs[0]], attr) return _impl def _softplus(): # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus def _impl(inputs, attr, params): exp_out = AttrCvt('exp')(inputs, attr) inputs.append(tvm.relay.const(1, attr['T'].name)) rh = tvm.relay.const(1, attr['T'].name) add_out = get_relay_op('add')(exp_out, rh) return get_relay_op('log')(add_out) return _impl def _topk(): def _impl(inputs, attr, params): k = int(_get_num_param(params, inputs.pop(1))) if k < 1: raise tvm.error.OpAttributeInvalid( 'Attribute k must be positive in operator TopKV2') if attr['sorted'] is False: raise tvm.error.OpAttributeUnImplemented( 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) return _impl def _floordiv(): def _impl(inputs, attr, params): assert len(inputs) == 2 return AttrCvt('floor_divide')(inputs, attr) return _impl def _floormod(): def _impl(inputs, attr, params): assert len(inputs) == 2 return AttrCvt('floor_mod')(inputs, attr) return _impl def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) return _impl def _space_to_batch_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] block_shape = _get_list_param(params, inputs[1]) paddings = _get_list_param(params, inputs[2]) N = len(input_shape) M = len(block_shape) batch = input_shape[0] remaining_shape_length = N - M - 1 paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d: # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings # to produce padded of shape padded_shape. padded = tvm.relay.nn.pad(input_node, pad_width=paddings) # Reshape padded to reshaped_padded of shape: # [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ..., # padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] reshaped_padded = tvm.relay.reshape(padded, newshape=shape1) # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: # block_shape + [batch] + [padded_shape[1] / block_shape[0], ..., # padded_shape[M] / block_shape[M-1]] + remaining_shape axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., # padded_shape[M] / block_shape[M-1]] + remaining_shape shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:] reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded, newshape=shape2) return reshaped_permuted_reshaped_padded return _impl def _batch_to_space_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] block_shape = _get_list_param(params, inputs[1]) crops = _get_list_param(params, inputs[2]) M = len(block_shape) batch = input_shape[0] # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: # Reshape input to reshaped of shape: # [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), # input_shape[1], ..., input_shape[N-1]] shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] reshaped = tvm.relay.reshape(input_node, newshape=shape1) # Permute dimensions of reshaped to produce permuted of shape # [batch / prod(block_shape), input_shape[1], block_shape[0], ..., # input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \ list(range(2 * M + 1, len(shape1))) permuted = tvm.relay.transpose(reshaped, axes=axes) # Reshape permuted to produce reshaped_permuted of shape # [batch / prod(block_shape), input_shape[1] * block_shape[0], ..., # input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] shape2 = [0] + [-3] * M + [-2] reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2) # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops # to produce the output of shape: # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] reshaped_permuted_shape = _infer_shape(reshaped_permuted) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] if crop != [0, 0]: indices = tvm.relay.arange( _expr.const(crop[0]), _expr.const(reshaped_permuted_shape[axis] - crop[1]), dtype='int32' ) cropped = tvm.relay.take(cropped, indices=indices, axis=axis) return cropped return _impl def _prod(): def _impl(inputs, attr, params): axis = _get_num_param(params, inputs[1]) keepdims = attr['keep_dims'] return _op.prod(inputs[0], int(axis), keepdims=keepdims) return _impl def _log1p(): # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p def _impl(inputs, attr, params): one = tvm.relay.const(1, attr['T'].name) add_out = get_relay_op('add')(inputs[0], one) return get_relay_op('log')(add_out) return _impl def _one_hot(): def _impl(inputs, attr, params): depth = int(_get_num_param(params, inputs[1])) dtype = attr['T'].name on_value = _get_num_param(params, inputs[2]) off_value = _get_num_param(params, inputs[3]) new_inputs = [inputs[0], \ tvm.relay.const(on_value, dtype), \ tvm.relay.const(off_value, dtype)] return AttrCvt('one_hot', ignores=['TI'], extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr) return _impl def _squared_difference(): def _impl(inputs, attr, params): difference = _op.subtract(inputs[0], inputs[1]) return _op.multiply(difference, difference) return _impl def _size(): def _impl(inputs, attr, params): new_attr = attr new_attr['out_type'] = attr['out_type'].name return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr) return _impl def _add_n(): def _impl(inputs, attr, params): if not isinstance(inputs, tuple): inputs = list(inputs) assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." _res = inputs[0] for each in inputs[1:]: _res = _op.add(_res, each) return _res return _impl # compatible operators that do NOT require any conversion. _identity_list = [] # Operators that get pruned away when the complete graph is frozen. # These operators are not needed for inference. _freezed_graph_pruned_op_list = ['ReadVariableOp', 'ResourceGather', 'Variable', 'VariableV2', 'VarHandleOp', 'Assign', 'AssignVariableOp'] # _convert_map defines maps of name to converter functor(callable) # for 1 to 1 mapping, use Renamer if nothing but name is different # use AttrCvt if attributes need to be converted # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { 'Abs' : AttrCvt('abs'), 'Add' : _elemwise('add'), 'AddV2' : _elemwise('add'), 'AddN' : _add_n(), 'All' : _reduce('all'), 'Any' : _reduce('any'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'Assert' : _assert(), 'AvgPool' : _pooling('avg_pool'), 'AvgPool3D' : _pool3d('avg_pool3d'), 'BatchMatMul' : _batch_matmul(), 'BatchMatMulV2' : _batch_matmul(), 'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchToSpaceND' : _batch_to_space_nd(), 'BiasAdd' : _bias_add(), 'BroadcastTo' : _broadcast_to(), 'Cast' : _cast(), 'Ceil' : AttrCvt('ceil'), 'CheckNumerics' : _check_numerics(), 'ClipByValue' : _clip_by_value(), 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), 'Conv3D' : _conv3d('conv'), 'Conv2DBackpropInput' : _conv('conv_transpose'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), 'DepthToSpace' : _depth_to_space(), 'Dilation2D' : _dilation2d(), 'Equal' : _broadcast('equal'), 'Elu' : _elu(), 'Erf' : AttrCvt('erf'), 'Exp' : AttrCvt('exp'), 'ExpandDims' : _expand_dims(), 'Fill' : _fill(), 'Floor' : AttrCvt('floor'), 'FloorDiv' : _floordiv(), 'FloorMod' : _floormod(), 'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(), 'FusedBatchNormV3' : _fused_batch_norm(), 'Gather' : _gather(), 'GatherNd' : _gather_nd(), 'GatherV2' : _gather(), 'Greater' : _broadcast('greater'), 'GreaterEqual' : _broadcast('greater_equal'), 'Identity' : _identity(), 'LeakyRelu' : AttrCvt('leaky_relu'), 'LeftShift' : AttrCvt('left_shift'), 'Less' : _broadcast('less'), 'LessEqual' : _broadcast('less_equal'), 'Log' : AttrCvt('log'), 'Log1p' : _log1p(), 'Tan' : AttrCvt('tan'), 'Cos' : AttrCvt('cos'), 'Sin' : AttrCvt('sin'), 'LogicalAnd' : _logical('logical_and'), 'LogicalOr' : _logical('logical_or'), 'LogicalNot' : _logical('logical_not'), 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), 'MatMul' : _matmul(), 'Max' : _reduce('max'), 'MaxPool' : _pooling('max_pool'), 'MaxPool3D' : _pool3d('max_pool3d'), 'Maximum' : _elemwise('maximum'), 'Mean' : _mean(), 'Min' : _reduce('min'), 'Minimum' : _elemwise('minimum'), 'MirrorPad' : _mirror_pad(), 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), 'TensorArrayV3' : _tensor_array(), 'TensorArrayScatterV3' : _tensor_array_scatter(), 'TensorArrayGatherV3' : _tensor_array_gather(), 'TensorArraySizeV3' : _tensor_array_size(), 'TensorArrayWriteV3' : _tensor_array_write(), 'TensorArrayReadV3' : _tensor_array_read(), 'TensorArraySplitV3' : _tensor_array_split(), 'TensorArrayConcatV3' : _tensor_array_concat(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), 'Prod' : _prod(), 'Range' : _range(), 'Rank' : _rank(), 'RealDiv' : _elemwise('divide'), 'Relu' : AttrCvt('relu'), 'Relu6' : _relu6(), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize('bilinear'), 'ResizeBicubic' : _resize('bilinear'), 'ResizeNearestNeighbor' : _resize('nearest_neighbor'), 'ReverseV2' : _reverse_v2(), 'RightShift' : AttrCvt('right_shift'), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), 'Select' : _where(), 'Selu' : _selu(), 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), 'Sign' : AttrCvt('sign'), 'Size' : _size(), 'Slice' : _slice(), 'Softmax' : _softmax(), 'Softplus' : _softplus(), 'SpaceToBatchND' : _space_to_batch_nd(), 'SpaceToDepth' : _space_to_depth(), 'Split' : _split(False), 'SplitV' : _split(True), 'Sqrt' : AttrCvt('sqrt'), 'Square' : _square(), 'SquaredDifference' : _squared_difference(), 'Squeeze' : _squeeze(), 'StopGradient' : _identity(), 'StridedSlice' : _stridedSlice(), 'Sub' : _elemwise('subtract'), 'Sum' : _sum(), 'Tanh' : AttrCvt('tanh'), 'Tile' : _tile(), 'TopKV2' : _topk(), 'Transpose' : _transpose(), 'TruncateMod' : _elemwise('mod'), 'Unpack' : _unpack(), 'Where' : _where(), 'ZerosLike' : AttrCvt('zeros_like'), } def _LSTMBlockCell(): def _impl(inputs, in_state_c, in_state_h, attr, params): """LSTM Block cell. Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 Parameters ---------- inputs : relay.Expr Input data in_state_c: list of relay.Expr Cell state input values for all the layers in_state_h: list of relay.Expr Hidden state input values for all the layers attrs : dict Dict of operator attributes params : dict List of pretrained weights and bias Returns ------- sym : relay.Expr Converted relay.Expr output: relay.Expr Output state value. """ in_data = inputs[0] in_weight = inputs[3] in_bias = inputs[7] forget_bias = attr.pop('forget_bias') input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] batch_size, input_size = input_shape[0], input_shape[1] num_hidden_layers = weight_shape[1] num_hidden = num_hidden_layers // 4 in_data = _op.reshape(in_data, newshape=(batch_size, input_size)) ixh = _op.concatenate([in_data, in_state_h], axis=1) in_weight = _op.transpose(in_weight, axes=None) gates = _op.nn.dense(ixh, in_weight, units=num_hidden_layers) gates_bias = _op.add(gates, in_bias) gate_list = _op.split(gates_bias, indices_or_sections=4, axis=1) in_gate = _op.sigmoid(gate_list[0]) in_transform = _op.tanh(gate_list[1]) forget_gate = _op.add(gate_list[2], tvm.relay.const(forget_bias, attr['T'].name)) forget_gate = _op.sigmoid(forget_gate) out_gate = _op.sigmoid(gate_list[3]) next_c = _op.add(_op.multiply(forget_gate, in_state_c), _op.multiply(in_gate, in_transform)) next_h = out_gate * _op.tanh(next_c) out_state = _op.concatenate([next_c, next_h], axis=1) out_state = _op.reshape(out_state, newshape=(2, batch_size, num_hidden)) return next_h, out_state return _impl # _convert_map_rnn defines maps of rnn operator name to # converter functor(callable) for 1 to 1 mapping. _convert_map_rnn = { 'LSTMBlockCell' : _LSTMBlockCell(), } class RecurrentNetworks(object): """Recurrent network layer handlers. Handle Layer operations. ToDo: Operators like RNN/GRU layer concepts also can be handled here Parameters ---------- nodes : list list of graph nodes used for tensorflow parsing. out_rnn : list List of RecurrentNetwork outputs. This output will be appended to the 'head' nodes of the graph. graph : tensorflow graph definition object The loaded tensorflow GraphDef convert_map : dict Dict of name : callable, where name is the op's name that require conversion to relay, callable are functions which take attrs and return (new_op_name, new_attrs) """ def __init__(self, nodes, out_rnn, graph, convert_map): self._graph = graph self._convert_map = convert_map self._nodes = nodes self._out_rnn = out_rnn self._cur_lstm_layer = 0 self._layer_name_list = [] self._recurrent_ops_layer_map = { 'LSTMBlockCell' : self._LSTMBlockCellLayer(), } def _LSTMBlockCellLayer(self): """LSTMBlockCell layer handler. Parameters ---------- op_name : str Operator name, eg:LSTMBlockCell layer_name : str list Layer name is used for creating the state input placeholder. inputs : relay.Expr Input data attrs : dict Dict of operator attributes params : dict List of pretrained weights and bias num_layers : int Total number of LSTM layer presented in the graph Returns ------- sym : relay.Expr The returned relay Expr """ def _impl(op_name, layer_name, inputs, attrs, params, num_layers): in_state_c_name = layer_name+'_c' in_state_h_name = layer_name+'_h' def _init_state(num_layers, batch_size, num_hidden): """Create the initial states for the first layer in the graph.""" in_state_c = [_expr.var(in_state_c_name, shape=(num_layers, batch_size, num_hidden), dtype='float32')] in_state_h = [_expr.var(in_state_h_name, shape=(num_layers, batch_size, num_hidden), dtype='float32')] return in_state_c, in_state_h def _get_cur_input_state(in_state_c, in_state_h, num_layers, layer, batch_size, num_hidden): """Select the appropriate states for the current layer""" in_state_c_tup = _op.split(in_state_c[0], indices_or_sections=num_layers, axis=0) in_state_h_tup = _op.split(in_state_h[0], indices_or_sections=num_layers, axis=0) cur_in_state_c = _op.reshape(in_state_c_tup[layer], newshape=(batch_size, num_hidden)) cur_in_state_h = _op.reshape(in_state_h_tup[layer], newshape=(batch_size, num_hidden)) return cur_in_state_c, cur_in_state_h def _LSTMBlockCellWrapper(inputs, attr, params, num_layers, layer): """LSTM cell warapper to prepare the inputs""" input_shape = attr['_input_shapes'][inputs[0]] weight_shape = attr['_input_shapes'][inputs[3]] batch_size = input_shape[0] num_hidden = weight_shape[1] // 4 if layer == 0: #Create initial states placeholder in case of first layer in_state_c, in_state_h = _init_state(num_layers, batch_size, num_hidden) else: in_state_c = self._nodes[in_state_c_name] in_state_h = self._nodes[in_state_h_name] cur_in_state_c, cur_in_state_h = _get_cur_input_state( \ in_state_c, in_state_h, num_layers, layer, batch_size, num_hidden) output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, cur_in_state_h, attr, params) return output, out_state, in_state_c, in_state_h sym, cur_out_state, in_state_c, in_state_h = \ _LSTMBlockCellWrapper(inputs, attrs, params, num_layers, self._cur_lstm_layer) self._nodes[in_state_c_name] = in_state_c self._nodes[in_state_h_name] = in_state_h cur_out_state = _op.expand_dims(cur_out_state, axis=0, num_newaxis=1) self._out_rnn.append(cur_out_state) self._cur_lstm_layer += 1 return sym return _impl def process_op(self, op_name, inputs, attrs, params): """Process recurrent layer operators. List '_recurrent_ops_layer_map' map each Layer based operators with its layer handlers. Total number of layers are calculated to form the input data shapes. Parameters ---------- op_name : str Operator name, such as LSTMBlockCell inputs : relay.Expr Input data attrs : dict Dict of operator attributes params : dict List of pretrained weights and bias Returns ------- sym : relay.Expr Returns relay.Expr """ def _get_abs_layer_name(node): """Identify the layer name is already handled. Return the absolute name """ if not self._layer_name_list: self._layer_name_list.append(node.name) return node.name for _name in self._layer_name_list: if _name in node.name: abs_name = _name else: self._layer_name_list.append(node.name) abs_name = node.name return abs_name #Find number of layers of this same operator node in the graph #and also read the inputs name for the current op. num_layers = 0 for _, node in enumerate(self._graph.node): if node.op == op_name: layer_name = _get_abs_layer_name(node) num_layers += 1 sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, params, num_layers) return sym # An internal list to contain all the control flow primitives used in Tensorflow # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] class RewriteSubgraph(ExprMutator): """ A helper class to rewrite expr in while loop function to variable Parameters ---------- rewrite_map : Dict[expr, expr] A dictionay contains a set of expr to var mapping. """ def __init__(self, rewrite_map): ExprMutator.__init__(self) self.rewrite_map = rewrite_map def visit(self, expr): if expr in self.rewrite_map: return self.rewrite_map[expr] return super().visit(expr) def rewrite_subgraph(expr, rewrites): return RewriteSubgraph(rewrites).visit(expr) def _in_while_loop(control_flow_node_map, op_name): """ Check if a given control flow operator is part of a while loop execution frame. This is based on the fact that there is only one occurrence of `LoopCond` for a loop execution frame and it is only presented in the loop construct. Parameters ---------- control_flow_node_map : Dict[str, Set[str]] A dictionay contains the unique control flow execution frame name to a set of primitive operators mapping. op_name : str The name of a control flow primitive. Returns ------- ret : bool Return true if the operator is in a while loop execution frame, otherwise, return false. """ return op_name in control_flow_node_map and \ "LoopCond" in control_flow_node_map[op_name] class Branch: """A class contains the components that are used to build up a Relay if node. Parameters ---------- cond : tvm.relay.Expr The condition of a if node. true_branch : tvm.relay.Expr The body of the true branch of a if expression. false_branch: tvm.relay.Expr The body of the false branch of a if expression. _if : tvm.relay.Expr An internal variable indicates where an if expression is already created for a matched TF condition construct. Examples -------- The following is a cond statement written in TensorFlow: .. code-block:: python def vanilla_cond(): i = tf.constant(1) j = tf.constant(4) def f1(): return tf.multiply(1, 17) def f2(): return tf.add(4, 23) r = tf.cond(tf.less(i, j), f1, f2) This condition statement should be converted into Relay in the following form: .. code-block:: python fn (%Const: Tensor[(1,), int32], %Const_1: Tensor[(1,), int32], %cond/Mul/x: Tensor[(1,), int32], %cond/Mul/y: Tensor[(1,), int32], %cond/Add/x: Tensor[(1,), int32], %cond/Add/y: Tensor[(1,), int32]) { %0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool] %1 = min(%0) if (%1) { %2 = multiply(%cond/Mul/x, %cond/Mul/y) %2 } else { %3 = add(%cond/Add/x, %cond/Add/y) %3 } } """ def __init__(self): self._if = None self.cond = None self.true_branch = None self.false_branch = None def _if_node(self): """An internal API to create a relay if node from the matched TF condition construct. """ # `cond` returns a tensor that contains boolean values. We add a `min` # operator to checks if there is any false value. If so, this condition # doesn't not hold. cond = tvm.relay.op.min(self.cond) return tvm.relay.If(cond, self.true_branch, self.false_branch) def if_node(self): """Create an tvm.relay.If node if it hasn't been created yet.""" if self._if is None: self._if = self._if_node() return self._if class Loop: """ A class contains the components that are used to build up a Relay recursive call. Parameters ---------- loop_vars : List[tvm.relay.Expr] The loop variables that used in a while loop. cond : tvm.relay.Expr The condition of a while loop. body : tvm.relay.Expr The body of a matched while loop. _loop : tvm.relay.Expr An internal variable indicates where a recursive call is already created for a matched TF while loop construct. Examples -------- The following is a vanilla loop from TensorFlow: .. code-block:: python i = tf.constant(0) c = lambda i: tf.less(i, 10) b = lambda i: tf.add(i, 1) r = tf.while_loop(c, b, [i]) It will be converted to the following recursive call in Relay: .. code-block:: python fn (%while/Less/y: Tensor[(1,), int32], %while/Add/y: Tensor[(1,), int32], %Const: Tensor[(1,), int32]) { %0 = fn(%loop_var0: Tensor[(1,), int32]) { %1 = less(%loop_var0, %while/Less/y) %2 = min(%1) if (%2) { %3 = add(%loop_var0, %while/Add/y) free_var %while_loop %4 = %while_loop(%3) %4 } else { %5 = (%loop_var0,) %5 } } let %while_loop1 = %0 %6 = %while_loop1(%Const) %6 } """ def __init__(self): self.loop_vars = [] self.cond = None self.body = [] self._loop = None def _while_loop(self): """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ wl = tvm.relay.var('while_loop') sb = tvm.relay.scope_builder.ScopeBuilder() loop_vars = [] bind_map = {} for i, var in enumerate(self.loop_vars): if not isinstance(var, _expr.Var): var_chk = _infer_type(var) var_type = var_chk.checked_type else: var_type = var.type_annotation v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type) loop_vars.append(v) bind_map[var] = v self.cond = rewrite_subgraph(self.cond, bind_map) self.body = [rewrite_subgraph(b, bind_map) for b in self.body] cond = tvm.relay.op.min(self.cond) with sb.if_scope(cond): sb.ret(wl(*self.body)) with sb.else_scope(): sb.ret(tvm.relay.Tuple(loop_vars)) loop_fn = tvm.relay.Function(loop_vars, sb.get()) sb = tvm.relay.scope_builder.ScopeBuilder() sb.let(wl, loop_fn) sb.ret(wl(*self.loop_vars)) return sb.get() def while_loop(self): """Instantiate a while loop if it has not been created yet.""" if self._loop is None: self._loop = self._while_loop() return self._loop return self._loop class GraphProto(object): """ A helper class for handling relay graph copying from Tensorflow GraphDef. Definition: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto """ def __init__(self): self._nodes = {} self._params = {} self._input_shapes = {} self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False self._input_shapes = {} self._loops = {} self._branches = {} self._mod = IRModule({}) self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to Relay. Some of the assumptions listed below. -> All Placeholders are considered as graph input. -> All Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Graph should be frozen with add_shapes=True. Or user can pass input shape dictionary optionally. -> DecodeJpeg, ResizeBilinear: These are dummy operators. Hence user should handle preprocessing outside. -> CheckNumerics: No implementation as of now for this. Just copies input to output. Parameters ---------- graph : tensorflow graph definition object The loaded tensorflow GraphDef layout : target layout to be used (Optional) NCHW only supported now to enable NHWC models on GPU. shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. Returns ------- mod : tvm.IRModule The module that optimizations will be performed on. params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ try: from tensorflow.python.framework import tensor_util except ImportError as e: raise ImportError( "Unable to import tensorflow which is required {}".format(e)) missing_operators = self._parse_import_prerequisites(graph) if missing_operators: freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list] if freezed_ops: raise Exception("Graph is not frozen. Provide a frozen graph. " "Found operators {}".format(freezed_ops)) raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) control_flow_node_map = defaultdict(set) for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) else: self._input_shapes[node.name] = \ tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) for idx, dim in enumerate(self._input_shapes[node.name]): if dim < 0: self._input_shapes[node.name][idx] = 1 warnings.warn("Use 1 instead of -1 in shape of operator %s." % node.name) self._output_shapes[node.name] = [self._input_shapes[node.name]] attr = self._parse_attr(node.attr) self._nodes[node.name] = [_expr.var(node.name, shape=self._input_shapes[node.name], dtype=attr['dtype'].name)] # Ignore user's input shape for Non placeholder elif node.op == 'Const': tensor_value = node.attr['value'].tensor self._input_shapes[node.name] = \ tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) if shape and node.name in shape: warnings.warn("Ignore the passed shape. Shape in graphdef " "will be used for operator %s." % node.name) # Parse the nodes to re-create TF graph using Relay operators. for node in graph.node: # Tensorflow doesn't have separate list for params extraction. # Operator name 'Const' is treated as a parameter to build params dict. input_shapes = {} attr = self._parse_attr(node.attr) # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': self._output_shapes[node.name] = [self._input_shapes[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ for tshape in attr['_output_shapes']] else: # Keep the list indexable to avoid key error. # Actual value will be filled after node creation. # Will infer shapes if the graph is not frozen with add_shapes=True self._output_shapes[node.name] = [None] if node.op == "Const": # All Const nodes are Param nodes, lets parse self._num_param += 1 for key, value in node.attr.items(): self._parse_param(key, value, node.name, shape) if node.name not in self._nodes: raise NotImplementedError( \ "Const {} couldn't be converted to Param.".format(node.name)) attr = self._parse_attr(node.attr) elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] # Pass the node name too in attr attr["_node_name"] = node.name # Pass the target layout attr["_target_layout"] = layout # Fill shapes for all inputs in a list inputs = [] for i in node.input: # Some TensorFlow operators internally maintain execution layers # and their output name includes the layer number along with # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case, # the number has to be ignored for single-output nodes. # On the other hand, for multi-output nodes the number is the output index, # and the lack of the number implies 0. tensor_name = i.split(':') node_name = tensor_name[0] if node_name in self._nodes: in_sym = self._nodes[node_name] if isinstance(in_sym, _expr.TupleWrapper): tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 in_sym = [in_sym[tensor_slot]] input_shape = self._output_shapes[node_name][tensor_slot] else: tensor_slot = 0 input_shape = self._output_shapes[node_name][0] inputs.append(in_sym[0]) input_shapes[in_sym[0]] = input_shape attr['_input_shapes'] = input_shapes if node.op in _control_flow_nodes: op = self._convert_control_flow_operator(node, inputs, attr, control_flow_node_map) else: op = self._convert_operator(node.op, inputs, attr, graph) # Check if op is converted to param if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [_expr.var(node.name, shape=self._params[node.name].shape, dtype=self._params[node.name].dtype)] elif isinstance(op, (_expr.TupleWrapper, tuple, list)): pass elif isinstance(op, _expr.Expr): op = [op] else: raise RuntimeError("unexpected type %s" % type(op)) self._nodes[node.name] = op # Infer shapes even without specifying "add_shapes=True" if output_shapes == [None]: out_shapes = [_infer_shape(node_item, self._mod) for node_item in self._nodes[node.name]] self._output_shapes[node.name] = out_shapes if self._output_shapes[node.name] and shape and node.name in shape: assert self._output_shapes[node.name] == list(shape[node.name]) # Infer shapes if passed explicitly node_output = self._nodes[node.name] if shape and (not self._output_shapes[node.name][0] or -1 in self._output_shapes[node.name][0]): out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output] self._output_shapes[node.name] = out_shapes out = [] if outputs is None: if node.op == "Exit": out = [op[0].tuple_value] else: out = op else: for out_name in outputs: if ":" in out_name: out_name, out_num = out_name.split(":") out_num = int(out_num) out.append(self._nodes[out_name][out_num]) else: out.append(self._nodes[out_name][0]) #Add the RNN outputs also with 'head' nodes of the relay graph if self._num_rnn_layer: if len(self._out_rnn) == 1: out.append(self._out_rnn[0]) else: out_rnn = _op.concatenate(self._out_rnn, axis=0) out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) func = _function.Function(analysis.free_vars(out), out) self._mod["main"] = func return self._mod, self._params def _parse_import_prerequisites(self, graph): """ Calculate the named preconditions from TensorFlow `graph`. Return prerequisites for parsing: a. Set of operator names which don't have their mapping in TVM, i.e. which are not supported """ missing_operators = set() for node in graph.node: if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass else: if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn, _control_flow_nodes]]): pass else: missing_operators.add(node.op) return missing_operators def _parse_param(self, key, value, name, shape): try: from tensorflow.python.framework import tensor_util except ImportError as e: raise ImportError( "Unable to import tensorflow which is required {}".format(e)) if key == 'value': np_array = tensor_util.MakeNdarray(value.tensor) if np_array.dtype == np.dtype(object): # Object types are generally tensorflow DT_STRING (DecodeJpeg op). # Just leave it as placeholder. if shape and name in shape: var_shape = shape[name] else: var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')] return array_ndim = len(np_array.shape) if array_ndim == 0: new_array = np.empty([1], dtype=np_array.dtype) new_array[0] = np_array self._nodes[name] = [tvm.relay.const(new_array)] else: self._params[name] = tvm.nd.array(np_array) self._nodes[name] = [_expr.var(name, shape=self._params[name].shape, dtype=self._params[name].dtype)] else: if key not in ('dtype', '_output_shapes', '_class'): raise NotImplementedError \ ("Other attributes for a Const(param) Node {} ? .".format(key)) def _get_attr(self, buf): """Returns the value of the attr of this buf with the given `name`. Args: buf: attrvalue protobuf. Returns: The value of the attr, as a Python object. Raises: ValueError: If this op does not have an attr with the given `name`. """ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] x = buf ret = [] try: from tensorflow.python.framework import dtypes except ImportError as e: raise ImportError( "Unable to import tensorflow which is required {}".format(e)) # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): return ret if x.HasField("list"): for f in fields: if getattr(x.list, f): if f == "type": ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] else: ret += list(getattr(x.list, f)) else: for f in fields: if x.HasField(f): if f == "type": ret = dtypes.as_dtype(getattr(x, f)) else: ret = getattr(x, f) return ret def _parse_attr(self, attr_proto): """Convert a list of AttributeProto to a dict, with names as keys.""" attrs = {} for key, value in attr_proto.items(): attrs[key] = self._get_attr(value) return attrs def _convert_rnn_operator(self, op_name, inputs, attrs, params, graph, convert_map): """Convert RNN and its variant operators to Relay operators. This converter read the input states of each layers and also maintain the output states of each layer in a list. Parameters ---------- op_name : str Operator name, such as LSTMBlockCell inputs : list of relay.Expr List of input symbols. attrs : dict Dict of operator attributes params : dict List of pretrained weights and bias graph : Tensorflow graph object Graph is to find the number of upcoming same operator to calculate the number of layers. convert_map : dict Dict of name : callable, where name is the op's name that require conversion to relay, callable are functions which take attrs and return (new_op_name, new_attrs) Returns ------- sym : relay.Expr Converted relay.Expr """ if not self._num_rnn_layer: self._out_rnn = [] self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) self._num_rnn_layer = True sym = self.rnn.process_op(op_name, inputs, attrs, params) return sym def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): """ Convert the Relay control flow primitive into corresponding component of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop` are converted in Relay `If` and recusrive call, respectively. Parameters ---------- node: TensorFlow graph node object. A TensorFlow graph node object. inputs : List[tvm.relay.Expr] List of input symbols. attrs : Dict[tvm.Attrs] Dict of operator attributes. control_flow_node_map : Dict[str, Set[str]] A dictionary contains the execution frame name to primitives mapping. Returns ------- op : tvm.relay.Expr Converted relay expression. """ node_name_prefix = node.name.rsplit('/', 1)[0] if node.op == "Merge": if _in_while_loop(control_flow_node_map, node_name_prefix): op = self._nodes[node.input[0]] self._loops[node_name_prefix] = Loop() else: if len(self._branches) == 0: raise RuntimeError("Cannot find a created " "conditional for merge node") branch = self._branches[node_name_prefix] false_br = self._nodes[node.input[0]] true_br = self._nodes[node.input[1]] assert len(true_br) == 1 assert len(false_br) == 1 branch.true_branch = true_br[0] branch.false_branch = false_br[0] op = [branch.if_node()] elif node.op == "Exit": loop = self._loops[node_name_prefix] exit_name = node.name.split('/')[-1] assert str.startswith(exit_name, 'Exit') # TensorFlow has differen naming convention on different # versions. if '_' in exit_name: exit_number = int("0" + exit_name[5:]) else: exit_number = int("0" + exit_name[4:]) expr = loop.while_loop() op = _expr.TupleGetItem(expr, exit_number) elif node.op == "Enter": op = self._nodes[node.input[0]] elif node.op == "LoopCond": op = self._nodes[node.input[0]] assert len(op) == 1 self._loops[node_name_prefix].cond = op[0] elif node.op == "Switch": op = self._nodes[node.input[0]] assert len(op) == 1 if _in_while_loop(control_flow_node_map, node_name_prefix): self._loops[node_name_prefix].loop_vars.append(op[0]) else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() chk_op = _infer_type(op[0]) self._branches[node_name_prefix].cond = chk_op elif node.op == "NextIteration": op = self._nodes[node.input[0]] assert len(op) == 1 self._loops[node_name_prefix].body.append(op[0]) else: raise Exception("Cannot identify control flow operator: " + "{}".format(node.op)) return op def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters ---------- op_name : str Operator name, such as Conv2D, AvgPool inputs : list of relay.op List of input symbols. attrs : dict Dict of operator attributes identity_list : list List of operators that don't require conversion convert_map : dict Dict of name : callable, where name is the op's name that require conversion to relay, callable are functions which take attrs and return (new_op_name, new_attrs) Returns ------- sym : relay.op Converted relay operator """ identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map convert_map_rnn = _convert_map_rnn if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: if _need_prelude_for_shape_inference(op_name): sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) elif _need_module_for_shape_inference(op_name): sym = convert_map[op_name](inputs, attrs, self._params, self._mod) else: sym = convert_map[op_name](inputs, attrs, self._params) elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, convert_map_rnn) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. Parameters ---------- graph : GraphDef object Tensorflow GraphDef layout : target layout to be used (Optional) NCHW only supported now to enable NHWC models on GPU. shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. outputs : List of output tensor names (Optional) if not specified then the last node is assumed as graph output. Returns ------- mod : tvm.IRModule The module that optimizations will be performed on. params : dict of str to tvm.nd.NDArray Dict of converted parameters stored in tvm.nd.NDArray format """ g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) return mod, params