Commit 150f7a8b by MORINAGA Committed by Tianqi Chen

[Frontend][MXNet] ones zeros ones_like zeros_like ops support (#1814)

parent 39c8bc2a
...@@ -273,6 +273,14 @@ def _lrn(inputs, attrs): ...@@ -273,6 +273,14 @@ def _lrn(inputs, attrs):
new_attrs['size'] = _required_attr(attrs, 'nsize') new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _ones(_, attrs):
op_name = "ones"
return _get_nnvm_op(op_name)(**attrs)
def _zeros(_, attrs):
op_name = "zeros"
return _get_nnvm_op(op_name)(**attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
...@@ -281,8 +289,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', ...@@ -281,8 +289,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh', 'ones_like', 'relu', 'sigmoid', 'slice_like', 'softmax',
'transpose'] 'sum', 'tanh', 'transpose', 'zeros_like']
_convert_map = { _convert_map = {
'_copy' : _rename('copy'), '_copy' : _rename('copy'),
...@@ -294,6 +302,8 @@ _convert_map = { ...@@ -294,6 +302,8 @@ _convert_map = {
'_rminus_scalar': _rename('__rsub_scalar__'), '_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'), '_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection, '_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_ones' : _ones,
'_zeros' : _zeros,
'Activation' : _activations, 'Activation' : _activations,
'BatchNorm' : _batch_norm, 'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm, 'BatchNorm_v1' : _batch_norm,
...@@ -397,13 +407,14 @@ def _from_mxnet_impl(symbol, graph): ...@@ -397,13 +407,14 @@ def _from_mxnet_impl(symbol, graph):
if node: if node:
return node[output_index] return node[output_index]
attr = symbol.list_attr() attr = symbol.list_attr()
# op_name = symbol.attr('op_name') op_name = symbol.attr('op_name')
childs = symbol.get_children() childs = symbol.get_children()
if childs is not None: if childs is not None:
op_name = symbol.attr('op_name')
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))] childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr) node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
else: else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op'] op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr) node = _sym.Variable(name=name, **attr)
......
...@@ -153,6 +153,28 @@ def test_forward_lrn(): ...@@ -153,6 +153,28 @@ def test_forward_lrn():
mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5) mx_sym = mx.sym.LRN(data, alpha=2, beta=2, knorm=1, nsize=5)
verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24)) verify_mxnet_frontend_impl(mx_sym, (1, 10, 24, 24), (1, 10, 24, 24))
def test_forward_ones():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, ones)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros():
data = mx.sym.var('data')
zeros = mx.sym.zeros(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.elemwise_add(data, zeros)
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_ones_like():
data = mx.sym.var('data')
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -168,3 +190,7 @@ if __name__ == '__main__': ...@@ -168,3 +190,7 @@ if __name__ == '__main__':
test_forward_expand_dims() test_forward_expand_dims()
test_forward_pooling() test_forward_pooling()
test_forward_lrn() test_forward_lrn()
test_forward_ones()
test_forward_zeros()
test_forward_ones_like()
test_forward_zeros_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment