Commit 9a6feca6 by Joshua Z. Zhang Committed by Tianqi Chen

[FRONTEND] Composed operators (#175)

* fix for composed symbol

* fix

* clean up

* fix exception type
parent 9fb13a69
...@@ -7,6 +7,12 @@ from .. import symbol as _sym ...@@ -7,6 +7,12 @@ from .. import symbol as _sym
__all__ = ['from_mxnet'] __all__ = ['from_mxnet']
def _get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op
def _get_mxnet_version(): def _get_mxnet_version():
try: try:
import mxnet as mx import mxnet as mx
...@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'): ...@@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'):
return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
def _rename(new_name): def _rename(new_name):
def impl(attr): def impl(inputs, attrs):
return new_name, attr return _get_nnvm_op(new_name)(*inputs, **attrs)
return impl return impl
def _variable(attrs): def _pooling(inputs, attrs):
return "Variable", attrs
def _pooling(attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel')) kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2: if len(kernel) != 2:
_raise_not_supported('non-2d kernel', 'pool_2d') _raise_not_supported('non-2d kernel', 'pool_2d')
...@@ -61,9 +64,9 @@ def _pooling(attrs): ...@@ -61,9 +64,9 @@ def _pooling(attrs):
new_attrs['strides'] = attrs.get('stride', (1, 1)) new_attrs['strides'] = attrs.get('stride', (1, 1))
new_attrs['padding'] = attrs.get('pad', (0, 0)) new_attrs['padding'] = attrs.get('pad', (0, 0))
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full') new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _batch_norm(attrs): def _batch_norm(inputs, attrs):
if _parse_bool_str(attrs, 'output_mean_var'): if _parse_bool_str(attrs, 'output_mean_var'):
_raise_not_supported('output_mean_var', 'batch_norm') _raise_not_supported('output_mean_var', 'batch_norm')
# if _parse_bool_str(attrs, 'fix_gamma'): # if _parse_bool_str(attrs, 'fix_gamma'):
...@@ -77,14 +80,14 @@ def _batch_norm(attrs): ...@@ -77,14 +80,14 @@ def _batch_norm(attrs):
new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True new_attrs['center'] = True
new_attrs['scale'] = True new_attrs['scale'] = True
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _concat(attrs): def _concat(inputs, attrs):
op_name = 'concatenate' op_name = 'concatenate'
new_attrs = {'axis': attrs.get('dim', 1)} new_attrs = {'axis': attrs.get('dim', 1)}
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _conv2d(attrs): def _conv2d(inputs, attrs):
kernel = _parse_tshape(_required_attr(attrs, 'kernel')) kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
if len(kernel) != 2: if len(kernel) != 2:
_raise_not_supported('non 2d kernel', 'conv2d') _raise_not_supported('non 2d kernel', 'conv2d')
...@@ -100,9 +103,9 @@ def _conv2d(attrs): ...@@ -100,9 +103,9 @@ def _conv2d(attrs):
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False'
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _conv2d_transpose(attrs): def _conv2d_transpose(inputs, attrs):
if 'target_shape' in attrs: if 'target_shape' in attrs:
_raise_not_supported('target_shape', 'conv2d_transpose') _raise_not_supported('target_shape', 'conv2d_transpose')
kernel = _parse_tshape(_required_attr(attrs, 'kernel')) kernel = _parse_tshape(_required_attr(attrs, 'kernel'))
...@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs): ...@@ -121,51 +124,68 @@ def _conv2d_transpose(attrs):
new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['groups'] = attrs.get('num_group', 1)
new_attrs['layout'] = layout new_attrs['layout'] = layout
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _dense(attrs): def _dense(inputs, attrs):
op_name, new_attrs = 'dense', {} op_name, new_attrs = 'dense', {}
new_attrs['units'] = _required_attr(attrs, 'num_hidden') new_attrs['units'] = _required_attr(attrs, 'num_hidden')
new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias')
major, minor, micro = _get_mxnet_version() major, minor, micro = _get_mxnet_version()
if major >= 0 and minor >= 11 and micro >= 1: if major >= 0 and minor >= 11 and micro >= 1:
new_attrs['flatten'] = _parse_bool_str(attrs, 'flatten', 'True') use_flatten = _parse_bool_str(attrs, 'flatten', 'True')
return op_name, new_attrs if use_flatten:
inputs[0] = _sym.flatten(inputs[0])
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _dropout(attrs): def _dropout(inputs, attrs):
op_name, new_attrs = 'dropout', {} op_name, new_attrs = 'dropout', {}
new_attrs['rate'] = attrs.get('p', 0.5) new_attrs['rate'] = attrs.get('p', 0.5)
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _leaky_relu(attrs): def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type') act_type = _required_attr(attrs, 'act_type')
if act_type not in ['leaky']: if act_type in ['leaky']:
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu':
slope = attrs.get('slope', 0.25)
sym = -slope * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs)
elif act_type == 'rrelu':
lower_bound = float(_required_attr(attrs, 'lower_bound'))
upper_bound = float(_required_attr(attrs, 'upper_bound'))
slope = (lower_bound + upper_bound) / 2.0
op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
else:
_raise_not_supported('act_type: ' + act_type) _raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = 'leaky_relu', {} return sym
new_attrs['alpha'] = attrs.get('slope', 0.25)
return op_name, new_attrs
def _activations(attrs): def _activations(inputs, attrs):
act_type = _required_attr(attrs, 'act_type') act_type = _required_attr(attrs, 'act_type')
if act_type not in ['relu', 'sigmoid', 'tanh']: if act_type in ['relu', 'sigmoid', 'tanh']:
op_name, new_attrs = act_type, {}
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'softrelu':
sym = _sym.log((1 + _sym.exp(*inputs)))
else:
_raise_not_supported('act_type: ' + act_type) _raise_not_supported('act_type: ' + act_type)
op_name, new_attrs = act_type, {} return sym
return op_name, new_attrs
def _reshape(attrs): def _reshape(inputs, attrs):
if _parse_bool_str(attrs, 'reverse'): if _parse_bool_str(attrs, 'reverse'):
_raise_not_supported('reverse', 'reshape') _raise_not_supported('reverse', 'reshape')
op_name, new_attrs = 'reshape', {} op_name, new_attrs = 'reshape', {}
new_attrs['shape'] = _required_attr(attrs, 'shape') new_attrs['shape'] = _required_attr(attrs, 'shape')
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _split(attrs): def _split(inputs, attrs):
if _parse_bool_str(attrs, 'squeeze_axis'): if _parse_bool_str(attrs, 'squeeze_axis'):
_raise_not_supported('squeeze_axis', 'split') _raise_not_supported('squeeze_axis', 'split')
op_name, new_attrs = 'split', {} op_name, new_attrs = 'split', {}
new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs')
new_attrs['axis'] = attrs.get('axis', 1) new_attrs['axis'] = attrs.get('axis', 1)
return op_name, new_attrs return _get_nnvm_op(op_name)(*inputs, **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
...@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', ...@@ -178,7 +198,12 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
_convert_map = { _convert_map = {
'null' : _variable, '_div_scalar' : _rename('__div_scalar__'),
'_minus_scalar' : _rename('__sub_scalar__'),
'_mul_scalar' : _rename('__mul_scalar__'),
'_plus_scalar' : _rename('__add_scalar__'),
'_rdiv_scalar' : _rename('__rdiv_scalar__'),
'_rminus_scalar': _rename('__rsub_scalar__'),
'Activation' : _activations, 'Activation' : _activations,
'BatchNorm' : _batch_norm, 'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm, 'BatchNorm_v1' : _batch_norm,
...@@ -202,7 +227,7 @@ _convert_map = { ...@@ -202,7 +227,7 @@ _convert_map = {
'sum_axis' : _rename('sum'), 'sum_axis' : _rename('sum'),
} }
def _convert_symbol(op_name, attrs, def _convert_symbol(op_name, inputs, attrs,
identity_list=None, identity_list=None,
convert_map=None): convert_map=None):
"""Convert from mxnet op to nnvm op. """Convert from mxnet op to nnvm op.
...@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs, ...@@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs,
---------- ----------
op_name : str op_name : str
Operator name, such as Convolution, FullyConnected Operator name, such as Convolution, FullyConnected
inputs : list of nnvm.Symbol
List of input symbols.
attrs : dict attrs : dict
Dict of operator attributes Dict of operator attributes
identity_list : list identity_list : list
...@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs, ...@@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs,
Returns Returns
------- -------
(op_name, attrs) sym : nnvm.Symbol
Converted (op_name, attrs) for nnvm. Converted nnvm Symbol
""" """
identity_list = identity_list if identity_list else _identity_list identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list: if op_name in identity_list:
pass op = _get_nnvm_op(op_name)
sym = op(*inputs, **attrs)
elif op_name in convert_map: elif op_name in convert_map:
op_name, attrs = convert_map[op_name](attrs) sym = convert_map[op_name](inputs, attrs)
else: else:
_raise_not_supported('Operator: ' + op_name) _raise_not_supported('Operator: ' + op_name)
op = getattr(_sym, op_name, None) return sym
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op, attrs
def _is_mxnet_group_symbol(symbol): def _is_mxnet_group_symbol(symbol):
"""Internal check for mxnet group symbol.""" """Internal check for mxnet group symbol."""
...@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph): ...@@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph):
node = graph.get(name, None) node = graph.get(name, None)
if node: if node:
return node return node
attr = symbol.list_attr()
# op_name = symbol.attr('op_name') # op_name = symbol.attr('op_name')
if symbol.get_children(): childs = symbol.get_children()
if childs:
op_name = symbol.attr('op_name') op_name = symbol.attr('op_name')
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
attr = symbol.list_attr()
new_op, new_attr = _convert_symbol(op_name, attr)
if new_op == _sym.Variable:
node = new_op(name=name, **new_attr)
else:
childs = symbol.get_children()
childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)] childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol childs = [x for y in childs for x in _as_list(y)] # expand group symbol
if new_op == _sym.dense and 'flatten' in new_attr: node = _convert_symbol(op_name, childs, attr)
if new_attr['flatten']: else:
childs[0] = _sym.flatten(childs[0]) op_name = json.loads(symbol.tojson())['nodes'][0]['op']
new_attr.pop('flatten') node = _sym.Variable(name=name, **attr)
node = new_op(name=name, *childs, **new_attr)
graph[name] = node graph[name] = node
return node return node
def from_mxnet(symbol, arg_params=None, aux_params=None): def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format. """Convert from MXNet's model into compatible NNVM format.
......
...@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape ...@@ -46,7 +46,7 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
assert "data" not in args assert "data" not in args
for target, ctx in ctx_list(): for target, ctx in ctx_list():
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5) np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_forward_mlp(): def test_forward_mlp():
mlp = model_zoo.mx_mlp mlp = model_zoo.mx_mlp
...@@ -62,7 +62,40 @@ def test_forward_resnet(): ...@@ -62,7 +62,40 @@ def test_forward_resnet():
mx_sym = model_zoo.mx_resnet[n] mx_sym = model_zoo.mx_resnet[n]
verify_mxnet_frontend_impl(mx_sym) verify_mxnet_frontend_impl(mx_sym)
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='elu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_rrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
mx_sym = mx.sym.Activation(data, act_type='softrelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_fc_flatten():
# test flatten=True option in mxnet 0.11.1
data = mx.sym.var('data')
try:
mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100))
except:
pass
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
test_forward_vgg() test_forward_vgg()
test_forward_resnet() test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_softrelu()
test_forward_fc_flatten()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment