Commit 9bb16872 by Yong Wu Committed by Tianqi Chen

[Relay][Frontend] Add a bunch of ops in tf converter (#3270)

parent c9e96d9f
...@@ -777,12 +777,12 @@ def _sum(): ...@@ -777,12 +777,12 @@ def _sum():
ignores=['name', 'Tidx'])([inputs[0]], attr) ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl return _impl
def _reduce_all(): def _reduce(op):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy() axis = params.pop(inputs[1].name_hint).asnumpy()
axis = tuple(axis) axis = tuple(axis)
return AttrCvt( return AttrCvt(
op_name='all', op_name=op,
extras={'axis': axis}, extras={'axis': axis},
transforms={'keep_dims':'keepdims'}, transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr) ignores=['name', 'Tidx'])([inputs[0]], attr)
...@@ -807,6 +807,14 @@ def _gather(): ...@@ -807,6 +807,14 @@ def _gather():
'Taxis', '_class'])(new_input, attr) 'Taxis', '_class'])(new_input, attr)
return _impl return _impl
def _gather_nd():
"""GatherNd"""
def _impl(inputs, attr, params):
return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\
'Taxis', '_class'])(inputs, attr)
return _impl
def _stridedSlice(): def _stridedSlice():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
"""Strided Slice. """Strided Slice.
...@@ -971,15 +979,18 @@ def _rank(): ...@@ -971,15 +979,18 @@ def _rank():
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
start = _get_num_param(params, inputs[0]) start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = _get_num_param(params, inputs[1]) limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
delta = _get_num_param(params, inputs[2]) if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
name = attr["_node_name"] dtype = attr['dtype'].name if 'dtype' in attr else "int32"
params[name] = tvm.nd.array([start, limit, delta]) return AttrCvt(
return [_expr.var(name, op_name="arange",
shape=params[name].shape, ignores=['Tidx'],
dtype='int32')] extras={'start': start,
"stop": limit,
'step': delta,
'dtype': dtype})([], attr)
return _impl return _impl
def _elu(): def _elu():
...@@ -1099,6 +1110,13 @@ def _topk(): ...@@ -1099,6 +1110,13 @@ def _topk():
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl return _impl
def _floordiv():
def _impl(inputs, attr, params):
assert len(inputs) == 2
div = AttrCvt('divide')(inputs, attr)
return _get_relay_op('floor')(div)
return _impl
def _logical(name): def _logical(name):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr) return AttrCvt(op_name=name)(inputs, attr)
...@@ -1207,8 +1225,9 @@ _identity_list = [] ...@@ -1207,8 +1225,9 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions # for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?) # for N to 1 mapping, currently not supported(?)
_convert_map = { _convert_map = {
'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'), 'Add' : _elemwise('add'),
'All' : _reduce_all(), 'All' : _reduce('all'),
'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'), 'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'), 'AvgPool' : _pooling('avg_pool'),
...@@ -1232,26 +1251,33 @@ _convert_map = { ...@@ -1232,26 +1251,33 @@ _convert_map = {
'ExpandDims' : _expand_dims(), 'ExpandDims' : _expand_dims(),
'Fill' : _fill(), 'Fill' : _fill(),
'Floor' : AttrCvt('floor'), 'Floor' : AttrCvt('floor'),
'FloorDiv' : _floordiv(),
'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(),
'Gather' : _gather(), 'Gather' : _gather(),
'GatherNd' : _gather_nd(),
'GatherV2' : _gather(), 'GatherV2' : _gather(),
'Greater' : _broadcast('greater'), 'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'), 'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(), 'Identity' : _identity(),
'LeakyRelu' : AttrCvt('leaky_relu'), 'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'), 'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'), 'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'), 'Log' : AttrCvt('log'),
'LogicalAnd' : _logical('logical_and'), 'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'), 'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'), 'LogicalNot' : _logical('logical_not'),
'LogSoftmax' : AttrCvt('log_softmax'),
'LRN' : _lrn(), 'LRN' : _lrn(),
'MatMul' : _matmul(), 'MatMul' : _matmul(),
'Max' : _reduce('max'),
'MaxPool' : _pooling('max_pool'), 'MaxPool' : _pooling('max_pool'),
'Maximum' : _elemwise('maximum'), 'Maximum' : _elemwise('maximum'),
'Mean' : _mean(), 'Mean' : _mean(),
'Min' : _reduce('min'),
'Minimum' : _elemwise('minimum'), 'Minimum' : _elemwise('minimum'),
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'), 'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'), 'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'), 'NotEqual' : _broadcast('not_equal'),
...@@ -1269,6 +1295,7 @@ _convert_map = { ...@@ -1269,6 +1295,7 @@ _convert_map = {
'ResizeBilinear' : _resize_bilinear(), 'ResizeBilinear' : _resize_bilinear(),
'ResizeBicubic' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(),
'ReverseV2' : _reverse_v2(), 'ReverseV2' : _reverse_v2(),
'RightShift' : AttrCvt('right_shift'),
'Round' : AttrCvt('round'), 'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(), 'Rsqrt' : _rsqrt(),
'Select' : _where(), 'Select' : _where(),
...@@ -1292,7 +1319,9 @@ _convert_map = { ...@@ -1292,7 +1319,9 @@ _convert_map = {
'Tile' : _tile(), 'Tile' : _tile(),
'TopKV2' : _topk(), 'TopKV2' : _topk(),
'Transpose' : _transpose(), 'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(), 'Unpack' : _unpack(),
'ZerosLike' : AttrCvt('zeros_like'),
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment