Commit 9a6feca6 by Joshua Z. Zhang Committed by Tianqi Chen

[FRONTEND] Composed operators (#175)

* fix for composed symbol

* fix

* clean up

* fix exception type
parent 9fb13a69
......@@ -7,6 +7,12 @@ from .. import symbol as _sym
__all__ = ['from_mxnet']
def _get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op
def _get_mxnet_version():
try:
import mxnet as mx
......@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
def _rename(new_name):
def impl(attr):
return new_name, attr
def impl(inputs, attrs):
return _get_nnvm_op(new_name)(*inputs, **attrs)
return impl
def _variable(attrs):
return "Variable", attrs
def _pooling(attrs):
def _pooling(inputs, attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non-2d kernel', 'pool_2d')
......@@ -61,9 +64,9 @@ def _pooling(attrs):
new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _batch_norm(attrs):
def _batch_norm(inputs, attrs):
if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm')
# if _parse_bool_str(attrs, 'fix_gamma'):
......@@ -77,14 +80,14 @@ def _batch_norm(attrs):
new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True
new_attrs['scale'] = True
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _concat(attrs):
def _concat(inputs, attrs):
op_name = 'concatenate'
new_attrs = {'axis': attrs.get('dim', 1)}
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _conv2d(attrs):
def _conv2d(inputs, attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2:
_raise_not_supported('non 2d kernel', 'conv2d')
......@@ -100,9 +103,9 @@ def _conv2d(attrs):
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _conv2d_transpose(attrs):
def _conv2d_transpose(inputs, attrs):
if 'target_shape' in attrs:
_raise_not_supported('target_shape', 'conv2d_transpose')
kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
......@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _dense(attrs):
def _dense(inputs, attrs):
op_name, new_attrs = 'dense', {}
new_attrs['units'] = _required_attr(attrs, 'num_hidden')
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
major, minor, micro = _get_mxnet_version()
if major >= 0 and minor >= 11 and micro >= 1:
new_attrs['flatten'] = _parse_bool_str(attrs, 'flatten', 'True')
return op_name, new_attrs
use_flatten = _parse_bool_str(attrs, 'flatten', 'True')
if use_flatten:
inputs[0] = _sym.flatten(inputs[0])
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _dropout(attrs):
def _dropout(inputs, attrs):
op_name, new_attrs = 'dropout', {}
new_attrs['rate'] = attrs.get('p', 0.5)
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _leaky_relu(attrs):
def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['leaky']:
if act_type in ['leaky']:
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu':
slope = attrs.get('slope', 0.25)
sym = -slope * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs)
elif act_type == 'rrelu':
lower_bound = float(_required_attr(attrs, 'lower_bound'))
upper_bound = float(_required_attr(attrs, 'upper_bound'))
slope = (lower_bound + upper_bound) / 2.0
op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
else:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
return op_name, new_attrs
return sym
def _activations(attrs):
def _activations(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type not in ['relu', 'sigmoid', 'tanh']:
if act_type in ['relu', 'sigmoid', 'tanh']:
op_name, new_attrs = act_type, {}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'softrelu':
sym = _sym.log((1 + _sym.exp(*inputs)))
else:
_raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = act_type, {}
return op_name, new_attrs
return sym
def _reshape(attrs):
def _reshape(inputs, attrs):
if _parse_bool_str(attrs, 'reverse'):
_raise_not_supported('reverse', 'reshape')
op_name, new_attrs = 'reshape', {}
new_attrs['shape'] = _required_attr(attrs, 'shape')
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _split(attrs):
def _split(inputs, attrs):
if _parse_bool_str(attrs, 'squeeze_axis'):
_raise_not_supported('squeeze_axis', 'split')
op_name, new_attrs = 'split', {}
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
new_attrs['axis'] = attrs.get('axis', 1)
return op_name, new_attrs
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
......@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
_convert_map = {
'null' : _variable,
'_div_scalar' : _rename('__div_scalar__'),
'_minus_scalar' : _rename('__sub_scalar__'),
'_mul_scalar' : _rename('__mul_scalar__'),
'_plus_scalar' : _rename('__add_scalar__'),
'_rdiv_scalar' : _rename('__rdiv_scalar__'),
'_rminus_scalar': _rename('__rsub_scalar__'),
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
......@@ -202,7 +227,7 @@ _convert_map = {
'sum_axis' : _rename('sum'),
}
def _convert_symbol(op_name, attrs,
def _convert_symbol(op_name, inputs, attrs,
identity_list=None,
convert_map=None):
"""Convert from mxnet op to nnvm op.
......@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict
Dict of operator attributes
identity_list : list
......@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
Returns
-------
(op_name, attrs)
Converted (op_name, attrs) for nnvm.
sym : nnvm.Symbol
Converted nnvm Symbol
"""
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
pass
op = _get_nnvm_op(op_name)
sym = op(*inputs, **attrs)
elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs)
sym = convert_map[op_name](inputs, attrs)
else:
_raise_not_supported('Operator: ' + op_name)
op = getattr(_sym, op_name, None)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op, attrs
return sym
def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol."""
......@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
node = graph.get(name, None)
if node:
return node
attr = symbol.list_attr()
# op_name = symbol.attr('op_name')
if symbol.get_children():
childs = symbol.get_children()
if childs:
op_name = symbol.attr('op_name')
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
attr = symbol.list_attr()
new_op, new_attr = _convert_symbol(op_name, attr)
if new_op == _sym.Variable:
node = new_op(name=name, **new_attr)
else:
childs = symbol.get_children()
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
if new_op == _sym.dense and 'flatten' in new_attr:
if new_attr['flatten']:
childs[0] = _sym.flatten(childs[0])
new_attr.pop('flatten')
node = new_op(name=name, *childs, **new_attr)
node = _convert_symbol(op_name, childs, attr)
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node
def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
......
......@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
assert "data" not in args
for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mlp():
mlp = model_zoo.mx_mlp
......@@ -62,7 +62,40 @@ def test_forward_resnet():
mx_sym = model_zoo.mx_resnet[n]
verify_mxnet_frontend_impl(mx_sym)
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_softrelu()
test_forward_fc_flatten()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment