Commit 5a15664e by Tatsuya Nishiyama Committed by Tianqi Chen

Add PReLU support to mxnet frontend (#1249)

parent 5ba24773
...@@ -151,9 +151,10 @@ def _dropout(inputs, attrs): ...@@ -151,9 +151,10 @@ def _dropout(inputs, attrs):
def _leaky_relu(inputs, attrs): def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type') act_type = _required_attr(attrs, 'act_type')
if act_type in ['leaky']: if act_type in ['leaky', 'prelu']:
op_name, new_attrs = 'leaky_relu', {} op_name, new_attrs = act_type, {}
new_attrs['alpha'] = attrs.get('slope', 0.25) if act_type == 'leaky':
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu': elif act_type == 'elu':
slope = attrs.get('slope', 0.25) slope = attrs.get('slope', 0.25)
......
...@@ -97,6 +97,13 @@ def test_forward_rrelu(): ...@@ -97,6 +97,13 @@ def test_forward_rrelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
gamma = mx.sym.zeros(shape=(6,))
mx_sym = mx.sym.LeakyReLU(data, gamma=gamma, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu(): def test_forward_softrelu():
data = mx.sym.var('data') data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
...@@ -126,6 +133,7 @@ if __name__ == '__main__': ...@@ -126,6 +133,7 @@ if __name__ == '__main__':
test_forward_resnet() test_forward_resnet()
test_forward_elu() test_forward_elu()
test_forward_rrelu() test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu() test_forward_softrelu()
test_forward_fc_flatten() test_forward_fc_flatten()
test_forward_clip() test_forward_clip()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment