Commit 8f5c27bd by Hao Jin Committed by Haichen Shen

support MXNet _minimum and _maximum (#2709)

parent c8259e3e
...@@ -286,6 +286,12 @@ def _lrn(inputs, attrs): ...@@ -286,6 +286,12 @@ def _lrn(inputs, attrs):
new_attrs['size'] = _required_attr(attrs, 'nsize') new_attrs['size'] = _required_attr(attrs, 'nsize')
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _minimum(inputs, attrs):
return _get_nnvm_op('broadcast_min')(*inputs, **attrs)
def _maximum(inputs, attrs):
return _get_nnvm_op('broadcast_max')(*inputs, **attrs)
def _ones(_, attrs): def _ones(_, attrs):
op_name = 'ones' op_name = 'ones'
return _get_nnvm_op(op_name)(**attrs) return _get_nnvm_op(op_name)(**attrs)
...@@ -330,6 +336,8 @@ _convert_map = { ...@@ -330,6 +336,8 @@ _convert_map = {
'_rminus_scalar': _rename('__rsub_scalar__'), '_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'), '_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection, '_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'_minimum' : _minimum,
'_maximum' : _maximum,
'_ones' : _ones, '_ones' : _ones,
'_zeros' : _zeros, '_zeros' : _zeros,
'argmax' : _argmax, 'argmax' : _argmax,
......
...@@ -227,6 +227,68 @@ def test_forward_slice(): ...@@ -227,6 +227,68 @@ def test_forward_slice():
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2)) mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2)) verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
def test_forward_maximum():
a = mx.sym.var('a')
b = mx.sym.var('b')
dshape = (10, 20)
dtype = 'float32'
mx_sym = mx.sym._internal._maximum(a, b)
np_a = np.random.uniform(size=dshape).astype(dtype)
np_b = np.random.uniform(size=dshape).astype(dtype)
mx_a = mx.nd.array(np_a)
mx_b = mx.nd.array(np_b)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['a', 'b'])
mod.bind(data_shapes=[('a', dshape), ('b', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd._internal._maximum(mx_a, mx_b).asnumpy()
out_shape = dshape
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
shape_dict = {'a': dshape, 'b': dshape}
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("a", tvm.nd.array(np_a))
m.set_input("b", tvm.nd.array(np_b))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_minimum():
a = mx.sym.var('a')
b = mx.sym.var('b')
dshape = (10, 20)
dtype = 'float32'
mx_sym = mx.sym._internal._minimum(a, b)
np_a = np.random.uniform(size=dshape).astype(dtype)
np_b = np.random.uniform(size=dshape).astype(dtype)
mx_a = mx.nd.array(np_a)
mx_b = mx.nd.array(np_b)
mod = mx.mod.Module(mx_sym, label_names=None, data_names=['a', 'b'])
mod.bind(data_shapes=[('a', dshape), ('b', dshape)], for_training=False)
mod.init_params()
args, auxs = mod.get_params()
mx_out = mx.nd._internal._minimum(mx_a, mx_b).asnumpy()
out_shape = dshape
new_sym, params = frontend.from_mxnet(mx_sym, args, auxs)
shape_dict = {'a': dshape, 'b': dshape}
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input("a", tvm.nd.array(np_a))
m.set_input("b", tvm.nd.array(np_b))
m.set_input(**params)
m.run()
# get outputs
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -251,4 +313,6 @@ if __name__ == '__main__': ...@@ -251,4 +313,6 @@ if __name__ == '__main__':
test_forward_argmin() test_forward_argmin()
test_forward_where() test_forward_where()
test_forward_slice() test_forward_slice()
test_forward_maximum()
test_forward_minimum()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment