Commit 150f7a8b by MORINAGA Committed by Tianqi Chen

[Frontend][MXNet] ones zeros ones_like zeros_like ops support (#1814)

parent 39c8bc2a
......@@ -273,6 +273,14 @@ def _lrn(inputs, attrs):
new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _ones(_, attrs):
op_name = "ones"
return _get_nnvm_op(op_name)(**attrs)
def _zeros(_, attrs):
op_name = "zeros"
return _get_nnvm_op(op_name)(**attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
......@@ -281,8 +289,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh',
'transpose']
'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'sum', 'tanh', 'transpose', 'zeros_like']
_convert_map = {
'_copy' : _rename('copy'),
......@@ -294,6 +302,8 @@ _convert_map = {
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_ones' : _ones,
'_zeros' : _zeros,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
......@@ -397,13 +407,14 @@ def _from_mxnet_impl(symbol, graph):
if node:
return node[output_index]
attr = symbol.list_attr()
# op_name = symbol.attr('op_name')
op_name = symbol.attr('op_name')
childs = symbol.get_children()
if childs is not None:
op_name = symbol.attr('op_name')
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
......
......@@ -153,6 +153,28 @@ def test_forward_lrn():
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -168,3 +190,7 @@ if __name__ == '__main__':
test_forward_expand_dims()
test_forward_pooling()
test_forward_lrn()
test_forward_ones()
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment