Commit a0537ecb by Hao Jin Committed by Yizhi Liu

add support for mxnet smooth_l1 (#2905)

parent 53511bf1
...@@ -594,6 +594,15 @@ def _mx_embedding(inputs, _): ...@@ -594,6 +594,15 @@ def _mx_embedding(inputs, _):
return _op.take(weight, indices.astype('int32'), axis=0) return _op.take(weight, indices.astype('int32'), axis=0)
def _mx_smooth_l1(inputs, attrs):
scalar = attrs.get_float("scalar", 1.0)
scalar_sq = scalar * scalar
mask = _op.less(inputs[0], _expr.const(1.0 / scalar_sq, dtype='float32'))
return _op.where(mask,
_expr.const(scalar_sq / 2.0, dtype='float32') * inputs[0] * inputs[0],
_op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))
# Note: due to attribute conversion constraint # Note: due to attribute conversion constraint
# ops in the identity set must be attribute free # ops in the identity set must be attribute free
_identity_list = [ _identity_list = [
...@@ -729,6 +738,7 @@ _convert_map = { ...@@ -729,6 +738,7 @@ _convert_map = {
"Embedding" : _mx_embedding, "Embedding" : _mx_embedding,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation, "SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
# vision # vision
"_contrib_BilinearResize2D" : _mx_upsampling, "_contrib_BilinearResize2D" : _mx_upsampling,
"_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxPrior" : _mx_multibox_prior,
......
...@@ -464,6 +464,14 @@ def test_forward_embedding(): ...@@ -464,6 +464,14 @@ def test_forward_embedding():
verify((2, 2), (4, 5)) verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5)) verify((2, 3, 4), (4, 5))
def test_forward_smooth_l1():
data = mx.sym.var('data')
mx_sym = mx.sym.smooth_l1(data)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
mx_sym = mx.sym.smooth_l1(data, scalar=1.0)
verify_mxnet_frontend_impl(mx_sym, (3, 4), (3, 4))
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
...@@ -498,3 +506,4 @@ if __name__ == '__main__': ...@@ -498,3 +506,4 @@ if __name__ == '__main__':
test_forward_broadcast_axis() test_forward_broadcast_axis()
test_forward_full() test_forward_full()
test_forward_embedding() test_forward_embedding()
test_forward_smooth_l1()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment