Commit 9bb16872 by Yong Wu Committed by Tianqi Chen

[Relay][Frontend] Add a bunch of ops in tf converter (#3270)

parent c9e96d9f
......@@ -777,12 +777,12 @@ def _sum():
ignores=['name', 'Tidx'])([inputs[0]], attr)
return _impl
def _reduce_all():
def _reduce(op):
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
axis = tuple(axis)
return AttrCvt(
op_name='all',
op_name=op,
extras={'axis': axis},
transforms={'keep_dims':'keepdims'},
ignores=['name', 'Tidx'])([inputs[0]], attr)
......@@ -807,6 +807,14 @@ def _gather():
'Taxis', '_class'])(new_input, attr)
return _impl
def _gather_nd():
"""GatherNd"""
def _impl(inputs, attr, params):
return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\
'Taxis', '_class'])(inputs, attr)
return _impl
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
......@@ -971,15 +979,18 @@ def _rank():
def _range():
def _impl(inputs, attr, params):
start = _get_num_param(params, inputs[0])
limit = _get_num_param(params, inputs[1])
delta = _get_num_param(params, inputs[2])
name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta])
return [_expr.var(name,
shape=params[name].shape,
dtype='int32')]
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
extras={'start': start,
"stop": limit,
'step': delta,
'dtype': dtype})([], attr)
return _impl
def _elu():
......@@ -1099,6 +1110,13 @@ def _topk():
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl
def _floordiv():
def _impl(inputs, attr, params):
assert len(inputs) == 2
div = AttrCvt('divide')(inputs, attr)
return _get_relay_op('floor')(div)
return _impl
def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
......@@ -1207,8 +1225,9 @@ _identity_list = []
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
_convert_map = {
'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'),
'All' : _reduce_all(),
'All' : _reduce('all'),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'AvgPool' : _pooling('avg_pool'),
......@@ -1232,26 +1251,33 @@ _convert_map = {
'ExpandDims' : _expand_dims(),
'Fill' : _fill(),
'Floor' : AttrCvt('floor'),
'FloorDiv' : _floordiv(),
'FusedBatchNorm' : _fused_batch_norm(),
'FusedBatchNormV2' : _fused_batch_norm(),
'Gather' : _gather(),
'GatherNd' : _gather_nd(),
'GatherV2' : _gather(),
'Greater' : _broadcast('greater'),
'GreaterEqual' : _broadcast('greater_equal'),
'Identity' : _identity(),
'LeakyRelu' : AttrCvt('leaky_relu'),
'LeftShift' : AttrCvt('left_shift'),
'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
'LogSoftmax' : AttrCvt('log_softmax'),
'LRN' : _lrn(),
'MatMul' : _matmul(),
'Max' : _reduce('max'),
'MaxPool' : _pooling('max_pool'),
'Maximum' : _elemwise('maximum'),
'Mean' : _mean(),
'Min' : _reduce('min'),
'Minimum' : _elemwise('minimum'),
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NotEqual' : _broadcast('not_equal'),
......@@ -1269,6 +1295,7 @@ _convert_map = {
'ResizeBilinear' : _resize_bilinear(),
'ResizeBicubic' : _resize_bilinear(),
'ReverseV2' : _reverse_v2(),
'RightShift' : AttrCvt('right_shift'),
'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(),
'Select' : _where(),
......@@ -1292,7 +1319,9 @@ _convert_map = {
'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'ZerosLike' : AttrCvt('zeros_like'),
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment