Commit ad7ffd35 by Joshua Z. Zhang Committed by Tianqi Chen

add softmaxoutput (#207)

parent 8b6f80c2
...@@ -184,6 +184,12 @@ def _split(inputs, attrs): ...@@ -184,6 +184,12 @@ def _split(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1) new_attrs['axis'] = attrs.get('axis', 1)
return _get_nnvm_op(op_name)(*inputs, **new_attrs) return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _softmax_output(inputs, attrs):
op_name, new_attrs = 'softmax', {}
if _parse_bool_str(attrs, 'multi_output'):
new_attrs['axis'] = 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
...@@ -217,6 +223,7 @@ _convert_map = { ...@@ -217,6 +223,7 @@ _convert_map = {
'Pooling_v1' : _pooling, 'Pooling_v1' : _pooling,
'Reshape' : _reshape, 'Reshape' : _reshape,
'Softmax' : _rename('softmax'), 'Softmax' : _rename('softmax'),
'SoftmaxOutput' : _softmax_output,
'concat' : _concat, 'concat' : _concat,
'max_axis' : _rename('max'), 'max_axis' : _rename('max'),
'min_axis' : _rename('min'), 'min_axis' : _rename('min'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment