Commit 5a15664e by Tatsuya Nishiyama Committed by Tianqi Chen

Add PReLU support to mxnet frontend (#1249)

parent 5ba24773
......@@ -151,9 +151,10 @@ def _dropout(inputs, attrs):
def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type in ['leaky']:
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
if act_type in ['leaky', 'prelu']:
op_name, new_attrs = act_type, {}
if act_type == 'leaky':
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu':
slope = attrs.get('slope', 0.25)
......
......@@ -97,6 +97,13 @@ def test_forward_rrelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
gamma = mx.sym.zeros(shape=(6,))
mx_sym = mx.sym.LeakyReLU(data, gamma=gamma, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
......@@ -126,6 +133,7 @@ if __name__ == '__main__':
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_fc_flatten()
test_forward_clip()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment