Commit ad7ffd35 by Joshua Z. Zhang Committed by Tianqi Chen

add softmaxoutput (#207)

parent 8b6f80c2
......@@ -184,6 +184,12 @@ def _split(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1)
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
def _softmax_output(inputs, attrs):
op_name, new_attrs = 'softmax', {}
if _parse_bool_str(attrs, 'multi_output'):
new_attrs['axis'] = 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
......@@ -217,6 +223,7 @@ _convert_map = {
'Pooling_v1' : _pooling,
'Reshape' : _reshape,
'Softmax' : _rename('softmax'),
'SoftmaxOutput' : _softmax_output,
'concat' : _concat,
'max_axis' : _rename('max'),
'min_axis' : _rename('min'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment