Unverified Commit 6b1136dd by maheshambule Committed by GitHub

[Frontend] [MXNet] make_loss operator support (#4930)

* make_loss test case

* mxnet frontend make_loss support

* added comment for make_loss

* pylint fix

* Update mxnet.py
parent 8c6a7723
......@@ -644,6 +644,13 @@ def _mx_arange(inputs, attrs):
return _op.arange(**new_attrs)
# pylint: disable=unused-argument
def _mx_make_loss(inputs, attrs):
# while doing inference make_loss does not have any effect
# and it should be mapped to identity
return inputs[0]
def _mx_repeat(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
......@@ -1822,6 +1829,7 @@ _convert_map = {
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
"make_loss" : _mx_make_loss,
"_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
"one_hot" : _mx_one_hot,
# vision
......
......@@ -201,6 +201,12 @@ def test_forward_ones_like():
mx_sym = mx.sym.ones_like(data, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_make_loss():
data = mx.sym.var('data')
ones = mx.sym.ones(shape=(2, 3, 4), dtype='float32')
mx_sym = mx.sym.make_loss((data-ones)**2/2, dtype='float32')
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 3, 4))
def test_forward_zeros_like():
data = mx.sym.var('data')
mx_sym = mx.sym.zeros_like(data, dtype='float32')
......@@ -996,4 +1002,5 @@ if __name__ == '__main__':
test_forward_one_hot()
test_forward_convolution()
test_forward_deconvolution()
test_forward_cond()
\ No newline at end of file
test_forward_cond()
test_forward_make_loss()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment